You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2019/01/10 23:13:03 UTC

[incubator-mxnet] branch master updated: [MXNET-1260] Float64 DType computation support in Scala/Java (#13678)

This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ed7ca26  [MXNET-1260] Float64 DType computation support in Scala/Java (#13678)
ed7ca26 is described below

commit ed7ca26a23881f9b7474fbcb12a576c2b544bee6
Author: Piyush Ghai <gh...@osu.edu>
AuthorDate: Thu Jan 10 15:12:43 2019 -0800

    [MXNET-1260] Float64 DType computation support in Scala/Java (#13678)
    
    * Added Float64 as a supported datatype in NDArray
    
    * Added unit tests for Float64 in NDArray
    
    * Fix for failing Clojure unit tests
    
    * Added Float and Double as MX_PRIMITIVES for computation in Scala
    
    * Trying out second approach --> Private Impl methods with generic signature, and public methods calling the Impls
    
    * Fixed errors in *= method
    
    * Added Float64 in IO.scala and DataIter.scala
    
    * Added another testcase for IO.DataDesc creation
    
    * Fixed failing CI
    
    * Added Float64 in Predictor class
    
    * Added Float64 in Classifier class
    
    * Added Double as a possible return type to : classifyWithNDArray
    
    * Added unit tests for Classifier and Predictor.scala classes for Float64/Double
    
    * Approach 3 --> Using a trait to mirror Float and Double in Scala
    
    * Added comments on MX_PRIMITIVES.scala
    
    * Added Float64/Double support for inference in ImageClassifier APIs
    
    * Added unary- and compareTo in MX_NUMBER_LIKE
    
    * Renamed MX_NUMBER_LIKE to MX_PRIMITIVE_TYPE
    
    * Fixed linting issue
    
    * Now specifying dType from the available data in copyTo and MXDataIter.scala for creating a new DataIterator
    
    * Add primitives support handling to the generator for proper conversion
    
    * Reduced code duplication in classify method in Classifier.scala
    
    * Fix infer package for new signatures and address some bugs
    
    * Removed code duplication in getPixelsArray
    
    * remove debugging
    
    * Changed classifyWithNDArray method in Classifier.scala
    
    * Removed code duplication in predictImpl
    
    * Satisfying lint god _/\_
    
    * Fixed failing PredictorSuite test
    
    * Renamed MX_FLOAT to Camel case
    
    * Revert "Renamed MX_FLOAT to Camel case"
    
    This reverts commit 9d7c3ce6f9c4d6ed2c46041a02e23c0f1df8dfe5.
    
    * Added an implicit conversion from int--> float to support int operations in NDArrays. (These ops were already supported in the previous versions)
    
    * Added Float64 as a training option to ImClassification Suite. Also added integration tests for it
    
    * Satisfy Lint God _/\_
    
    * Added Float64 support in Java NDArray
    
    * Added Float64 support in Java's Predictor API
    
    * Added yours truly to the Contributors list
    
    * Added method comments on Predictor.predict with Array[Double] as a possible input
    
    * Added method comments explaining what MX_PRIMITIVE_TYPE is
    
    *  Fixed errors cause by rebasing with master
    
    * Added licences to the files
---
 CONTRIBUTORS.md                                    |   1 +
 .../src/org/apache/clojure_mxnet/infer.clj         | 242 ++++++-------
 .../src/org/apache/clojure_mxnet/primitives.clj    |  46 +++
 .../src/org/apache/clojure_mxnet/util.clj          |   7 +-
 contrib/clojure-package/test/good-test-ndarray.clj |   7 +-
 .../clojure_mxnet/infer/imageclassifier_test.clj   |  12 +-
 .../clojure_mxnet/infer/objectdetector_test.clj    |   4 +
 .../test/org/apache/clojure_mxnet/ndarray_test.clj |   2 +-
 .../org/apache/clojure_mxnet/primitives_test.clj   |  45 +++
 .../test/org/apache/clojure_mxnet/util_test.clj    |  10 +
 .../src/main/scala/org/apache/mxnet/Base.scala     |   7 +-
 .../src/main/scala/org/apache/mxnet/LibInfo.scala  |   3 +
 .../scala/org/apache/mxnet/MX_PRIMITIVES.scala     |  85 +++++
 .../src/main/scala/org/apache/mxnet/NDArray.scala  | 230 +++++++++---
 .../scala/org/apache/mxnet/io/MXDataIter.scala     |   6 +-
 .../scala/org/apache/mxnet/io/NDArrayIter.scala    |   7 +-
 .../scala/org/apache/mxnet/javaapi/NDArray.scala   |  65 ++++
 .../java/org/apache/mxnet/javaapi/NDArrayTest.java |  15 +
 .../src/test/scala/org/apache/mxnet/IOSuite.scala  |  27 ++
 .../test/scala/org/apache/mxnet/NDArraySuite.scala | 396 ++++++++++++++++++---
 .../imclassification/TrainModel.scala              |  24 +-
 .../datasets/SyntheticDataIter.scala               |   8 +-
 .../imclassification/models/Lenet.scala            |   5 +-
 .../models/MultiLayerPerceptron.scala              |   5 +-
 .../imclassification/models/Resnet.scala           |  16 +-
 .../IMClassificationExampleSuite.scala             |  10 +-
 .../scala/org/apache/mxnet/infer/Classifier.scala  |  39 +-
 .../org/apache/mxnet/infer/ImageClassifier.scala   |  48 ++-
 .../scala/org/apache/mxnet/infer/Predictor.scala   |  46 ++-
 .../org/apache/mxnet/infer/javaapi/Predictor.scala |  24 ++
 .../org/apache/mxnet/infer/ClassifierSuite.scala   |  47 ++-
 .../apache/mxnet/infer/ImageClassifierSuite.scala  |   7 +
 .../org/apache/mxnet/infer/PredictorSuite.scala    |  32 +-
 .../main/native/org_apache_mxnet_native_c_api.cc   |   9 +
 .../main/native/org_apache_mxnet_native_c_api.h    |   8 +
 35 files changed, 1251 insertions(+), 294 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index b9f84d5..5b5fdce 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -193,6 +193,7 @@ List of Contributors
 * [Yuxi Hu](https://github.com/yuxihu)
 * [Harsh Patel](https://github.com/harshp8l)
 * [Xiao Wang](https://github.com/BeyonderXX)
+* [Piyush Ghai](https://github.com/piyushghai)
 
 Label Bot
 ---------
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index b2b23da..224a392 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -18,6 +18,7 @@
 (ns org.apache.clojure-mxnet.infer
   (:refer-clojure :exclude [type])
   (:require [org.apache.clojure-mxnet.context :as context]
+            [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.shape :as shape]
             [org.apache.clojure-mxnet.util :as util]
@@ -62,10 +63,12 @@
 (defprotocol AImageClassifier
   (classify-image
     [wrapped-image-classifier image]
-    [wrapped-image-classifier image topk])
+    [wrapped-image-classifier image topk]
+    [wrapped-image-classifier image topk dtype])
   (classify-image-batch
     [wrapped-image-classifier images]
-    [wrapped-image-classifier images topk]))
+    [wrapped-image-classifier images topk]
+    [wrapped-image-classifier images topk dtype]))
 
 (defprotocol AObjectDetector
   (detect-objects
@@ -80,7 +83,8 @@
 
 (extend-protocol APredictor
   WrappedPredictor
-  (predict [wrapped-predictor inputs]
+  (predict
+    [wrapped-predictor inputs]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
     (util/validate! ::vec-of-float-arrays inputs
@@ -101,62 +105,50 @@
 
 (extend-protocol AClassifier
   WrappedClassifier
-  (classify [wrapped-classifier inputs]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (classify wrapped-classifier inputs nil))
-  (classify [wrapped-classifier inputs topk]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classify (:classifier wrapped-classifier)
-                (util/vec->indexed-seq inputs)
-                (util/->int-option topk))))
-  (classify-with-ndarray [wrapped-classifier inputs]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (classify-with-ndarray wrapped-classifier inputs nil))
-  (classify-with-ndarray [wrapped-classifier inputs topk]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyWithNDArray (:classifier wrapped-classifier)
-                           (util/vec->indexed-seq inputs)
-                           (util/->int-option topk))))
+  (classify
+    ([wrapped-classifier inputs]
+     (classify wrapped-classifier inputs nil))
+    ([wrapped-classifier inputs topk]
+     (util/validate! ::wrapped-classifier wrapped-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-float-arrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classify (:classifier wrapped-classifier)
+                 (util/vec->indexed-seq inputs)
+                 (util/->int-option topk)))))
+  (classify-with-ndarray
+    ([wrapped-classifier inputs]
+     (classify-with-ndarray wrapped-classifier inputs nil))
+    ([wrapped-classifier inputs topk]
+     (util/validate! ::wrapped-classifier wrapped-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-ndarrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classifyWithNDArray (:classifier wrapped-classifier)
+                            (util/vec->indexed-seq inputs)
+                           (util/->int-option topk)))))
   WrappedImageClassifier
-  (classify [wrapped-image-classifier inputs]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (classify wrapped-image-classifier inputs nil))
-  (classify [wrapped-image-classifier inputs topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classify (:image-classifier wrapped-image-classifier)
-                (util/vec->indexed-seq inputs)
-                (util/->int-option topk))))
-  (classify-with-ndarray [wrapped-image-classifier inputs]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (classify-with-ndarray wrapped-image-classifier inputs nil))
-  (classify-with-ndarray [wrapped-image-classifier inputs topk]
+  (classify
+    ([wrapped-image-classifier inputs]
+     (classify wrapped-image-classifier inputs nil))
+    ([wrapped-image-classifier inputs topk]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-float-arrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classify (:image-classifier wrapped-image-classifier)
+                 (util/vec->indexed-seq inputs)
+                 (util/->int-option topk)))))
+  (classify-with-ndarray
+    ([wrapped-image-classifier inputs]
+     (classify-with-ndarray wrapped-image-classifier inputs nil))
+    ([wrapped-image-classifier inputs topk]
     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                     "Invalid classifier")
     (util/validate! ::vec-of-ndarrays inputs
@@ -165,83 +157,83 @@
     (util/coerce-return-recursive
      (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
                            (util/vec->indexed-seq inputs)
-                           (util/->int-option topk)))))
+                           (util/->int-option topk))))))
 
 (s/def ::image #(instance? BufferedImage %))
+(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
 
 (extend-protocol AImageClassifier
   WrappedImageClassifier
-  (classify-image [wrapped-image-classifier image]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::image image "Invalid image")
-    (classify-image wrapped-image-classifier image nil))
-  (classify-image [wrapped-image-classifier image topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::image image "Invalid image")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyImage (:image-classifier wrapped-image-classifier)
-                     image
-                     (util/->int-option topk))))
-  (classify-image-batch [wrapped-image-classifier images]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (classify-image-batch wrapped-image-classifier images nil))
-  (classify-image-batch [wrapped-image-classifier images topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                          images
-                          (util/->int-option topk)))))
+  (classify-image
+    ([wrapped-image-classifier image]
+     (classify-image wrapped-image-classifier image nil dtype/FLOAT32))
+    ([wrapped-image-classifier image topk]
+     (classify-image wrapped-image-classifier image topk dtype/FLOAT32))
+    ([wrapped-image-classifier image topk dtype]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::image image "Invalid image")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/validate! ::dtype dtype "Invalid dtype")
+     (util/coerce-return-recursive
+      (.classifyImage (:image-classifier wrapped-image-classifier)
+                      image
+                      (util/->int-option topk)
+                      dtype))))
+  (classify-image-batch
+    ([wrapped-image-classifier images]
+     (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
+    ([wrapped-image-classifier images topk]
+         (classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
+    ([wrapped-image-classifier images topk dtype]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/validate! ::dtype dtype "Invalid dtype")
+     (util/coerce-return-recursive
+      (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+                           images
+                           (util/->int-option topk)
+                           dtype)))))
 
 (extend-protocol AObjectDetector
   WrappedObjectDetector
-  (detect-objects [wrapped-detector image]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::image image "Invalid image")
-    (detect-objects wrapped-detector image nil))
-  (detect-objects [wrapped-detector image topk]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::image image "Invalid image")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.imageObjectDetect (:object-detector wrapped-detector)
-                         image
-                         (util/->int-option topk))))
-  (detect-objects-batch [wrapped-detector images]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (detect-objects-batch wrapped-detector images nil))
-  (detect-objects-batch [wrapped-detector images topk]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                              images
-                              (util/->int-option topk))))
-  (detect-objects-with-ndarrays [wrapped-detector input-arrays]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::vec-of-ndarrays input-arrays
-                    "Invalid inputs")
-    (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
-  (detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
+  (detect-objects
+    ([wrapped-detector image]
+     (detect-objects wrapped-detector image nil))
+    ([wrapped-detector image topk]
     (util/validate! ::wrapped-detector wrapped-detector
                     "Invalid object detector")
-    (util/validate! ::vec-of-ndarrays input-arrays
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.objectDetectWithNDArray (:object-detector wrapped-detector)
-                               (util/vec->indexed-seq input-arrays)
+     (util/validate! ::image image "Invalid image")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.imageObjectDetect (:object-detector wrapped-detector)
+                          image
+                          (util/->int-option topk)))))
+  (detect-objects-batch
+    ([wrapped-detector images]
+     (detect-objects-batch wrapped-detector images nil))
+    ([wrapped-detector images topk]
+     (util/validate! ::wrapped-detector wrapped-detector
+                     "Invalid object detector")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.imageBatchObjectDetect (:object-detector wrapped-detector)
+                               images
                                (util/->int-option topk)))))
+  (detect-objects-with-ndarrays
+    ([wrapped-detector input-arrays]
+     (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
+    ([wrapped-detector input-arrays topk]
+     (util/validate! ::wrapped-detector wrapped-detector
+                     "Invalid object detector")
+     (util/validate! ::vec-of-ndarrays input-arrays
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.objectDetectWithNDArray (:object-detector wrapped-detector)
+                                (util/vec->indexed-seq input-arrays)
+                                (util/->int-option topk))))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
@@ -335,7 +327,7 @@
   [image input-shape-vec]
   (util/validate! ::image image "Invalid image")
   (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
-  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
+  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
 
 (s/def ::image-path string?)
 (s/def ::image-paths (s/coll-of ::image-path))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
new file mode 100644
index 0000000..0967df2
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
@@ -0,0 +1,46 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives
+  (:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
+                             MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))
+
+
+;;; Defines customer mx primitives that can be used for mathematical computations
+;;; in NDArrays to control precision. Currently Float and Double are supported
+
+;;; For purposes of automatic conversion in ndarray functions, doubles are default
+;; to specify using floats you must use a Float
+
+(defn mx-float
+  "Creates a MXNet float primitive"
+  [num]
+  (new MX_PRIMITIVES$MX_FLOAT num))
+
+(defn mx-double
+  "Creates a MXNet double primitive"
+  [num]
+  (new MX_PRIMITIVES$MX_Double num))
+
+(defn ->num
+  "Returns the underlying number value"
+  [primitive]
+  (.data primitive))
+
+(defn primitive? [x]
+  (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))
+
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 21e31ba..43970c0 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -19,6 +19,7 @@
   (:require [clojure.spec.alpha :as s]
             [t6.from-scala.core :refer [$ $$] :as $]
             [clojure.string :as string]
+            [org.apache.clojure-mxnet.primitives :as primitives]
             [org.apache.clojure-mxnet.shape :as mx-shape])
   (:import (org.apache.mxnet NDArray)
            (scala Product Tuple2 Tuple3)
@@ -36,7 +37,8 @@
                            "byte<>" "byte-array"
                            "java.lang.String<>" "vec-or-strings"
                            "org.apache.mxnet.NDArray" "ndarray"
-                           "org.apache.mxnet.Symbol" "sym"})
+                           "org.apache.mxnet.Symbol" "sym"
+                           "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})
 
 (def symbol-param-coerce {"java.lang.String" "sym-name"
                           "float" "num"
@@ -144,6 +146,8 @@
     (and (get targets "int<>") (vector? param)) (int-array param)
     (and (get targets "float<>") (vector? param)) (float-array param)
     (and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+    (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
+    (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
     :else param))
 
 (defn nil-or-coerce-param [param targets]
@@ -177,6 +181,7 @@
     (instance? Map return-val) (scala-map->map return-val)
     (instance? Tuple2 return-val) (tuple->vec return-val)
     (instance? Tuple3 return-val) (tuple->vec return-val)
+    (primitives/primitive? return-val) (primitives/->num return-val)
     :else return-val))
 
 (defn coerce-return-recursive [return-val]
diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj
index 3b53b19..b048a81 100644
--- a/contrib/clojure-package/test/good-test-ndarray.clj
+++ b/contrib/clojure-package/test/good-test-ndarray.clj
@@ -27,11 +27,12 @@
 
 (defn
  div
- ([ndarray num-or-ndarray]
+ ([ndarray ndarray-or-double-or-float]
   (util/coerce-return
    (.$div
     ndarray
     (util/coerce-param
-     num-or-ndarray
-     #{"float" "org.apache.mxnet.NDArray"})))))
+     ndarray-or-double-or-float
+     #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
+       "org.apache.mxnet.NDArray"})))))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 9badfed..b459b06 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -40,7 +40,11 @@
 (deftest test-single-classification
   (let [classifier (create-classifier)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
-        [predictions] (infer/classify-image classifier image 5)]
+        [predictions-all] (infer/classify-image classifier image)
+        [predictions-with-default-dtype] (infer/classify-image classifier image 10)
+        [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 10 (count predictions-with-default-dtype)))
     (is (some? predictions))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
@@ -58,8 +62,12 @@
   (let [classifier (create-classifier)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
-        batch-predictions (infer/classify-image-batch classifier image-batch 5)
+        batch-predictions-all (infer/classify-image-batch classifier image-batch)
+        batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
+        batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
         predictions (first batch-predictions)]
+    (is (= 1000 (count (first batch-predictions-all))))
+    (is (= 10 (count (first batch-predictions-with-default-dtype))))
     (is (some? batch-predictions))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 788a594..3a0e3d3 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -40,9 +40,11 @@
 (deftest test-single-detection
   (let [detector (create-detector)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
+        [predictions-all] (infer/detect-objects detector image)
         [predictions] (infer/detect-objects detector image 5)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
+    (is (= 13 (count predictions-all)))
     (is (every? #(= 2 (count %)) predictions))
     (is (every? #(string? (first %)) predictions))
     (is (every? #(= 5 (count (second %))) predictions))
@@ -53,9 +55,11 @@
   (let [detector (create-detector)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
+        batch-predictions-all (infer/detect-objects-batch detector image-batch)
         batch-predictions (infer/detect-objects-batch detector image-batch 5)
         predictions (first batch-predictions)]
     (is (some? batch-predictions))
+    (is (= 13 (count (first batch-predictions-all))))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
     (is (every? #(string? (first %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index 79e9441..9ffd3ab 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -97,7 +97,7 @@
     (is (= [1.0 1.0] (->vec ndhalves)))))
 
 (deftest test-full
-  (let [nda (full [1 2] 3)]
+  (let [nda (full [1 2] 3.0)]
     (is (= (shape nda) (mx-shape/->shape [1 2])))
     (is (= [3.0 3.0] (->vec nda)))))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
new file mode 100644
index 0000000..1a538e5
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
@@ -0,0 +1,45 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives-test
+  (:require [org.apache.clojure-mxnet.primitives :as primitives]
+            [clojure.test :refer :all])
+  (:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE
+                             MX_PRIMITIVES$MX_FLOAT
+                             MX_PRIMITIVES$MX_Double)))
+
+(deftest test-primitive-types
+  (is (not (primitives/primitive? 3)))
+  (is (primitives/primitive? (primitives/mx-float 3)))
+  (is (primitives/primitive? (primitives/mx-double 3))))
+
+(deftest test-float-primitives
+  (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3)))
+  (is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3)))
+  (is (instance? Float (-> (primitives/mx-float 3)
+                           (primitives/->num))))
+  (is (= 3.0 (-> (primitives/mx-float 3)
+                 (primitives/->num)))))
+
+(deftest test-double-primitives
+  (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2)))
+  (is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2)))
+  (is (instance? Double (-> (primitives/mx-double 2)
+                            (primitives/->num))))
+  (is (= 2.0 (-> (primitives/mx-double 2)
+                 (primitives/->num)))))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index bd77a8a..c26f83d 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -20,6 +20,7 @@
             [org.apache.clojure-mxnet.shape :as mx-shape]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.primitives :as primitives]
             [org.apache.clojure-mxnet.symbol :as sym]
             [org.apache.clojure-mxnet.test-util :as test-util]
             [clojure.spec.alpha :as s])
@@ -133,6 +134,9 @@
   (is (= "[F"  (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str))))
   (is (= "[L"  (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str))))
 
+  (is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+  (is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+
   (is (= 1 (util/coerce-param 1 #{"unknown"}))))
 
 (deftest test-nil-or-coerce-param
@@ -171,6 +175,12 @@
                 (util/convert-tuple [1 2]))))
   (is (= [1 2 3] (util/coerce-return
                   (util/convert-tuple [1 2 3]))))
+
+  (is (instance? Double (util/coerce-return (primitives/mx-double 3))))
+  (is (= 3.0 (util/coerce-return (primitives/mx-double 3))))
+  (is (instance? Float (util/coerce-return (primitives/mx-float 2))))
+  (is (= 2.0 (util/coerce-return (primitives/mx-float 2))))
+
   (is (= "foo" (util/coerce-return "foo"))))
 
 (deftest test-translate-keyword-shape
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index ed7aff6..001bd04 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -18,7 +18,9 @@
 package org.apache.mxnet
 
 import org.apache.mxnet.util.NativeLibraryLoader
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.Specializable.Group
 
 private[mxnet] object Base {
   private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")
@@ -57,6 +59,9 @@ private[mxnet] object Base {
 
   val MX_REAL_TYPE = DType.Float32
 
+  // The primitives currently supported for NDArray operations
+  val MX_PRIMITIVES = new Group ((Double, Float))
+
   try {
     try {
       tryLoadLibraryOS("mxnet-scala")
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 0a5683a..20b6ed9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -93,6 +93,9 @@ private[mxnet] class LibInfo {
   @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
                                        source: Array[MXFloat],
                                        size: Int): Int
+  @native def mxFloat64NDArraySyncCopyFromCPU(handle: NDArrayHandle,
+                                       source: Array[Double],
+                                       size: Int): Int
   @native def mxNDArrayLoad(fname: String,
                             outSize: MXUintRef,
                             handles: ArrayBuffer[NDArrayHandle],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
new file mode 100644
index 0000000..cb97885
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object MX_PRIMITIVES {
+
+  /**
+    * This defines the basic primitives we can use in Scala for mathematical
+    * computations in NDArrays.This gives us a flexibility to expand to
+    * more supported primitives in the future. Currently Float and Double
+    * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept
+    * plain old Float and Double data as inputs because of the underlying
+    * implicit conversion between primitives to MX_PRIMITIVE_TYPE.
+    */
+  trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{
+
+    def toString: String
+
+    def unary_- : MX_PRIMITIVE_TYPE
+  }
+
+  trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] {
+
+    def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE): Int = x.compare(y)
+
+  }
+
+  implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering
+
+  /**
+    * Wrapper over Float in Scala.
+    * @param data
+    */
+  class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE {
+
+    override def toString: String = data.toString
+
+    override def unary_- : MX_PRIMITIVE_TYPE = new MX_FLOAT(data.unary_-)
+
+    override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+      this.data.compareTo(that.asInstanceOf[MX_FLOAT].data)
+    }
+  }
+
+  implicit def FloatToMX_Float(d : Float): MX_FLOAT = new MX_FLOAT(d)
+
+  implicit def MX_FloatToFloat(d: MX_FLOAT) : Float = d.data
+
+  implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat)
+
+  /**
+    * Wrapper over Double in Scala.
+    * @param data
+    */
+  class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE {
+
+    override def toString: String = data.toString
+
+    override def unary_- : MX_PRIMITIVE_TYPE = new MX_Double(data.unary_-)
+
+    override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+      this.data.compareTo(that.asInstanceOf[MX_Double].data)
+    }
+  }
+
+  implicit def DoubleToMX_Double(d : Double): MX_Double = new MX_Double(d)
+
+  implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
+
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 1259581..163ed26 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
+import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
 import org.slf4j.LoggerFactory
 
 import scala.collection.mutable
@@ -262,16 +263,46 @@ object NDArray extends NDArrayBase {
     arr
   }
 
-  // Perform power operator
+  def full(shape: Shape, value: Double, ctx: Context): NDArray = {
+    val arr = empty(shape, ctx, DType.Float64)
+    arr.set(value)
+    arr
+  }
+
+  /**
+    * Create a new NDArray filled with given value, with specified shape.
+    * @param shape shape of the NDArray.
+    * @param value value to be filled with
+    */
+  def full(shape: Shape, value: Double): NDArray = {
+    full(shape, value, null)
+  }
+
+
+  /**
+    * Perform power operation on NDArray. Returns result as NDArray
+    * @param lhs
+    * @param rhs
+    */
   def power(lhs: NDArray, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs))
   }
 
-  def power(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform scalar power operation on NDArray. Returns result as NDArray
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs))
   }
 
-  def power(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform scalar power operation on NDArray. Returns result as NDArray
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs))
   }
 
@@ -280,11 +311,21 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs))
   }
 
-  def maximum(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform the max operation on NDArray. Returns the result as NDArray.
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
   }
 
-  def maximum(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform the max operation on NDArray. Returns the result as NDArray.
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
   }
 
@@ -293,11 +334,21 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs))
   }
 
-  def minimum(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform the min operation on NDArray. Returns the result as NDArray.
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
   }
 
-  def minimum(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform the min operation on NDArray. Returns the result as NDArray.
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
   }
 
@@ -310,7 +361,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs))
   }
 
-  def equal(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding elements are same,
+    * otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -324,7 +383,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs))
   }
 
-  def notEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **not equal to** (!=) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding elements are different,
+    * otherwise return 0(false).
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -338,7 +405,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs))
   }
 
-  def greater(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **greater than** (>) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+    * otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs))
   }
 
@@ -352,7 +428,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs))
   }
 
-  def greaterEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **greater than or equal to** (>=) comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are greater than equal to
+    * rhs, otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -366,7 +451,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs))
   }
 
-  def lesser(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **lesser than** (<) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+    * otherwise return 0(false).
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs))
   }
 
@@ -380,7 +473,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs))
   }
 
-  def lesserEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are
+    * lesser than equal to rhs, otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -397,6 +499,16 @@ object NDArray extends NDArrayBase {
     arr
   }
 
+  def array(sourceArr: Array[Double], shape: Shape, ctx: Context): NDArray = {
+    val arr = empty(shape, ctx, dtype = DType.Float64)
+    arr.set(sourceArr)
+    arr
+  }
+
+  def array(sourceArr: Array[Double], shape: Shape): NDArray = {
+    array(sourceArr, shape, null)
+  }
+
   /**
    * Returns evenly spaced values within a given interval.
    * Values are generated within the half-open interval [`start`, `stop`). In other
@@ -645,6 +757,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
   }
 
+  private def syncCopyfrom(source: Array[Double]): Unit = {
+    require(source.length == size,
+      s"array size (${source.length}) do not match the size of NDArray ($size)")
+    checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
+  }
+
   /**
    * Return a sliced NDArray that shares memory with current one.
    * NDArray only support continuous slicing on axis 0
@@ -759,7 +877,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
    * @param value Value to set
    * @return Current NDArray
    */
-  def set(value: Float): NDArray = {
+  def set(value: MX_PRIMITIVE_TYPE): NDArray = {
     require(writable, "trying to assign to a readonly NDArray")
     NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this))
     this
@@ -776,11 +894,17 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
+  def set(other: Array[Double]): NDArray = {
+    require(writable, "trying to assign to a readonly NDArray")
+    syncCopyfrom(other)
+    this
+  }
+
   def +(other: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other))
   }
 
-  def +(other: Float): NDArray = {
+  def +(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other))
   }
 
@@ -792,7 +916,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def +=(other: Float): NDArray = {
+  def +=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to add to a readonly NDArray")
     }
@@ -804,7 +928,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other))
   }
 
-  def -(other: Float): NDArray = {
+  def -(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other))
   }
 
@@ -816,7 +940,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def -=(other: Float): NDArray = {
+  def -=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
     }
@@ -828,7 +952,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other))
   }
 
-  def *(other: Float): NDArray = {
+  def *(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other))
   }
 
@@ -844,7 +968,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def *=(other: Float): NDArray = {
+  def *=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
     }
@@ -856,7 +980,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other))
   }
 
-  def /(other: Float): NDArray = {
+  def /(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other))
   }
 
@@ -868,7 +992,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def /=(other: Float): NDArray = {
+  def /=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to divide from a readonly NDArray")
     }
@@ -880,7 +1004,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.power(this, other)
   }
 
-  def **(other: Float): NDArray = {
+  def **(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.power(this, other)
   }
 
@@ -888,7 +1012,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this))
   }
 
-  def **=(other: Float): NDArray = {
+  def **=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this))
   }
 
@@ -896,7 +1020,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.greater(this, other)
   }
 
-  def >(other: Float): NDArray = {
+  def >(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.greater(this, other)
   }
 
@@ -904,7 +1028,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.greaterEqual(this, other)
   }
 
-  def >=(other: Float): NDArray = {
+  def >=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.greaterEqual(this, other)
   }
 
@@ -912,7 +1036,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.lesser(this, other)
   }
 
-  def <(other: Float): NDArray = {
+  def <(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.lesser(this, other)
   }
 
@@ -920,7 +1044,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.lesserEqual(this, other)
   }
 
-  def <=(other: Float): NDArray = {
+  def <=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.lesserEqual(this, other)
   }
 
@@ -928,7 +1052,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other))
   }
 
-  def %(other: Float): NDArray = {
+  def %(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other))
   }
 
@@ -940,7 +1064,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def %=(other: Float): NDArray = {
+  def %=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to take modulo from a readonly NDArray")
     }
@@ -956,6 +1080,14 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     internal.toFloatArray
   }
 
+  /**
+    * Return a copied flat java array of current array (row-major) with datatype as Float64/Double.
+    * @return  A copy of array content.
+    */
+  def toFloat64Array: Array[Double] = {
+    internal.toDoubleArray
+  }
+
   def internal: NDArrayInternal = {
     val myType = dtype
     val arrLength = DType.numOfBytes(myType) * size
@@ -975,6 +1107,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this.toArray(0)
   }
 
+  def toFloat64Scalar: Double = {
+    require(shape == Shape(1), "The current array is not a scalar")
+    this.toFloat64Array(0)
+  }
+
   /**
    * Copy the content of current array to other.
    *
@@ -997,7 +1134,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
    * @return The copy target NDArray
    */
   def copyTo(ctx: Context): NDArray = {
-    val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true))
+    val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true, dtype = dtype))
     copyTo(ret)
   }
 
@@ -1047,11 +1184,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
 
 private[mxnet] object NDArrayConversions {
   implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat)
-  implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat)
+  implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x)
   implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x)
 }
 
-private[mxnet] class NDArrayConversions(val value: Float) {
+private[mxnet] class NDArrayConversions(val value: MX_PRIMITIVE_TYPE) {
   def +(other: NDArray): NDArray = {
     other + value
   }
@@ -1145,34 +1282,39 @@ private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) {
   def waitToRead(): Unit = head.waitToRead()
   def context: Context = head.context
   def set(value: Float): NDArray = head.set(value)
+  def set(value: Double): NDArray = head.set(value)
   def set(other: NDArray): NDArray = head.set(other)
   def set(other: Array[Float]): NDArray = head.set(other)
+  def set(other: Array[Double]): NDArray = head.set(other)
   def +(other: NDArray): NDArray = head + other
-  def +(other: Float): NDArray = head + other
+  def +(other: MX_PRIMITIVE_TYPE): NDArray = head + other
   def +=(other: NDArray): NDArray = head += other
-  def +=(other: Float): NDArray = head += other
+  def +=(other: MX_PRIMITIVE_TYPE): NDArray = head += other
   def -(other: NDArray): NDArray = head - other
-  def -(other: Float): NDArray = head - other
+  def -(other: MX_PRIMITIVE_TYPE): NDArray = head - other
   def -=(other: NDArray): NDArray = head -= other
-  def -=(other: Float): NDArray = head -= other
+  def -=(other: MX_PRIMITIVE_TYPE): NDArray = head -= other
   def *(other: NDArray): NDArray = head * other
-  def *(other: Float): NDArray = head * other
+  def *(other: MX_PRIMITIVE_TYPE): NDArray = head * other
   def unary_-(): NDArray = -head
   def *=(other: NDArray): NDArray = head *= other
-  def *=(other: Float): NDArray = head *= other
+  def *=(other: MX_PRIMITIVE_TYPE): NDArray = head *= other
   def /(other: NDArray): NDArray = head / other
+  def /(other: MX_PRIMITIVE_TYPE): NDArray = head / other
   def **(other: NDArray): NDArray = head ** other
-  def **(other: Float): NDArray = head ** other
+  def **(other: MX_PRIMITIVE_TYPE): NDArray = head ** other
   def >(other: NDArray): NDArray = head > other
-  def >(other: Float): NDArray = head > other
+  def >(other: MX_PRIMITIVE_TYPE): NDArray = head > other
   def >=(other: NDArray): NDArray = head >= other
-  def >=(other: Float): NDArray = head >= other
+  def >=(other: MX_PRIMITIVE_TYPE): NDArray = head >= other
   def <(other: NDArray): NDArray = head < other
-  def <(other: Float): NDArray = head < other
+  def <(other: MX_PRIMITIVE_TYPE): NDArray = head < other
   def <=(other: NDArray): NDArray = head <= other
-  def <=(other: Float): NDArray = head <= other
+  def <=(other: MX_PRIMITIVE_TYPE): NDArray = head <= other
   def toArray: Array[Float] = head.toArray
+  def toFloat64Array: Array[Double] = head.toFloat64Array
   def toScalar: Float = head.toScalar
+  def toFloat64Scalar: Double = head.toFloat64Scalar
   def copyTo(other: NDArray): NDArray = head.copyTo(other)
   def copyTo(ctx: Context): NDArray = head.copyTo(ctx)
   def copy(): NDArray = head.copy()
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index a84bd10..e30098c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -53,9 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
       val label = currentBatch.label(0)
       // properties
       val res = (
-        // TODO: need to allow user to specify DType and Layout
-        IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
-        IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
+        // TODO: need to allow user to specify Layout
+        IndexedSeq(new DataDesc(dataName, data.shape, data.dtype, Layout.UNDEFINED)),
+        IndexedSeq(new DataDesc(labelName, label.shape, label.dtype, Layout.UNDEFINED)),
         ListMap(dataName -> data.shape),
         ListMap(labelName -> label.shape),
         data.shape(0))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 0032a54..e690abb 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -61,7 +61,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
            dataBatchSize: Int = 1, shuffle: Boolean = false,
            lastBatchHandle: String = "pad",
            dataName: String = "data", labelName: String = "label") {
-    this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED),
+    this(IO.initDataDesc(data, allowEmpty = false, dataName,
+      if (data == null || data.isEmpty)  MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
       IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
       dataBatchSize, shuffle, lastBatchHandle)
   }
@@ -272,7 +273,7 @@ object NDArrayIter {
      */
     def addData(name: String, data: NDArray): Builder = {
       this.data = this.data ++ IndexedSeq((new DataDesc(name,
-        data.shape, DType.Float32, Layout.UNDEFINED), data))
+        data.shape, data.dtype, Layout.UNDEFINED), data))
       this
     }
 
@@ -284,7 +285,7 @@ object NDArrayIter {
      */
     def addLabel(name: String, label: NDArray): Builder = {
       this.label = this.label ++ IndexedSeq((new DataDesc(name,
-        label.shape, DType.Float32, Layout.UNDEFINED), label))
+        label.shape, label.dtype, Layout.UNDEFINED), label))
       this
     }
 
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 198102d..67809c1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -91,17 +91,26 @@ object NDArray extends NDArrayBase {
   def full(shape: Shape, value: Float, ctx: Context): NDArray
   = org.apache.mxnet.NDArray.full(shape, value, ctx)
 
+  def full(shape: Shape, value: Double, ctx: Context): NDArray
+  = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
   def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
   def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
   def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
 
   def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
   def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
   def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
 
   def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
   def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
   def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
 
 
   /**
@@ -111,6 +120,7 @@ object NDArray extends NDArrayBase {
     */
   def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
   def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+  def equal(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
 
   /**
     * Returns the result of element-wise **not equal to** (!=) comparison operation
@@ -120,6 +130,7 @@ object NDArray extends NDArrayBase {
     */
   def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
   def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+  def notEqual(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
 
   /**
     * Returns the result of element-wise **greater than** (>) comparison operation
@@ -129,6 +140,7 @@ object NDArray extends NDArrayBase {
     */
   def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
   def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+  def greater(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
 
   /**
     * Returns the result of element-wise **greater than or equal to** (>=) comparison
@@ -140,6 +152,8 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
   def greaterEqual(lhs: NDArray, rhs: Float): NDArray
   = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+  def greaterEqual(lhs: NDArray, rhs: Double): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
 
   /**
     * Returns the result of element-wise **lesser than** (<) comparison operation
@@ -149,6 +163,7 @@ object NDArray extends NDArrayBase {
     */
   def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
   def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+  def lesser(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
 
   /**
     * Returns the result of element-wise **lesser than or equal to** (<=) comparison
@@ -160,6 +175,8 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
   def lesserEqual(lhs: NDArray, rhs: Float): NDArray
   = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+  def lesserEqual(lhs: NDArray, rhs: Double): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
 
   /**
     * Create a new NDArray that copies content from source_array.
@@ -173,6 +190,18 @@ object NDArray extends NDArrayBase {
     sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
 
   /**
+    * Create a new NDArray that copies content from source_array.
+    * @param sourceArr Source data (list of Doubles) to create NDArray from.
+    * @param shape shape of the NDArray
+    * @param ctx The context of the NDArray, default to current default context.
+    * @return The created NDArray.
+    */
+  def arrayWithDouble(sourceArr: java.util.List[java.lang.Double], shape: Shape,
+                      ctx: Context = null): NDArray
+  = org.apache.mxnet.NDArray.array(
+    sourceArr.asScala.map(ele => Double.unbox(ele)).toArray, shape)
+
+  /**
     * Returns evenly spaced values within a given interval.
     * Values are generated within the half-open interval [`start`, `stop`). In other
     * words, the interval includes `start` but excludes `stop`.
@@ -205,6 +234,10 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
   }
 
+  def this(arr: Array[Double], shape: Shape, ctx: Context) = {
+    this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+  }
+
   def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
     this(NDArray.array(arr, shape, ctx))
   }
@@ -304,41 +337,59 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     * @return Current NDArray
     */
   def set(value: Float): NDArray = nd.set(value)
+  def set(value: Double): NDArray = nd.set(value)
   def set(other: NDArray): NDArray = nd.set(other)
   def set(other: Array[Float]): NDArray = nd.set(other)
+  def set(other: Array[Double]): NDArray = nd.set(other)
 
   def add(other: NDArray): NDArray = this.nd + other.nd
   def add(other: Float): NDArray = this.nd + other
+  def add(other: Double): NDArray = this.nd + other
   def addInplace(other: NDArray): NDArray = this.nd += other
   def addInplace(other: Float): NDArray = this.nd += other
+  def addInplace(other: Double): NDArray = this.nd += other
   def subtract(other: NDArray): NDArray = this.nd - other
   def subtract(other: Float): NDArray = this.nd - other
+  def subtract(other: Double): NDArray = this.nd - other
   def subtractInplace(other: NDArray): NDArray = this.nd -= other
   def subtractInplace(other: Float): NDArray = this.nd -= other
+  def subtractInplace(other: Double): NDArray = this.nd -= other
   def multiply(other: NDArray): NDArray = this.nd * other
   def multiply(other: Float): NDArray = this.nd * other
+  def multiply(other: Double): NDArray = this.nd * other
   def multiplyInplace(other: NDArray): NDArray = this.nd *= other
   def multiplyInplace(other: Float): NDArray = this.nd *= other
+  def multiplyInplace(other: Double): NDArray = this.nd *= other
   def div(other: NDArray): NDArray = this.nd / other
   def div(other: Float): NDArray = this.nd / other
+  def div(other: Double): NDArray = this.nd / other
   def divInplace(other: NDArray): NDArray = this.nd /= other
   def divInplace(other: Float): NDArray = this.nd /= other
+  def divInplace(other: Double): NDArray = this.nd /= other
   def pow(other: NDArray): NDArray = this.nd ** other
   def pow(other: Float): NDArray = this.nd ** other
+  def pow(other: Double): NDArray = this.nd ** other
   def powInplace(other: NDArray): NDArray = this.nd **= other
   def powInplace(other: Float): NDArray = this.nd **= other
+  def powInplace(other: Double): NDArray = this.nd **= other
   def mod(other: NDArray): NDArray = this.nd % other
   def mod(other: Float): NDArray = this.nd % other
+  def mod(other: Double): NDArray = this.nd % other
   def modInplace(other: NDArray): NDArray = this.nd %= other
   def modInplace(other: Float): NDArray = this.nd %= other
+  def modInplace(other: Double): NDArray = this.nd %= other
   def greater(other: NDArray): NDArray = this.nd > other
   def greater(other: Float): NDArray = this.nd > other
+  def greater(other: Double): NDArray = this.nd > other
   def greaterEqual(other: NDArray): NDArray = this.nd >= other
   def greaterEqual(other: Float): NDArray = this.nd >= other
+  def greaterEqual(other: Double): NDArray = this.nd >= other
   def lesser(other: NDArray): NDArray = this.nd < other
   def lesser(other: Float): NDArray = this.nd < other
+  def lesser(other: Double): NDArray = this.nd < other
   def lesserEqual(other: NDArray): NDArray = this.nd <= other
   def lesserEqual(other: Float): NDArray = this.nd <= other
+  def lesserEqual(other: Double): NDArray = this.nd <= other
 
   /**
     * Return a copied flat java array of current array (row-major).
@@ -347,6 +398,12 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
   def toArray: Array[Float] = nd.toArray
 
   /**
+    * Return a copied flat java array of current array (row-major).
+    * @return  A copy of array content.
+    */
+  def toFloat64Array: Array[Double] = nd.toFloat64Array
+
+  /**
     * Return a CPU scalar(float) of current ndarray.
     * This ndarray must have shape (1,)
     *
@@ -355,6 +412,14 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
   def toScalar: Float = nd.toScalar
 
   /**
+    * Return a CPU scalar(float) of current ndarray.
+    * This ndarray must have shape (1,)
+    *
+    * @return The scalar representation of the ndarray.
+    */
+  def toFloat64Scalar: Double = nd.toFloat64Scalar
+
+  /**
     * Copy the content of current array to other.
     *
     * @param other Target NDArray or context we want to copy data to.
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b78..86c7eb2 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -40,6 +40,15 @@ public class NDArrayTest {
                 new Shape(new int[]{1, 3}),
                 new Context("cpu", 0));
         assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+
+        List<Double> list2 = Arrays.asList(1d, 1d, 1d);
+        nd = NDArray.arrayWithDouble(list2,
+                new Shape(new int[]{1, 3}),
+                new Context("cpu", 0));
+
+        // Float64 assertion
+        assertTrue(nd.dtype() == DType.Float64());
+
     }
 
     @Test
@@ -64,6 +73,12 @@ public class NDArrayTest {
         nd = nd.subtract(nd2);
         float[] lesser = new float[]{0, 0, 0};
         assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+
+        NDArray nd3 = new NDArray(new double[]{1.0, 2.0, 3.0}, new Shape(new int[]{3}), new Context("cpu", 0));
+        nd3 = nd3.add(1.0);
+        double[] smaller = new double[] {2, 3, 4};
+        assertTrue(Arrays.equals(smaller, nd3.toFloat64Array()));
+
     }
 
     @Test
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 2ec6f66..d3969b0 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -303,5 +303,32 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(dataDesc(0).layout == Layout.NTC)
     assert(labelDesc(0).dtype == DType.Int32)
     assert(labelDesc(0).layout == Layout.NT)
+
+
+    // Test with passing Float64 hardcoded as Dtype of data
+    val dataIter4 = new NDArrayIter(
+      IO.initDataDesc(data, false, "data", DType.Float64, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc4 = dataIter4.provideDataDesc
+    val labelDesc4 = dataIter4.provideLabelDesc
+    assert(dataDesc4(0).dtype == DType.Float64)
+    assert(dataDesc4(0).layout == Layout.NTC)
+    assert(labelDesc4(0).dtype == DType.Int32)
+    assert(labelDesc4(0).layout == Layout.NT)
+
+    // Test with Float64 coming from the data itself
+    val dataF64 = IndexedSeq(NDArray.ones(shape0, dtype = DType.Float64),
+      NDArray.zeros(shape0, dtype = DType.Float64))
+
+    val dataIter5 = new NDArrayIter(
+      IO.initDataDesc(dataF64, false, "data", DType.Float64, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc5 = dataIter5.provideDataDesc
+    assert(dataDesc5(0).dtype == DType.Float64)
+    assert(dataDesc5(0).dtype != DType.Float32)
+    assert(dataDesc5(0).layout == Layout.NTC)
+
   }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 2f3b167..bc7a0a0 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.mxnet.NDArrayConversions._
-import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
 
 class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   private val sequence: AtomicInteger = new AtomicInteger(0)
@@ -29,6 +29,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("to java array") {
     val ndarray = NDArray.zeros(2, 2)
     assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
+
+    val float64Array = NDArray.zeros(Shape(2, 2), dtype = DType.Float64)
+    assert(float64Array.toFloat64Array === Array(0d, 0d, 0d, 0d))
   }
 
   test("to scalar") {
@@ -38,8 +41,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(ndones.toScalar === 1f)
   }
 
+  test("to float 64 scalar") {
+    val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64)
+    assert(ndzeros.toFloat64Scalar === 0d)
+    val ndones = NDArray.ones(Shape(1), dtype = DType.Float64)
+    assert(ndones.toFloat64Scalar === 1d)
+  }
+
   test ("call toScalar on an ndarray which is not a scalar") {
     intercept[Exception] { NDArray.zeros(1, 1).toScalar }
+    intercept[Exception] { NDArray.zeros(shape = Shape (1, 1),
+      dtype = DType.Float64).toFloat64Scalar }
   }
 
   test("size and shape") {
@@ -51,12 +63,20 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("dtype") {
     val arr = NDArray.zeros(3, 2)
     assert(arr.dtype === DType.Float32)
+
+    val float64Array = NDArray.zeros(shape = Shape(3, 2), dtype = DType.Float64)
+    assert(float64Array.dtype === DType.Float64)
   }
 
   test("set scalar value") {
     val ndarray = NDArray.empty(2, 1)
     ndarray.set(10f)
     assert(ndarray.toArray === Array(10f, 10f))
+
+    val float64array = NDArray.empty(shape = Shape(2, 1), dtype = DType.Float64)
+    float64array.set(10d)
+    assert(float64array.toFloat64Array === Array(10d, 10d))
+
   }
 
   test("copy from java array") {
@@ -66,19 +86,29 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("plus") {
-    val ndzeros = NDArray.zeros(2, 1)
-    val ndones = ndzeros + 1f
+    var ndzeros = NDArray.zeros(2, 1)
+    var ndones = ndzeros + 1f
     assert(ndones.toArray === Array(1f, 1f))
     assert((ndones + ndzeros).toArray === Array(1f, 1f))
     assert((1 + ndones).toArray === Array(2f, 2f))
     // in-place
     ndones += ndones
     assert(ndones.toArray === Array(2f, 2f))
+
+    // Float64 method test
+    ndzeros = NDArray.zeros(shape = Shape(2, 1), dtype = DType.Float64)
+    ndones = ndzeros + 1d
+    assert(ndones.toFloat64Array === Array(1d, 1d))
+    assert((ndones + ndzeros).toFloat64Array === Array(1d, 1d))
+    assert((1d + ndones).toArray === Array(2d, 2d))
+    // in-place
+    ndones += ndones
+    assert(ndones.toFloat64Array === Array(2d, 2d))
   }
 
   test("minus") {
-    val ndones = NDArray.ones(2, 1)
-    val ndzeros = ndones - 1f
+    var ndones = NDArray.ones(2, 1)
+    var ndzeros = ndones - 1f
     assert(ndzeros.toArray === Array(0f, 0f))
     assert((ndones - ndzeros).toArray === Array(1f, 1f))
     assert((ndzeros - ndones).toArray === Array(-1f, -1f))
@@ -86,23 +116,46 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     // in-place
     ndones -= ndones
     assert(ndones.toArray === Array(0f, 0f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape(2, 1))
+    ndzeros = ndones - 1d
+    assert(ndzeros.toFloat64Array === Array(0d, 0d))
+    assert((ndones - ndzeros).toFloat64Array === Array(1d , 1d))
+    assert((ndzeros - ndones).toFloat64Array === Array(-1d , -1d))
+    assert((ndones - 1).toFloat64Array === Array(0d, 0d))
+    // in-place
+    ndones -= ndones
+    assert(ndones.toArray === Array(0d, 0d))
+
   }
 
   test("multiplication") {
-    val ndones = NDArray.ones(2, 1)
-    val ndtwos = ndones * 2
+    var ndones = NDArray.ones(2, 1)
+    var ndtwos = ndones * 2
     assert(ndtwos.toArray === Array(2f, 2f))
     assert((ndones * ndones).toArray === Array(1f, 1f))
     assert((ndtwos * ndtwos).toArray === Array(4f, 4f))
     ndtwos *= ndtwos
     // in-place
     assert(ndtwos.toArray === Array(4f, 4f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape(2, 1), dtype = DType.Float64)
+    ndtwos = ndones * 2d
+    assert(ndtwos.toFloat64Array === Array(2d, 2d))
+    assert((ndones * ndones).toFloat64Array === Array(1d, 1d))
+    assert((ndtwos * ndtwos).toFloat64Array === Array(4d, 4d))
+    ndtwos *= ndtwos
+    // in-place
+    assert(ndtwos.toFloat64Array === Array(4d, 4d))
+
   }
 
   test("division") {
-    val ndones = NDArray.ones(2, 1)
-    val ndzeros = ndones - 1f
-    val ndhalves = ndones / 2
+    var ndones = NDArray.ones(2, 1)
+    var ndzeros = ndones - 1f
+    var ndhalves = ndones / 2
     assert(ndhalves.toArray === Array(0.5f, 0.5f))
     assert((ndhalves / ndhalves).toArray === Array(1f, 1f))
     assert((ndones / ndones).toArray === Array(1f, 1f))
@@ -110,37 +163,75 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     ndhalves /= ndhalves
     // in-place
     assert(ndhalves.toArray === Array(1f, 1f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape (2, 1), dtype = DType.Float64)
+    ndzeros = ndones - 1d
+    ndhalves = ndones / 2d
+    assert(ndhalves.toFloat64Array === Array(0.5d, 0.5d))
+    assert((ndhalves / ndhalves).toFloat64Array === Array(1d, 1d))
+    assert((ndones / ndones).toFloat64Array === Array(1d, 1d))
+    assert((ndzeros / ndones).toFloat64Array === Array(0d, 0d))
+    ndhalves /= ndhalves
+    // in-place
+    assert(ndhalves.toFloat64Array === Array(1d, 1d))
   }
 
   test("full") {
-    val arr = NDArray.full(Shape(1, 2), 3f)
+    var arr = NDArray.full(Shape(1, 2), 3f)
     assert(arr.shape === Shape(1, 2))
     assert(arr.toArray === Array(3f, 3f))
+
+    // Float64 methods test
+    arr = NDArray.full(Shape(1, 2), value = 5d, Context.cpu())
+    assert(arr.toFloat64Array === Array (5d, 5d))
   }
 
   test("clip") {
-    val ndarray = NDArray.empty(3, 2)
+    var ndarray = NDArray.empty(3, 2)
     ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f))
     assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f))
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(3, 2), dtype = DType.Float64)
+    ndarray.set(Array(1d, 2d, 3d, 4d, 5d, 6d))
+    assert(NDArray.clip(ndarray, 2d, 5d).toFloat64Array === Array(2d, 2d, 3d, 4d, 5d, 5d))
   }
 
   test("sqrt") {
-    val ndarray = NDArray.empty(4, 1)
+    var ndarray = NDArray.empty(4, 1)
     ndarray.set(Array(0f, 1f, 4f, 9f))
     assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f))
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(4, 1), dtype = DType.Float64)
+    ndarray.set(Array(0d, 1d, 4d, 9d))
+    assert(NDArray.sqrt(ndarray).toFloat64Array === Array(0d, 1d, 2d, 3d))
   }
 
   test("rsqrt") {
-    val ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
+    var ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
     assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f))
+
+    // Float64 methods test
+    ndarray = NDArray.array(Array(1d, 4d, 25d), shape = Shape(3, 1), Context.cpu())
+    assert(NDArray.rsqrt(ndarray).toFloat64Array === Array(1d, 0.5d, 0.2d))
   }
 
   test("norm") {
-    val ndarray = NDArray.empty(3, 1)
+    var ndarray = NDArray.empty(3, 1)
     ndarray.set(Array(1f, 2f, 3f))
-    val normed = NDArray.norm(ndarray)
+    var normed = NDArray.norm(ndarray)
     assert(normed.shape === Shape(1))
     assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f)
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(3, 1), dtype = DType.Float64)
+    ndarray.set(Array(1d, 2d, 3d))
+    normed = NDArray.norm(ndarray)
+    assert(normed.get.dtype === DType.Float64)
+    assert(normed.shape === Shape(1))
+    assert(normed.toFloat64Scalar === math.sqrt(14.0) +- 1e-3d)
   }
 
   test("one hot encode") {
@@ -176,25 +267,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("power") {
-    val arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
+    var arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
 
-    val arrPower1 = NDArray.power(2f, arr)
+    var arrPower1 = NDArray.power(2f, arr)
     assert(arrPower1.shape === Shape(2, 1))
     assert(arrPower1.toArray === Array(8f, 32f))
 
-    val arrPower2 = NDArray.power(arr, 2f)
+    var arrPower2 = NDArray.power(arr, 2f)
     assert(arrPower2.shape === Shape(2, 1))
     assert(arrPower2.toArray === Array(9f, 25f))
 
-    val arrPower3 = NDArray.power(arr, arr)
+    var arrPower3 = NDArray.power(arr, arr)
     assert(arrPower3.shape === Shape(2, 1))
     assert(arrPower3.toArray === Array(27f, 3125f))
 
-    val arrPower4 = arr ** 2f
+    var arrPower4 = arr ** 2f
+
     assert(arrPower4.shape === Shape(2, 1))
     assert(arrPower4.toArray === Array(9f, 25f))
 
-    val arrPower5 = arr ** arr
+    var arrPower5 = arr ** arr
     assert(arrPower5.shape === Shape(2, 1))
     assert(arrPower5.toArray === Array(27f, 3125f))
 
@@ -206,84 +298,211 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     arr **= arr
     assert(arr.shape === Shape(2, 1))
     assert(arr.toArray === Array(27f, 3125f))
+
+    // Float64 tests
+    arr = NDArray.array(Array(3d, 5d), shape = Shape(2, 1))
+
+    arrPower1 = NDArray.power(2d, arr)
+    assert(arrPower1.shape === Shape(2, 1))
+    assert(arrPower1.dtype === DType.Float64)
+    assert(arrPower1.toFloat64Array === Array(8d, 32d))
+
+    arrPower2 = NDArray.power(arr, 2d)
+    assert(arrPower2.shape === Shape(2, 1))
+    assert(arrPower2.dtype === DType.Float64)
+    assert(arrPower2.toFloat64Array === Array(9d, 25d))
+
+    arrPower3 = NDArray.power(arr, arr)
+    assert(arrPower3.shape === Shape(2, 1))
+    assert(arrPower3.dtype === DType.Float64)
+    assert(arrPower3.toFloat64Array === Array(27d, 3125d))
+
+    arrPower4 = arr ** 2f
+    assert(arrPower4.shape === Shape(2, 1))
+    assert(arrPower4.dtype === DType.Float64)
+    assert(arrPower4.toFloat64Array === Array(9d, 25d))
+
+    arrPower5 = arr ** arr
+    assert(arrPower5.shape === Shape(2, 1))
+    assert(arrPower5.dtype === DType.Float64)
+    assert(arrPower5.toFloat64Array === Array(27d, 3125d))
+
+    arr **= 2d
+    assert(arr.shape === Shape(2, 1))
+    assert(arr.dtype === DType.Float64)
+    assert(arr.toFloat64Array === Array(9d, 25d))
+
+    arr.set(Array(3d, 5d))
+    arr **= arr
+    assert(arr.shape === Shape(2, 1))
+    assert(arr.dtype === DType.Float64)
+    assert(arr.toFloat64Array === Array(27d, 3125d))
   }
 
   test("equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = NDArray.equal(arr1, arr2)
+    var arrEqual1 = NDArray.equal(arr1, arr2)
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
 
-    val arrEqual2 = NDArray.equal(arr1, 3f)
+    var arrEqual2 = NDArray.equal(arr1, 3f)
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 0f, 1f, 0f))
+
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = NDArray.equal(arr1, arr2)
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+    arrEqual2 = NDArray.equal(arr1, 3d)
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 0d))
   }
 
   test("not_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = NDArray.notEqual(arr1, arr2)
+    var arrEqual1 = NDArray.notEqual(arr1, arr2)
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
 
-    val arrEqual2 = NDArray.notEqual(arr1, 3f)
+    var arrEqual2 = NDArray.notEqual(arr1, 3f)
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 1f, 0f, 1f))
+
+    // Float64 methods test
+
+    arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = NDArray.notEqual(arr1, arr2)
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+    arrEqual2 = NDArray.notEqual(arr1, 3d)
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
   }
 
   test("greater") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 > arr2
+    var arrEqual1 = arr1 > arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 0f, 1f, 0f))
 
-    val arrEqual2 = arr1 > 2f
+    var arrEqual2 = arr1 > 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 0f, 1f, 1f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 > arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 0d, 1d, 0d))
+
+    arrEqual2 = arr1 > 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 1d))
   }
 
   test("greater_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 >= arr2
+    var arrEqual1 = arr1 >= arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
 
-    val arrEqual2 = arr1 >= 2f
+    var arrEqual2 = arr1 >= 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 1f, 1f, 1f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 >= arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+    arrEqual2 = arr1 >= 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 1d, 1d, 1d))
   }
 
   test("lesser") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 < arr2
+    var arrEqual1 = arr1 < arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
 
-    val arrEqual2 = arr1 < 2f
+    var arrEqual2 = arr1 < 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 0f, 0f, 0f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 < arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+    arrEqual2 = arr1 < 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 0d, 0d, 0d))
+
   }
 
   test("lesser_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 <= arr2
+    var arrEqual1 = arr1 <= arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 1f, 0f, 1f))
 
-    val arrEqual2 = arr1 <= 2f
+    var arrEqual2 = arr1 <= 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 1f, 0f, 0f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 <= arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
+    arrEqual2 = arr1 <= 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 0d))
   }
 
   test("choose_element_0index") {
@@ -294,11 +513,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("copy to") {
-    val source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
-    val dest = NDArray.empty(1, 3)
+    var source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
+    var dest = NDArray.empty(1, 3)
     source.copyTo(dest)
     assert(dest.shape === Shape(1, 3))
     assert(dest.toArray === Array(1f, 2f, 3f))
+
+    // Float64 methods test
+    source = NDArray.array(Array(1d, 2d, 3d), shape = Shape(1, 3))
+    dest = NDArray.empty(shape = Shape(1, 3), dtype = DType.Float64)
+    source.copyTo(dest)
+    assert(dest.dtype === DType.Float64)
+    assert(dest.toFloat64Array === Array(1d, 2d, 3d))
   }
 
   test("abs") {
@@ -365,6 +591,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.maximum(arr1, arr2)
     assert(arr.shape === Shape(3, 1))
     assert(arr.toArray === Array(4f, 2.1f, 3.7f))
+
+    // Float64 methods test
+    val arr3 = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1))
+    val maxArr = NDArray.maximum(arr3, 10d)
+    assert(maxArr.shape === Shape(3, 1))
+    assert(maxArr.toArray === Array(10d, 10d, 10d))
   }
 
   test("min") {
@@ -378,11 +610,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.minimum(arr1, arr2)
     assert(arr.shape === Shape(3, 1))
     assert(arr.toArray === Array(1.5f, 1f, 3.5f))
+
+    // Float64 methods test
+    val arr3 = NDArray.array(Array(4d, 5d, 6d), shape = Shape(3, 1))
+    val minArr = NDArray.minimum(arr3, 5d)
+    assert(minArr.shape === Shape(3, 1))
+    assert(minArr.toFloat64Array === Array(4d, 5d, 5d))
   }
 
   test("sum") {
-    val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
+    var arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
     assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f)
+
   }
 
   test("argmaxChannel") {
@@ -398,6 +637,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.concatenate(arr1, arr2)
     assert(arr.shape === Shape(3, 3))
     assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f))
+
+    // Try concatenating float32 arr with float64 arr. Should get exception
+    intercept[Exception] {
+      val arr3 = NDArray.array(Array (5d, 6d, 7d), shape = Shape(1, 3))
+      NDArray.concatenate(Array(arr1, arr3))
+    }
   }
 
   test("concatenate axis-1") {
@@ -406,6 +651,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.concatenate(Array(arr1, arr2), axis = 1)
     assert(arr.shape === Shape(2, 3))
     assert(arr.toArray === Array(1f, 2f, 5f, 3f, 4f, 6f))
+
+    // Try concatenating float32 arr with float64 arr. Should get exception
+    intercept[Exception] {
+      val arr3 = NDArray.array(Array (5d, 6d), shape = Shape(2, 1))
+      NDArray.concatenate(Array(arr1, arr3), axis = 1)
+    }
   }
 
   test("transpose") {
@@ -428,6 +679,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val loadedArray = arrays(0)
       assert(loadedArray.shape === Shape(3, 1))
       assert(loadedArray.toArray === Array(1f, 2f, 3f))
+      assert(loadedArray.dtype === DType.Float32)
+    } finally {
+      val file = new File(filename)
+      file.delete()
+    }
+
+    // Try the same for Float64 array
+    try {
+      val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+      NDArray.save(filename, Map("local" -> ndarray))
+      val (keys, arrays) = NDArray.load(filename)
+      assert(keys.length === 1)
+      assert(keys(0) === "local")
+      assert(arrays.length === 1)
+      val loadedArray = arrays(0)
+      assert(loadedArray.shape === Shape(3, 1))
+      assert(loadedArray.toArray === Array(1d, 2d, 3d))
+      assert(loadedArray.dtype === DType.Float64)
     } finally {
       val file = new File(filename)
       file.delete()
@@ -446,6 +715,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val loadedArray = arrays(0)
       assert(loadedArray.shape === Shape(3, 1))
       assert(loadedArray.toArray === Array(1f, 2f, 3f))
+      assert(loadedArray.dtype === DType.Float32)
+    } finally {
+      val file = new File(filename)
+      file.delete()
+    }
+
+    // Try the same thing for Float64 array :
+
+    try {
+      val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+      NDArray.save(filename, Array(ndarray))
+      val (keys, arrays) = NDArray.load(filename)
+      assert(keys.length === 0)
+      assert(arrays.length === 1)
+      val loadedArray = arrays(0)
+      assert(loadedArray.shape === Shape(3, 1))
+      assert(loadedArray.toArray === Array(1d, 2d, 3d))
+      assert(loadedArray.dtype === DType.Float64)
     } finally {
       val file = new File(filename)
       file.delete()
@@ -464,9 +751,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1))
     val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
     val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Shape(3, 1))
+    val ndarray5 = NDArray.array(Array(3d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
     ndarray1 shouldEqual ndarray2
     ndarray1 shouldNot equal(ndarray3)
     ndarray1 shouldNot equal(ndarray4)
+    ndarray5 shouldNot equal(ndarray3)
   }
 
   test("slice") {
@@ -545,6 +834,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val bytes = arr.serialize()
     val arrCopy = NDArray.deserialize(bytes)
     assert(arr === arrCopy)
+    assert(arrCopy.dtype === DType.Float32)
   }
 
   test("dtype int32") {
@@ -580,18 +870,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("NDArray random module is generated properly") {
     val lam = NDArray.ones(1, 2)
     val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
-    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)),
+      dtype = Some("float64"))
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
+    assert(rnd2.head.dtype === DType.Float64)
   }
 
   test("NDArray random module is generated properly - special case of 'normal'") {
     val mu = NDArray.ones(1, 2)
     val sigma = NDArray.ones(1, 2) * 2
     val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
-    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)),
+      dtype = Some("float64"))
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
+    assert(rnd2.head.dtype === DType.Float64)
   }
 
   test("Generated api") {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index f6c283c..9f0430e 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -19,6 +19,7 @@ package org.apache.mxnetexamples.imclassification
 
 import java.util.concurrent._
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnetexamples.imclassification.models._
 import org.apache.mxnetexamples.imclassification.util.Trainer
 import org.apache.mxnet._
@@ -42,12 +43,13 @@ object TrainModel {
     * @return The final validation accuracy
     */
   def test(model: String, dataPath: String, numExamples: Int = 60000,
-           numEpochs: Int = 10, benchmark: Boolean = false): Float = {
+           numEpochs: Int = 10, benchmark: Boolean = false,
+           dtype: DType = DType.Float32): Float = {
     ResourceScope.using() {
       val devs = Array(Context.cpu(0))
       val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
       val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
-        numExamples = numExamples, benchmark = benchmark)
+        numExamples = numExamples, benchmark = benchmark, dtype = dtype)
       val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
         network = net, dataLoader = dataLoader,
         kvStore = "local", numEpochs = numEpochs)
@@ -69,7 +71,7 @@ object TrainModel {
     */
   def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
                          numLayers: Int = 50, numExamples: Int = 60000,
-                         benchmark: Boolean = false
+                         benchmark: Boolean = false, dtype: DType = DType.Float32
                         ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
     val (imageShape, numClasses) = dataset match {
       case "mnist" => (List(1, 28, 28), 10)
@@ -80,16 +82,17 @@ object TrainModel {
     val List(channels, height, width) = imageShape
     val dataSize: Int = channels * height * width
     val (datumShape, net) = model match {
-      case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses))
-      case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses))
+      case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses, dtype = dtype))
+      case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses, dtype = dtype))
       case "resnet" => (List(channels, height, width), Resnet.getSymbol(numClasses,
-        numLayers, imageShape))
+        numLayers, imageShape, dtype = dtype))
       case _ => throw new Exception("Invalid model name")
     }
 
     val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
       (batchSize: Int, kv: KVStore) => {
-        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples)
+        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples,
+          dtype)
         (iter, iter)
       }
     } else {
@@ -116,8 +119,10 @@ object TrainModel {
         val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
         else inst.dataDir
 
+        val dtype = DType.withName(inst.dType)
+
         val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
-          inst.numLayers, inst.numExamples, inst.benchmark)
+          inst.numLayers, inst.numExamples, inst.benchmark, dtype)
 
         val devs =
           if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
@@ -210,5 +215,8 @@ class TrainModel {
   private val numWorker: Int = 1
   @Option(name = "--num-server", usage = "# of servers")
   private val numServer: Int = 1
+  @Option(name = "--dtype", usage = "data type of the model to train. " +
+    "Can be float32/float64. Works only with synthetic data currently")
+  private val dType: String = "float32"
 }
 
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
index 9421f10..e4d3b2a 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable.ListMap
 import scala.util.Random
 
 class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[Int],
-                        labelShape: List[Int], maxIter: Int, dtype: DType = DType.Float32
+                        labelShape: List[Int], maxIter: Int, dType: DType = DType.Float32
                        ) extends DataIter {
   var curIter = 0
   val random = new Random()
@@ -35,12 +35,12 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In
   var label: IndexedSeq[NDArray] = IndexedSeq(
     NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = Some(batchLabelShape)))
   var data: IndexedSeq[NDArray] = IndexedSeq(
-    NDArray.api.random_uniform(shape = Some(shape)))
+    NDArray.api.random_uniform(shape = Some(shape), dtype = Some(dType.toString)))
 
   val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
-    new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+    new DataDesc("data", shape, data(0).dtype, Layout.UNDEFINED))
   val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
-    new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+    new DataDesc("softmax_label", batchLabelShape, label(0).dtype, Layout.UNDEFINED))
   val getPad: Int = 0
 
   override def getData(): IndexedSeq[NDArray] = data
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
index 76fb7bb..6f8b138 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object Lenet {
@@ -26,8 +27,8 @@ object Lenet {
     * @param numClasses Number of classes to classify into
     * @return model symbol
     */
-  def getSymbol(numClasses: Int): Symbol = {
-    val data = Symbol.Variable("data")
+  def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+    val data = Symbol.Variable("data", dType = dtype)
     // first conv
     val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
     val tanh1 = Symbol.api.tanh(data = Some(conv1))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
index 5d880bb..089b65f 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object MultiLayerPerceptron {
@@ -26,8 +27,8 @@ object MultiLayerPerceptron {
     * @param numClasses Number of classes to classify into
     * @return model symbol
     */
-  def getSymbol(numClasses: Int): Symbol = {
-    val data = Symbol.Variable("data")
+  def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+    val data = Symbol.Variable("data", dType = dtype)
 
     val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
     val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
index c3f43d9..e5f5976 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object Resnet {
@@ -77,13 +78,14 @@ object Resnet {
     */
   def resnet(units: List[Int], numStages: Int, filterList: List[Int], numClasses: Int,
              imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 0.9f,
-             workspace: Int = 256, dtype: String = "float32", memonger: Boolean = false): Symbol = {
+             workspace: Int = 256, dtype: DType = DType.Float32,
+             memonger: Boolean = false): Symbol = {
     assert(units.size == numStages)
     var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), dType = DType.Float32)
-    if (dtype == "float32") {
+    if (dtype == DType.Float32) {
       data = Symbol.api.identity(Some(data), "id")
-    } else if (dtype == "float16") {
-      data = Symbol.api.cast(Some(data), "float16")
+    } else if (dtype == DType.Float16) {
+      data = Symbol.api.cast(Some(data), DType.Float16.toString)
     }
     data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = Some(2e-5),
       momentum = Some(bnMom), name = "bn_data")
@@ -118,8 +120,8 @@ object Resnet {
       kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
     val flat = Symbol.api.Flatten(Some(pool1))
     var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, name = "fc1")
-    if (dtype == "float16") {
-      fc1 = Symbol.api.cast(Some(fc1), "float32")
+    if (dtype == DType.Float16) {
+      fc1 = Symbol.api.cast(Some(fc1), DType.Float32.toString)
     }
     Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
   }
@@ -134,7 +136,7 @@ object Resnet {
     * @return Model symbol
     */
   def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], convWorkspace: Int = 256,
-                dtype: String = "float32"): Symbol = {
+                dtype: DType = DType.Float32): Symbol = {
     val List(channels, height, width) = imageShape
     val (numStages, units, filterList, bottleNeck): (Int, List[Int], List[Int], Boolean) =
       if (height <= 28) {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 6e9667a..0daba5a 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.imclassification
 
 import java.io.File
 
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, DType}
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
@@ -55,9 +55,15 @@ class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
 
   for(model <- List("mlp", "lenet", "resnet")) {
     test(s"Example CI: Test Image Classification Model ${model}") {
-      var context = Context.cpu()
       val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
     }
   }
 
+  for(model <- List("mlp", "lenet", "resnet")) {
+    test(s"Example CI: Test Image Classification Model ${model} with Float64 input") {
+      val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true,
+        dtype = DType.Float64)
+    }
+  }
+
 }
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index 5208923..bf65815 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -17,9 +17,10 @@
 
 package org.apache.mxnet.infer
 
-import org.apache.mxnet.{Context, DataDesc, NDArray}
+import org.apache.mxnet._
 import java.io.File
 
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
 import org.slf4j.LoggerFactory
 
 import scala.io
@@ -30,13 +31,13 @@ trait ClassifierBase {
 
   /**
     * Takes an array of floats and returns corresponding (Label, Score) tuples
-    * @param input            Indexed sequence one-dimensional array of floats
+    * @param input            Indexed sequence one-dimensional array of floats/doubles
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
     * @return                 Indexed sequence of (Label, Score) tuples
     */
-  def classify(input: IndexedSeq[Array[Float]],
-               topK: Option[Int] = None): IndexedSeq[(String, Float)]
+  def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+               topK: Option[Int] = None): IndexedSeq[(String, T)]
 
   /**
     * Takes a sequence of NDArrays and returns (Label, Score) tuples
@@ -78,17 +79,35 @@ class Classifier(modelPathPrefix: String,
 
   /**
     * Takes flat arrays as input and returns (Label, Score) tuples.
-    * @param input            Indexed sequence one-dimensional array of floats
+    * @param input            Indexed sequence one-dimensional array of floats/doubles
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
     * @return                 Indexed sequence of (Label, Score) tuples
     */
-  override def classify(input: IndexedSeq[Array[Float]],
-                        topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
+  override def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+                        topK: Option[Int] = None): IndexedSeq[(String, T)] = {
+
+    // considering only the first output
+    val result = input(0)(0) match {
+      case d: Double => {
+        classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
+      }
+      case _ => {
+        classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
+      }
+    }
+
+    result.asInstanceOf[IndexedSeq[(String, T)]]
+  }
+
+  private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE]
+  (input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A)
+  : IndexedSeq[(String, B)] = {
 
     // considering only the first output
     val predictResult = predictor.predict(input)(0)
-    var result: IndexedSeq[(String, Float)] = IndexedSeq.empty
+
+    var result: IndexedSeq[(String, B)] = IndexedSeq.empty
 
     if (topK.isDefined) {
       val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
@@ -105,7 +124,7 @@ class Classifier(modelPathPrefix: String,
     * @param input            Indexed sequence of NDArrays
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
-    * @return                 Traversable sequence of (Label, Score) tuples
+    * @return                 Traversable sequence of (Label, Score) tuples.
     */
   override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
   : IndexedSeq[IndexedSeq[(String, Float)]] = {
@@ -113,7 +132,7 @@ class Classifier(modelPathPrefix: String,
     // considering only the first output
     // Copy NDArray to CPU to avoid frequent GPU to CPU copying
     val predictResultND: NDArray =
-      predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
+    predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
     // Parallel Execution with ParArray for better performance
     val predictResultPar: ParArray[Array[Float]] =
       new ParArray[Array[Float]](predictResultND.shape(0))
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
index 96be121..3c80f92 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.infer
 
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 
 import scala.collection.mutable.ListBuffer
 
@@ -70,14 +71,18 @@ class ImageClassifier(modelPathPrefix: String,
     *
     * @param inputImage       Path prefix of the input image
     * @param topK             Number of result elements to return, sorted by probability
+    * @param dType            The precision at which to run the inference.
+    *                         specify the DType as DType.Float64 for Double precision.
+    *                         Defaults to DType.Float32
     * @return                 List of list of tuples of (Label, Probability)
     */
-  def classifyImage(inputImage: BufferedImage,
-                    topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = {
+  def classifyImage
+  (inputImage: BufferedImage, topK: Option[Int] = None, dType: DType = DType.Float32):
+  IndexedSeq[IndexedSeq[(String, Float)]] = {
 
     val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
     val imageShape = inputShape.drop(1)
-    val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+    val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
     val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0)
     inputImage.flush()
     scaledImage.flush()
@@ -95,16 +100,19 @@ class ImageClassifier(modelPathPrefix: String,
     *
     * @param inputBatch       Input array of buffered images
     * @param topK             Number of result elements to return, sorted by probability
+    * @param dType            The precision at which to run the inference.
+    *                         specify the DType as DType.Float64 for Double precision.
+    *                         Defaults to DType.Float32
     * @return                 List of list of tuples of (Label, Probability)
     */
-  def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
-  IndexedSeq[IndexedSeq[(String, Float)]] = {
+  def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None,
+   dType: DType = DType.Float32): IndexedSeq[IndexedSeq[(String, Float)]] = {
 
     val inputBatchSeq = inputBatch.toIndexedSeq
     val imageBatch = inputBatchSeq.indices.par.map(idx => {
       val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
       val imageShape = inputShape.drop(1)
-      val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+      val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
       val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get
       handler.execute(imgND.dispose())
       imgWithBatch
@@ -152,11 +160,29 @@ object ImageClassifier {
     * returned by this method after the use.
     * </p>
     * @param resizedImage     BufferedImage to get pixels from
+    *
     * @param inputImageShape  Input shape; for example for resnet it is (3,224,224).
                               Should be same as inputDescriptor shape.
+    * @param dType            The DataType of the NDArray created from the image
+    *                         that should be returned.
+    *                         Currently it defaults to Dtype.Float32
     * @return                 NDArray pixels array with shape (3, 224, 224) in CHW format
     */
-  def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
+  def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape,
+                            dType : DType = DType.Float32): NDArray = {
+
+      if (dType == DType.Float64) {
+        val result = getFloatPixelsArray(resizedImage)
+        NDArray.array(result.map(_.toDouble), shape = inputImageShape)
+      }
+      else {
+        val result = getFloatPixelsArray(resizedImage)
+        NDArray.array(result, shape = inputImageShape)
+      }
+  }
+
+  private def getFloatPixelsArray(resizedImage: BufferedImage): Array[Float] = {
+
     // Get height and width of the image
     val w = resizedImage.getWidth()
     val h = resizedImage.getHeight()
@@ -166,7 +192,6 @@ object ImageClassifier {
 
     // 3 times height and width for R,G,B channels
     val result = new Array[Float](3 * h * w)
-
     var row = 0
     // copy pixels to array vertically
     while (row < h) {
@@ -184,11 +209,10 @@ object ImageClassifier {
       }
       row += 1
     }
+
     resizedImage.flush()
 
-    // creating NDArray according to the input shape
-    val pixelsArray = NDArray.array(result, shape = inputImageShape)
-    pixelsArray
+    result
   }
 
   /**
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index d4bce9f..67692a3 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -17,8 +17,9 @@
 
 package org.apache.mxnet.infer
 
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
 import org.apache.mxnet.io.NDArrayIter
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet._
 import org.apache.mxnet.module.Module
 
 import scala.collection.mutable.ListBuffer
@@ -36,11 +37,13 @@ private[infer] trait PredictBase {
    * <p>
    * This method will take input as IndexedSeq one dimensional arrays and creates the
    * NDArray needed for inference. The array will be reshaped based on the input descriptors.
-   * @param input:            An IndexedSequence of a one-dimensional array.
+   * @param input:            An Indexed Sequence of a one-dimensional array of datatype
+    *                         Float or Double
                               An IndexedSequence is needed when the model has more than one input.
    * @return                  Indexed sequence array of outputs
    */
-  def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
+  def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+  : IndexedSeq[Array[T]]
 
   /**
    * Predict using NDArray as input.
@@ -123,13 +126,13 @@ class Predictor(modelPathPrefix: String,
    * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
    * The array will be reshaped based on the input descriptors.
    *
-   * @param input:            An IndexedSequence of a one-dimensional array.
+   * @param input:            An IndexedSequence of a one-dimensional array
+    *                         of data type Float or Double.
                               An IndexedSequence is needed when the model has more than one input.
    * @return                  Indexed sequence array of outputs
    */
-  override def predict(input: IndexedSeq[Array[Float]])
-  : IndexedSeq[Array[Float]] = {
-
+  override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+  : IndexedSeq[Array[T]] = {
     require(input.length == inputDescriptors.length,
       s"number of inputs provided: ${input.length} does not match number of inputs " +
         s"in inputDescriptors: ${inputDescriptors.length}")
@@ -139,12 +142,30 @@ class Predictor(modelPathPrefix: String,
         s"number of elements:${i.length} in the input does not match the shape:" +
           s"${d.shape.toString()}")
     }
+
+    // Infer the dtype of input and call relevant method
+    val result = input(0)(0) match {
+      case d: Double => predictImpl(input.asInstanceOf[IndexedSeq[Array[Double]]])
+      case _ => predictImpl(input.asInstanceOf[IndexedSeq[Array[Float]]])
+    }
+
+    result.asInstanceOf[IndexedSeq[Array[T]]]
+  }
+
+  private def predictImpl[B, A <: MX_PRIMITIVE_TYPE]
+  (input: IndexedSeq[Array[B]])(implicit ev: B => A)
+  : IndexedSeq[Array[B]] = {
+
     var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
 
     for((i, d) <- input.zip(inputDescriptors)) {
       val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)
-
-      inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
+      if (d.dtype == DType.Float64) {
+        inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Double]], Shape(shape)))
+      }
+      else {
+        inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Float]], Shape(shape)))
+      }
     }
 
     // rebind with batchsize 1
@@ -158,7 +179,8 @@ class Predictor(modelPathPrefix: String,
     val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
       inputND.toIndexedSeq, dataBatchSize = 1)))
 
-    val result = resultND.map((f : NDArray) => f.toArray)
+    val result =
+      resultND.map((f : NDArray) => if (f.dtype == DType.Float64) f.toFloat64Array else f.toArray)
 
     mxNetHandler.execute(inputND.foreach(_.dispose))
     mxNetHandler.execute(resultND.foreach(_.dispose))
@@ -168,9 +190,11 @@ class Predictor(modelPathPrefix: String,
       mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
     }
 
-    result
+    result.asInstanceOf[IndexedSeq[Array[B]]]
   }
 
+
+
   /**
    * Predict using NDArray as input
    * This method is useful when the input is a batch of data
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 0466693..146fe93 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -73,6 +73,30 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
   }
 
   /**
+    * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference
+    * The array will be reshaped based on the input descriptors. Example of calling in Java:
+    *
+    * <pre>
+    * {@code
+    * double tmp[][] = new double[1][224];
+    * for (int x = 0; x < 1; x++)
+    *   for (int y = 0; y < 224; y++)
+    *     tmp[x][y] = (int)(Math.random()*10);
+    * predictor.predict(tmp);
+    * }
+    * </pre>
+    *
+    * @param input:            An Array of a one-dimensional array.
+                              An extra Array is needed for when the model has more than one input.
+    * @return                  Indexed sequence array of outputs
+    */
+
+  def predict(input: Array[Array[Double]]):
+  Array[Array[Double]] = {
+    predictor.predict(input).toArray
+  }
+
+  /**
     * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
     * The array will be reshaped based on the input descriptors.
     *
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
index b28aeba..d9ccec4 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
@@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths}
 import java.util
 
 import org.apache.mxnet.module.Module
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.{Context, DType, DataDesc, NDArray, Shape}
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -127,6 +127,29 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
 
   }
 
+  test("ClassifierSuite-flatFloat64Array-topK") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Double](12)(1d)
+
+    val predictResult : IndexedSeq[Array[Double]] =
+      IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+    val result: IndexedSeq[(String, Double)] = testClassifier.
+      classify(IndexedSeq(inputData), topK = Some(10))
+
+    assert((result(0)_2).getClass == 1d.getClass)
+
+    assertResult(predictResult(0).sortBy(-_)) {
+      result.map(_._2).toArray
+    }
+
+  }
+
   test("ClassifierSuite-flatArrayInput") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
     val inputData = Array.fill[Float](12)(1)
@@ -147,6 +170,28 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
     }
   }
 
+  test("ClassifierSuite-flatArrayFloat64Input") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Double](12)(1d)
+
+    val predictResult : IndexedSeq[Array[Double]] =
+      IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+    val result: IndexedSeq[(String, Double)] = testClassifier.
+      classify(IndexedSeq(inputData))
+
+    assert((result(0)_2).getClass == 1d.getClass)
+
+    assertResult(predictResult(0)) {
+      result.map(_._2).toArray
+    }
+  }
+
   test("ClassifierSuite-NDArray1InputWithoutTopK") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
     val inputDataShape = Shape(1, 3, 2, 2)
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
index 1c291e1..5198c4a 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
@@ -68,6 +68,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
     val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2))
 
     assert(result.shape == inputDescriptor(0).shape.drop(1))
+    assert(result.dtype == DType.Float32)
+
+    val resultFloat64 = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2), DType.Float64)
+    assert(resultFloat64.dtype == DType.Float64)
   }
 
   test("ImageClassifierSuite-testWithInputImage") {
@@ -106,8 +110,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
         predictResult(i).map(_._2).toArray
       }
     }
+
   }
 
+
   test("ImageClassifierSuite-testWithInputBatchImage") {
     val dType = DType.Float32
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512),
@@ -152,4 +158,5 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
       }
     }
   }
+
 }
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 509ffb3..9afbc9b 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
+import org.apache.mxnet._
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -91,6 +91,36 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
       , any[Option[BaseModule]], any[String])
   }
 
+  test("PredictorSuite-testWithFlatFloat64Arrays") {
+
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
+      layout = Layout.NCHW, dtype = DType.Float64))
+    val inputData = Array.fill[Double](12)(1d)
+
+    // this will disposed at the end of the predict call on Predictor.
+    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64))
+
+    val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+    val testFun = testPredictor.predict(IndexedSeq(inputData))
+
+    assert(testFun.size == 1, "output size should be 1 ")
+
+    assert(testFun(0)(0).getClass == 1d.getClass)
+
+    assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString)
+
+    // Verify that the module was bound with batch size 1 and rebound back to the original
+    // input descriptor. the number of times is twice here because loadModule overrides the
+    // initial bind.
+    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+      , any[Option[BaseModule]], any[String])
+  }
+
   test("PredictorSuite-testWithNDArray") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
       layout = Layout.NCHW))
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index d684c6d..ea6e9c8 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -424,6 +424,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+  (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
+  jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
+  int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
+                                     static_cast<const double *>(sourcePtr), arrSize);
+  env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
   (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
   int outDevType;
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 40230ac..7e8e03d 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -177,6 +177,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxFloat64NDArraySyncCopyFromCPU
+ * Signature: (J[DI)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+  (JNIEnv *, jobject, jlong, jdoubleArray, jint);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxNDArrayLoad
  * Signature: (Ljava/lang/String;Lorg/apache/mxnet/Base/RefInt;Lscala/collection/mutable/ArrayBuffer;Lorg/apache/mxnet/Base/RefInt;Lscala/collection/mutable/ArrayBuffer;)I
  */