You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/10 23:12:46 UTC

[GitHub] lanking520 closed pull request #13678: [MXNET-1260] Float64 DType computation support in Scala/Java

lanking520 closed pull request #13678: [MXNET-1260] Float64 DType computation support in Scala/Java
URL: https://github.com/apache/incubator-mxnet/pull/13678
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index b9f84d592a7..5b5fdce712f 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -193,6 +193,7 @@ List of Contributors
 * [Yuxi Hu](https://github.com/yuxihu)
 * [Harsh Patel](https://github.com/harshp8l)
 * [Xiao Wang](https://github.com/BeyonderXX)
+* [Piyush Ghai](https://github.com/piyushghai)
 
 Label Bot
 ---------
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
index b2b23da6274..224a39275da 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/infer.clj
@@ -18,6 +18,7 @@
 (ns org.apache.clojure-mxnet.infer
   (:refer-clojure :exclude [type])
   (:require [org.apache.clojure-mxnet.context :as context]
+            [org.apache.clojure-mxnet.dtype :as dtype]
             [org.apache.clojure-mxnet.io :as mx-io]
             [org.apache.clojure-mxnet.shape :as shape]
             [org.apache.clojure-mxnet.util :as util]
@@ -62,10 +63,12 @@
 (defprotocol AImageClassifier
   (classify-image
     [wrapped-image-classifier image]
-    [wrapped-image-classifier image topk])
+    [wrapped-image-classifier image topk]
+    [wrapped-image-classifier image topk dtype])
   (classify-image-batch
     [wrapped-image-classifier images]
-    [wrapped-image-classifier images topk]))
+    [wrapped-image-classifier images topk]
+    [wrapped-image-classifier images topk dtype]))
 
 (defprotocol AObjectDetector
   (detect-objects
@@ -80,7 +83,8 @@
 
 (extend-protocol APredictor
   WrappedPredictor
-  (predict [wrapped-predictor inputs]
+  (predict
+    [wrapped-predictor inputs]
     (util/validate! ::wrapped-predictor wrapped-predictor
                     "Invalid predictor")
     (util/validate! ::vec-of-float-arrays inputs
@@ -101,62 +105,50 @@
 
 (extend-protocol AClassifier
   WrappedClassifier
-  (classify [wrapped-classifier inputs]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (classify wrapped-classifier inputs nil))
-  (classify [wrapped-classifier inputs topk]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classify (:classifier wrapped-classifier)
-                (util/vec->indexed-seq inputs)
-                (util/->int-option topk))))
-  (classify-with-ndarray [wrapped-classifier inputs]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (classify-with-ndarray wrapped-classifier inputs nil))
-  (classify-with-ndarray [wrapped-classifier inputs topk]
-    (util/validate! ::wrapped-classifier wrapped-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyWithNDArray (:classifier wrapped-classifier)
-                           (util/vec->indexed-seq inputs)
-                           (util/->int-option topk))))
+  (classify
+    ([wrapped-classifier inputs]
+     (classify wrapped-classifier inputs nil))
+    ([wrapped-classifier inputs topk]
+     (util/validate! ::wrapped-classifier wrapped-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-float-arrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classify (:classifier wrapped-classifier)
+                 (util/vec->indexed-seq inputs)
+                 (util/->int-option topk)))))
+  (classify-with-ndarray
+    ([wrapped-classifier inputs]
+     (classify-with-ndarray wrapped-classifier inputs nil))
+    ([wrapped-classifier inputs topk]
+     (util/validate! ::wrapped-classifier wrapped-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-ndarrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classifyWithNDArray (:classifier wrapped-classifier)
+                            (util/vec->indexed-seq inputs)
+                           (util/->int-option topk)))))
   WrappedImageClassifier
-  (classify [wrapped-image-classifier inputs]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (classify wrapped-image-classifier inputs nil))
-  (classify [wrapped-image-classifier inputs topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-float-arrays inputs
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classify (:image-classifier wrapped-image-classifier)
-                (util/vec->indexed-seq inputs)
-                (util/->int-option topk))))
-  (classify-with-ndarray [wrapped-image-classifier inputs]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::vec-of-ndarrays inputs
-                    "Invalid inputs")
-    (classify-with-ndarray wrapped-image-classifier inputs nil))
-  (classify-with-ndarray [wrapped-image-classifier inputs topk]
+  (classify
+    ([wrapped-image-classifier inputs]
+     (classify wrapped-image-classifier inputs nil))
+    ([wrapped-image-classifier inputs topk]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::vec-of-float-arrays inputs
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.classify (:image-classifier wrapped-image-classifier)
+                 (util/vec->indexed-seq inputs)
+                 (util/->int-option topk)))))
+  (classify-with-ndarray
+    ([wrapped-image-classifier inputs]
+     (classify-with-ndarray wrapped-image-classifier inputs nil))
+    ([wrapped-image-classifier inputs topk]
     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
                     "Invalid classifier")
     (util/validate! ::vec-of-ndarrays inputs
@@ -165,83 +157,83 @@
     (util/coerce-return-recursive
      (.classifyWithNDArray (:image-classifier wrapped-image-classifier)
                            (util/vec->indexed-seq inputs)
-                           (util/->int-option topk)))))
+                           (util/->int-option topk))))))
 
 (s/def ::image #(instance? BufferedImage %))
+(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
 
 (extend-protocol AImageClassifier
   WrappedImageClassifier
-  (classify-image [wrapped-image-classifier image]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::image image "Invalid image")
-    (classify-image wrapped-image-classifier image nil))
-  (classify-image [wrapped-image-classifier image topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::image image "Invalid image")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyImage (:image-classifier wrapped-image-classifier)
-                     image
-                     (util/->int-option topk))))
-  (classify-image-batch [wrapped-image-classifier images]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (classify-image-batch wrapped-image-classifier images nil))
-  (classify-image-batch [wrapped-image-classifier images topk]
-    (util/validate! ::wrapped-image-classifier wrapped-image-classifier
-                    "Invalid classifier")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.classifyImageBatch (:image-classifier wrapped-image-classifier)
-                          images
-                          (util/->int-option topk)))))
+  (classify-image
+    ([wrapped-image-classifier image]
+     (classify-image wrapped-image-classifier image nil dtype/FLOAT32))
+    ([wrapped-image-classifier image topk]
+     (classify-image wrapped-image-classifier image topk dtype/FLOAT32))
+    ([wrapped-image-classifier image topk dtype]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::image image "Invalid image")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/validate! ::dtype dtype "Invalid dtype")
+     (util/coerce-return-recursive
+      (.classifyImage (:image-classifier wrapped-image-classifier)
+                      image
+                      (util/->int-option topk)
+                      dtype))))
+  (classify-image-batch
+    ([wrapped-image-classifier images]
+     (classify-image-batch wrapped-image-classifier images nil dtype/FLOAT32))
+    ([wrapped-image-classifier images topk]
+         (classify-image-batch wrapped-image-classifier images topk dtype/FLOAT32))
+    ([wrapped-image-classifier images topk dtype]
+     (util/validate! ::wrapped-image-classifier wrapped-image-classifier
+                     "Invalid classifier")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/validate! ::dtype dtype "Invalid dtype")
+     (util/coerce-return-recursive
+      (.classifyImageBatch (:image-classifier wrapped-image-classifier)
+                           images
+                           (util/->int-option topk)
+                           dtype)))))
 
 (extend-protocol AObjectDetector
   WrappedObjectDetector
-  (detect-objects [wrapped-detector image]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::image image "Invalid image")
-    (detect-objects wrapped-detector image nil))
-  (detect-objects [wrapped-detector image topk]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::image image "Invalid image")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.imageObjectDetect (:object-detector wrapped-detector)
-                         image
-                         (util/->int-option topk))))
-  (detect-objects-batch [wrapped-detector images]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (detect-objects-batch wrapped-detector images nil))
-  (detect-objects-batch [wrapped-detector images topk]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.imageBatchObjectDetect (:object-detector wrapped-detector)
-                              images
-                              (util/->int-option topk))))
-  (detect-objects-with-ndarrays [wrapped-detector input-arrays]
-    (util/validate! ::wrapped-detector wrapped-detector
-                    "Invalid object detector")
-    (util/validate! ::vec-of-ndarrays input-arrays
-                    "Invalid inputs")
-    (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
-  (detect-objects-with-ndarrays [wrapped-detector input-arrays topk]
+  (detect-objects
+    ([wrapped-detector image]
+     (detect-objects wrapped-detector image nil))
+    ([wrapped-detector image topk]
     (util/validate! ::wrapped-detector wrapped-detector
                     "Invalid object detector")
-    (util/validate! ::vec-of-ndarrays input-arrays
-                    "Invalid inputs")
-    (util/validate! ::nil-or-int topk "Invalid top-K")
-    (util/coerce-return-recursive
-     (.objectDetectWithNDArray (:object-detector wrapped-detector)
-                               (util/vec->indexed-seq input-arrays)
+     (util/validate! ::image image "Invalid image")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.imageObjectDetect (:object-detector wrapped-detector)
+                          image
+                          (util/->int-option topk)))))
+  (detect-objects-batch
+    ([wrapped-detector images]
+     (detect-objects-batch wrapped-detector images nil))
+    ([wrapped-detector images topk]
+     (util/validate! ::wrapped-detector wrapped-detector
+                     "Invalid object detector")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.imageBatchObjectDetect (:object-detector wrapped-detector)
+                               images
                                (util/->int-option topk)))))
+  (detect-objects-with-ndarrays
+    ([wrapped-detector input-arrays]
+     (detect-objects-with-ndarrays wrapped-detector input-arrays nil))
+    ([wrapped-detector input-arrays topk]
+     (util/validate! ::wrapped-detector wrapped-detector
+                     "Invalid object detector")
+     (util/validate! ::vec-of-ndarrays input-arrays
+                     "Invalid inputs")
+     (util/validate! ::nil-or-int topk "Invalid top-K")
+     (util/coerce-return-recursive
+      (.objectDetectWithNDArray (:object-detector wrapped-detector)
+                                (util/vec->indexed-seq input-arrays)
+                                (util/->int-option topk))))))
 
 (defprotocol AInferenceFactory
   (create-predictor [factory] [factory opts])
@@ -335,7 +327,7 @@
   [image input-shape-vec]
   (util/validate! ::image image "Invalid image")
   (util/validate! (s/coll-of int?) input-shape-vec "Invalid shape vector")
-  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec)))
+  (ImageClassifier/bufferedImageToPixels image (shape/->shape input-shape-vec) dtype/FLOAT32))
 
 (s/def ::image-path string?)
 (s/def ::image-paths (s/coll-of ::image-path))
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
new file mode 100644
index 00000000000..0967df2289d
--- /dev/null
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/primitives.clj
@@ -0,0 +1,46 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives
+  (:import (org.apache.mxnet MX_PRIMITIVES$MX_FLOAT MX_PRIMITIVES$MX_Double
+                             MX_PRIMITIVES$MX_PRIMITIVE_TYPE)))
+
+
+;;; Defines customer mx primitives that can be used for mathematical computations
+;;; in NDArrays to control precision. Currently Float and Double are supported
+
+;;; For purposes of automatic conversion in ndarray functions, doubles are default
+;; to specify using floats you must use a Float
+
+(defn mx-float
+  "Creates a MXNet float primitive"
+  [num]
+  (new MX_PRIMITIVES$MX_FLOAT num))
+
+(defn mx-double
+  "Creates a MXNet double primitive"
+  [num]
+  (new MX_PRIMITIVES$MX_Double num))
+
+(defn ->num
+  "Returns the underlying number value"
+  [primitive]
+  (.data primitive))
+
+(defn primitive? [x]
+  (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE x))
+
diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
index 21e31baa3a9..43970c0abd7 100644
--- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
+++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
@@ -19,6 +19,7 @@
   (:require [clojure.spec.alpha :as s]
             [t6.from-scala.core :refer [$ $$] :as $]
             [clojure.string :as string]
+            [org.apache.clojure-mxnet.primitives :as primitives]
             [org.apache.clojure-mxnet.shape :as mx-shape])
   (:import (org.apache.mxnet NDArray)
            (scala Product Tuple2 Tuple3)
@@ -36,7 +37,8 @@
                            "byte<>" "byte-array"
                            "java.lang.String<>" "vec-or-strings"
                            "org.apache.mxnet.NDArray" "ndarray"
-                           "org.apache.mxnet.Symbol" "sym"})
+                           "org.apache.mxnet.Symbol" "sym"
+                           "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"})
 
 (def symbol-param-coerce {"java.lang.String" "sym-name"
                           "float" "num"
@@ -144,6 +146,8 @@
     (and (get targets "int<>") (vector? param)) (int-array param)
     (and (get targets "float<>") (vector? param)) (float-array param)
     (and (get targets "java.lang.String<>") (vector? param)) (into-array param)
+    (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param)
+    (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param)
     :else param))
 
 (defn nil-or-coerce-param [param targets]
@@ -177,6 +181,7 @@
     (instance? Map return-val) (scala-map->map return-val)
     (instance? Tuple2 return-val) (tuple->vec return-val)
     (instance? Tuple3 return-val) (tuple->vec return-val)
+    (primitives/primitive? return-val) (primitives/->num return-val)
     :else return-val))
 
 (defn coerce-return-recursive [return-val]
diff --git a/contrib/clojure-package/test/good-test-ndarray.clj b/contrib/clojure-package/test/good-test-ndarray.clj
index 3b53b190600..b048a819c64 100644
--- a/contrib/clojure-package/test/good-test-ndarray.clj
+++ b/contrib/clojure-package/test/good-test-ndarray.clj
@@ -27,11 +27,12 @@
 
 (defn
  div
- ([ndarray num-or-ndarray]
+ ([ndarray ndarray-or-double-or-float]
   (util/coerce-return
    (.$div
     ndarray
     (util/coerce-param
-     num-or-ndarray
-     #{"float" "org.apache.mxnet.NDArray"})))))
+     ndarray-or-double-or-float
+     #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"
+       "org.apache.mxnet.NDArray"})))))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
index 9badfed933a..b459b06132b 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/imageclassifier_test.clj
@@ -40,7 +40,11 @@
 (deftest test-single-classification
   (let [classifier (create-classifier)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
-        [predictions] (infer/classify-image classifier image 5)]
+        [predictions-all] (infer/classify-image classifier image)
+        [predictions-with-default-dtype] (infer/classify-image classifier image 10)
+        [predictions] (infer/classify-image classifier image 5 dtype/FLOAT32)]
+    (is (= 1000 (count predictions-all)))
+    (is (= 10 (count predictions-with-default-dtype)))
     (is (some? predictions))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
@@ -58,8 +62,12 @@
   (let [classifier (create-classifier)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
-        batch-predictions (infer/classify-image-batch classifier image-batch 5)
+        batch-predictions-all (infer/classify-image-batch classifier image-batch)
+        batch-predictions-with-default-dtype (infer/classify-image-batch classifier image-batch 10)
+        batch-predictions (infer/classify-image-batch classifier image-batch 5 dtype/FLOAT32)
         predictions (first batch-predictions)]
+    (is (= 1000 (count (first batch-predictions-all))))
+    (is (= 10 (count (first batch-predictions-with-default-dtype))))
     (is (some? batch-predictions))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
index 788a5949109..3a0e3d30a1d 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/infer/objectdetector_test.clj
@@ -40,9 +40,11 @@
 (deftest test-single-detection
   (let [detector (create-detector)
         image (infer/load-image-from-file "test/test-images/kitten.jpg")
+        [predictions-all] (infer/detect-objects detector image)
         [predictions] (infer/detect-objects detector image 5)]
     (is (some? predictions))
     (is (= 5 (count predictions)))
+    (is (= 13 (count predictions-all)))
     (is (every? #(= 2 (count %)) predictions))
     (is (every? #(string? (first %)) predictions))
     (is (every? #(= 5 (count (second %))) predictions))
@@ -53,9 +55,11 @@
   (let [detector (create-detector)
         image-batch (infer/load-image-paths ["test/test-images/kitten.jpg"
                                              "test/test-images/Pug-Cookie.jpg"])
+        batch-predictions-all (infer/detect-objects-batch detector image-batch)
         batch-predictions (infer/detect-objects-batch detector image-batch 5)
         predictions (first batch-predictions)]
     (is (some? batch-predictions))
+    (is (= 13 (count (first batch-predictions-all))))
     (is (= 5 (count predictions)))
     (is (every? #(= 2 (count %)) predictions))
     (is (every? #(string? (first %)) predictions))
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
index 79e94412d0d..9ffd3abed2f 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_test.clj
@@ -97,7 +97,7 @@
     (is (= [1.0 1.0] (->vec ndhalves)))))
 
 (deftest test-full
-  (let [nda (full [1 2] 3)]
+  (let [nda (full [1 2] 3.0)]
     (is (= (shape nda) (mx-shape/->shape [1 2])))
     (is (= [3.0 3.0] (->vec nda)))))
 
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
new file mode 100644
index 00000000000..1a538e537b8
--- /dev/null
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/primitives_test.clj
@@ -0,0 +1,45 @@
+;;
+;; Licensed to the Apache Software Foundation (ASF) under one or more
+;; contributor license agreements.  See the NOTICE file distributed with
+;; this work for additional information regarding copyright ownership.
+;; The ASF licenses this file to You under the Apache License, Version 2.0
+;; (the "License"); you may not use this file except in compliance with
+;; the License.  You may obtain a copy of the License at
+;;
+;;    http://www.apache.org/licenses/LICENSE-2.0
+;;
+;; Unless required by applicable law or agreed to in writing, software
+;; distributed under the License is distributed on an "AS IS" BASIS,
+;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+;; See the License for the specific language governing permissions and
+;; limitations under the License.
+;;
+
+(ns org.apache.clojure-mxnet.primitives-test
+  (:require [org.apache.clojure-mxnet.primitives :as primitives]
+            [clojure.test :refer :all])
+  (:import (org.apache.mxnet MX_PRIMITIVES$MX_PRIMITIVE_TYPE
+                             MX_PRIMITIVES$MX_FLOAT
+                             MX_PRIMITIVES$MX_Double)))
+
+(deftest test-primitive-types
+  (is (not (primitives/primitive? 3)))
+  (is (primitives/primitive? (primitives/mx-float 3)))
+  (is (primitives/primitive? (primitives/mx-double 3))))
+
+(deftest test-float-primitives
+  (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-float 3)))
+  (is (instance? MX_PRIMITIVES$MX_FLOAT (primitives/mx-float 3)))
+  (is (instance? Float (-> (primitives/mx-float 3)
+                           (primitives/->num))))
+  (is (= 3.0 (-> (primitives/mx-float 3)
+                 (primitives/->num)))))
+
+(deftest test-double-primitives
+  (is (instance? MX_PRIMITIVES$MX_PRIMITIVE_TYPE (primitives/mx-double 2)))
+  (is (instance? MX_PRIMITIVES$MX_Double (primitives/mx-double 2)))
+  (is (instance? Double (-> (primitives/mx-double 2)
+                            (primitives/->num))))
+  (is (= 2.0 (-> (primitives/mx-double 2)
+                 (primitives/->num)))))
+
diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
index bd77a8a0edc..c26f83d5aa4 100644
--- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
+++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj
@@ -20,6 +20,7 @@
             [org.apache.clojure-mxnet.shape :as mx-shape]
             [org.apache.clojure-mxnet.util :as util]
             [org.apache.clojure-mxnet.ndarray :as ndarray]
+            [org.apache.clojure-mxnet.primitives :as primitives]
             [org.apache.clojure-mxnet.symbol :as sym]
             [org.apache.clojure-mxnet.test-util :as test-util]
             [clojure.spec.alpha :as s])
@@ -133,6 +134,9 @@
   (is (= "[F"  (->> (util/coerce-param [1 2] #{"float<>"}) str (take 2) (apply str))))
   (is (= "[L"  (->> (util/coerce-param [1 2] #{"java.lang.String<>"}) str (take 2) (apply str))))
 
+  (is (primitives/primitive? (util/coerce-param 1.0 #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+  (is (primitives/primitive? (util/coerce-param (float 1.0) #{"org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE"})))
+
   (is (= 1 (util/coerce-param 1 #{"unknown"}))))
 
 (deftest test-nil-or-coerce-param
@@ -171,6 +175,12 @@
                 (util/convert-tuple [1 2]))))
   (is (= [1 2 3] (util/coerce-return
                   (util/convert-tuple [1 2 3]))))
+
+  (is (instance? Double (util/coerce-return (primitives/mx-double 3))))
+  (is (= 3.0 (util/coerce-return (primitives/mx-double 3))))
+  (is (instance? Float (util/coerce-return (primitives/mx-float 2))))
+  (is (= 2.0 (util/coerce-return (primitives/mx-float 2))))
+
   (is (= "foo" (util/coerce-return "foo"))))
 
 (deftest test-translate-keyword-shape
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
index ed7aff602f6..001bd04d2c9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
@@ -18,7 +18,9 @@
 package org.apache.mxnet
 
 import org.apache.mxnet.util.NativeLibraryLoader
-import org.slf4j.{LoggerFactory, Logger}
+import org.slf4j.{Logger, LoggerFactory}
+
+import scala.Specializable.Group
 
 private[mxnet] object Base {
   private val logger: Logger = LoggerFactory.getLogger("MXNetJVM")
@@ -57,6 +59,9 @@ private[mxnet] object Base {
 
   val MX_REAL_TYPE = DType.Float32
 
+  // The primitives currently supported for NDArray operations
+  val MX_PRIMITIVES = new Group ((Double, Float))
+
   try {
     try {
       tryLoadLibraryOS("mxnet-scala")
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 0a5683aa7ab..20b6ed9fc80 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -93,6 +93,9 @@ private[mxnet] class LibInfo {
   @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
                                        source: Array[MXFloat],
                                        size: Int): Int
+  @native def mxFloat64NDArraySyncCopyFromCPU(handle: NDArrayHandle,
+                                       source: Array[Double],
+                                       size: Int): Int
   @native def mxNDArrayLoad(fname: String,
                             outSize: MXUintRef,
                             handles: ArrayBuffer[NDArrayHandle],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
new file mode 100644
index 00000000000..cb978856963
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/MX_PRIMITIVES.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object MX_PRIMITIVES {
+
+  /**
+    * This defines the basic primitives we can use in Scala for mathematical
+    * computations in NDArrays.This gives us a flexibility to expand to
+    * more supported primitives in the future. Currently Float and Double
+    * are supported. The functions which accept MX_PRIMITIVE_TYPE as input can also accept
+    * plain old Float and Double data as inputs because of the underlying
+    * implicit conversion between primitives to MX_PRIMITIVE_TYPE.
+    */
+  trait MX_PRIMITIVE_TYPE extends Ordered[MX_PRIMITIVE_TYPE]{
+
+    def toString: String
+
+    def unary_- : MX_PRIMITIVE_TYPE
+  }
+
+  trait MXPrimitiveOrdering extends Ordering[MX_PRIMITIVE_TYPE] {
+
+    def compare(x: MX_PRIMITIVE_TYPE, y: MX_PRIMITIVE_TYPE): Int = x.compare(y)
+
+  }
+
+  implicit object MX_PRIMITIVE_TYPE extends MXPrimitiveOrdering
+
+  /**
+    * Wrapper over Float in Scala.
+    * @param data
+    */
+  class MX_FLOAT(val data: Float) extends MX_PRIMITIVE_TYPE {
+
+    override def toString: String = data.toString
+
+    override def unary_- : MX_PRIMITIVE_TYPE = new MX_FLOAT(data.unary_-)
+
+    override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+      this.data.compareTo(that.asInstanceOf[MX_FLOAT].data)
+    }
+  }
+
+  implicit def FloatToMX_Float(d : Float): MX_FLOAT = new MX_FLOAT(d)
+
+  implicit def MX_FloatToFloat(d: MX_FLOAT) : Float = d.data
+
+  implicit def IntToMX_Float(d: Int): MX_FLOAT = new MX_FLOAT(d.toFloat)
+
+  /**
+    * Wrapper over Double in Scala.
+    * @param data
+    */
+  class MX_Double(val data: Double) extends MX_PRIMITIVE_TYPE {
+
+    override def toString: String = data.toString
+
+    override def unary_- : MX_PRIMITIVE_TYPE = new MX_Double(data.unary_-)
+
+    override def compare(that: MX_PRIMITIVE_TYPE): Int = {
+      this.data.compareTo(that.asInstanceOf[MX_Double].data)
+    }
+  }
+
+  implicit def DoubleToMX_Double(d : Double): MX_Double = new MX_Double(d)
+
+  implicit def MX_DoubleToDouble(d: MX_Double) : Double = d.data
+
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 125958150b7..163ed268253 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
+import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
 import org.slf4j.LoggerFactory
 
 import scala.collection.mutable
@@ -262,16 +263,46 @@ object NDArray extends NDArrayBase {
     arr
   }
 
-  // Perform power operator
+  def full(shape: Shape, value: Double, ctx: Context): NDArray = {
+    val arr = empty(shape, ctx, DType.Float64)
+    arr.set(value)
+    arr
+  }
+
+  /**
+    * Create a new NDArray filled with given value, with specified shape.
+    * @param shape shape of the NDArray.
+    * @param value value to be filled with
+    */
+  def full(shape: Shape, value: Double): NDArray = {
+    full(shape, value, null)
+  }
+
+
+  /**
+    * Perform power operation on NDArray. Returns result as NDArray
+    * @param lhs
+    * @param rhs
+    */
   def power(lhs: NDArray, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power", Seq(lhs, rhs))
   }
 
-  def power(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform scalar power operation on NDArray. Returns result as NDArray
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def power(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(lhs, rhs))
   }
 
-  def power(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform scalar power operation on NDArray. Returns result as NDArray
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def power(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_rpower_scalar", Seq(lhs, rhs))
   }
 
@@ -280,11 +311,21 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("_maximum", Seq(lhs, rhs))
   }
 
-  def maximum(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform the max operation on NDArray. Returns the result as NDArray.
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def maximum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
   }
 
-  def maximum(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform the max operation on NDArray. Returns the result as NDArray.
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def maximum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_maximum_scalar", Seq(lhs, rhs))
   }
 
@@ -293,11 +334,21 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("_minimum", Seq(lhs, rhs))
   }
 
-  def minimum(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Perform the min operation on NDArray. Returns the result as NDArray.
+    * @param lhs NDArray on which to perform the operation on.
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def minimum(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
   }
 
-  def minimum(lhs: Float, rhs: NDArray): NDArray = {
+  /**
+    * Perform the min operation on NDArray. Returns the result as NDArray.
+    * @param lhs The scalar input. Can be of type Float/Double
+    * @param rhs NDArray on which to perform the operation on.
+    */
+  def minimum(lhs: MX_PRIMITIVE_TYPE, rhs: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minimum_scalar", Seq(lhs, rhs))
   }
 
@@ -310,7 +361,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_equal", Seq(lhs, rhs))
   }
 
-  def equal(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **equal to** (==) comparison operation with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding elements are same,
+    * otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def equal(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -324,7 +383,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_not_equal", Seq(lhs, rhs))
   }
 
-  def notEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **not equal to** (!=) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if corresponding elements are different,
+    * otherwise return 0(false).
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def notEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_not_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -338,7 +405,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_greater", Seq(lhs, rhs))
   }
 
-  def greater(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **greater than** (>) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are greater than rhs,
+    * otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def greater(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_greater_scalar", Seq(lhs, rhs))
   }
 
@@ -352,7 +428,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_greater_equal", Seq(lhs, rhs))
   }
 
-  def greaterEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **greater than or equal to** (>=) comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are greater than equal to
+    * rhs, otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def greaterEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_greater_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -366,7 +451,15 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_lesser", Seq(lhs, rhs))
   }
 
-  def lesser(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **lesser than** (<) comparison operation
+    * with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are less than rhs,
+    * otherwise return 0(false).
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def lesser(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_lesser_scalar", Seq(lhs, rhs))
   }
 
@@ -380,7 +473,16 @@ object NDArray extends NDArrayBase {
     NDArray.genericNDArrayFunctionInvoke("broadcast_lesser_equal", Seq(lhs, rhs))
   }
 
-  def lesserEqual(lhs: NDArray, rhs: Float): NDArray = {
+  /**
+    * Returns the result of element-wise **lesser than or equal to** (<=) comparison
+    * operation with broadcasting.
+    * For each element in input arrays, return 1(true) if lhs elements are
+    * lesser than equal to rhs, otherwise return 0(false).
+    *
+    * @param lhs NDArray
+    * @param rhs The scalar input. Can be of type Float/Double
+    */
+  def lesserEqual(lhs: NDArray, rhs: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_lesser_equal_scalar", Seq(lhs, rhs))
   }
 
@@ -397,6 +499,16 @@ object NDArray extends NDArrayBase {
     arr
   }
 
+  def array(sourceArr: Array[Double], shape: Shape, ctx: Context): NDArray = {
+    val arr = empty(shape, ctx, dtype = DType.Float64)
+    arr.set(sourceArr)
+    arr
+  }
+
+  def array(sourceArr: Array[Double], shape: Shape): NDArray = {
+    array(sourceArr, shape, null)
+  }
+
   /**
    * Returns evenly spaced values within a given interval.
    * Values are generated within the half-open interval [`start`, `stop`). In other
@@ -645,6 +757,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     checkCall(_LIB.mxNDArraySyncCopyFromCPU(handle, source, source.length))
   }
 
+  private def syncCopyfrom(source: Array[Double]): Unit = {
+    require(source.length == size,
+      s"array size (${source.length}) do not match the size of NDArray ($size)")
+    checkCall(_LIB.mxFloat64NDArraySyncCopyFromCPU(handle, source, source.length))
+  }
+
   /**
    * Return a sliced NDArray that shares memory with current one.
    * NDArray only support continuous slicing on axis 0
@@ -759,7 +877,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
    * @param value Value to set
    * @return Current NDArray
    */
-  def set(value: Float): NDArray = {
+  def set(value: MX_PRIMITIVE_TYPE): NDArray = {
     require(writable, "trying to assign to a readonly NDArray")
     NDArray.genericNDArrayFunctionInvoke("_set_value", Seq(value), Map("out" -> this))
     this
@@ -776,11 +894,17 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
+  def set(other: Array[Double]): NDArray = {
+    require(writable, "trying to assign to a readonly NDArray")
+    syncCopyfrom(other)
+    this
+  }
+
   def +(other: NDArray): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_plus", Seq(this, other))
   }
 
-  def +(other: Float): NDArray = {
+  def +(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_plus_scalar", Seq(this, other))
   }
 
@@ -792,7 +916,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def +=(other: Float): NDArray = {
+  def +=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to add to a readonly NDArray")
     }
@@ -804,7 +928,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_minus", Seq(this, other))
   }
 
-  def -(other: Float): NDArray = {
+  def -(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_minus_scalar", Seq(this, other))
   }
 
@@ -816,7 +940,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def -=(other: Float): NDArray = {
+  def -=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to subtract from a readonly NDArray")
     }
@@ -828,7 +952,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_mul", Seq(this, other))
   }
 
-  def *(other: Float): NDArray = {
+  def *(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_mul_scalar", Seq(this, other))
   }
 
@@ -844,7 +968,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def *=(other: Float): NDArray = {
+  def *=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to multiply to a readonly NDArray")
     }
@@ -856,7 +980,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_div", Seq(this, other))
   }
 
-  def /(other: Float): NDArray = {
+  def /(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_div_scalar", Seq(this, other))
   }
 
@@ -868,7 +992,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def /=(other: Float): NDArray = {
+  def /=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to divide from a readonly NDArray")
     }
@@ -880,7 +1004,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.power(this, other)
   }
 
-  def **(other: Float): NDArray = {
+  def **(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.power(this, other)
   }
 
@@ -888,7 +1012,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_power", Seq(this, other), Map("out" -> this))
   }
 
-  def **=(other: Float): NDArray = {
+  def **=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_power_scalar", Seq(this, other), Map("out" -> this))
   }
 
@@ -896,7 +1020,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.greater(this, other)
   }
 
-  def >(other: Float): NDArray = {
+  def >(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.greater(this, other)
   }
 
@@ -904,7 +1028,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.greaterEqual(this, other)
   }
 
-  def >=(other: Float): NDArray = {
+  def >=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.greaterEqual(this, other)
   }
 
@@ -912,7 +1036,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.lesser(this, other)
   }
 
-  def <(other: Float): NDArray = {
+  def <(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.lesser(this, other)
   }
 
@@ -920,7 +1044,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.lesserEqual(this, other)
   }
 
-  def <=(other: Float): NDArray = {
+  def <=(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.lesserEqual(this, other)
   }
 
@@ -928,7 +1052,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other))
   }
 
-  def %(other: Float): NDArray = {
+  def %(other: MX_PRIMITIVE_TYPE): NDArray = {
     NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other))
   }
 
@@ -940,7 +1064,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this
   }
 
-  def %=(other: Float): NDArray = {
+  def %=(other: MX_PRIMITIVE_TYPE): NDArray = {
     if (!writable) {
       throw new IllegalArgumentException("trying to take modulo from a readonly NDArray")
     }
@@ -956,6 +1080,14 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     internal.toFloatArray
   }
 
+  /**
+    * Return a copied flat java array of current array (row-major) with datatype as Float64/Double.
+    * @return  A copy of array content.
+    */
+  def toFloat64Array: Array[Double] = {
+    internal.toDoubleArray
+  }
+
   def internal: NDArrayInternal = {
     val myType = dtype
     val arrLength = DType.numOfBytes(myType) * size
@@ -975,6 +1107,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
     this.toArray(0)
   }
 
+  def toFloat64Scalar: Double = {
+    require(shape == Shape(1), "The current array is not a scalar")
+    this.toFloat64Array(0)
+  }
+
   /**
    * Copy the content of current array to other.
    *
@@ -997,7 +1134,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
    * @return The copy target NDArray
    */
   def copyTo(ctx: Context): NDArray = {
-    val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true))
+    val ret = new NDArray(NDArray.newAllocHandle(shape, ctx, delayAlloc = true, dtype = dtype))
     copyTo(ret)
   }
 
@@ -1047,11 +1184,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
 
 private[mxnet] object NDArrayConversions {
   implicit def int2Scalar(x: Int): NDArrayConversions = new NDArrayConversions(x.toFloat)
-  implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x.toFloat)
+  implicit def double2Scalar(x: Double): NDArrayConversions = new NDArrayConversions(x)
   implicit def float2Scalar(x: Float): NDArrayConversions = new NDArrayConversions(x)
 }
 
-private[mxnet] class NDArrayConversions(val value: Float) {
+private[mxnet] class NDArrayConversions(val value: MX_PRIMITIVE_TYPE) {
   def +(other: NDArray): NDArray = {
     other + value
   }
@@ -1145,34 +1282,39 @@ private[mxnet] class NDArrayFuncReturn(private[mxnet] val arr: Array[NDArray]) {
   def waitToRead(): Unit = head.waitToRead()
   def context: Context = head.context
   def set(value: Float): NDArray = head.set(value)
+  def set(value: Double): NDArray = head.set(value)
   def set(other: NDArray): NDArray = head.set(other)
   def set(other: Array[Float]): NDArray = head.set(other)
+  def set(other: Array[Double]): NDArray = head.set(other)
   def +(other: NDArray): NDArray = head + other
-  def +(other: Float): NDArray = head + other
+  def +(other: MX_PRIMITIVE_TYPE): NDArray = head + other
   def +=(other: NDArray): NDArray = head += other
-  def +=(other: Float): NDArray = head += other
+  def +=(other: MX_PRIMITIVE_TYPE): NDArray = head += other
   def -(other: NDArray): NDArray = head - other
-  def -(other: Float): NDArray = head - other
+  def -(other: MX_PRIMITIVE_TYPE): NDArray = head - other
   def -=(other: NDArray): NDArray = head -= other
-  def -=(other: Float): NDArray = head -= other
+  def -=(other: MX_PRIMITIVE_TYPE): NDArray = head -= other
   def *(other: NDArray): NDArray = head * other
-  def *(other: Float): NDArray = head * other
+  def *(other: MX_PRIMITIVE_TYPE): NDArray = head * other
   def unary_-(): NDArray = -head
   def *=(other: NDArray): NDArray = head *= other
-  def *=(other: Float): NDArray = head *= other
+  def *=(other: MX_PRIMITIVE_TYPE): NDArray = head *= other
   def /(other: NDArray): NDArray = head / other
+  def /(other: MX_PRIMITIVE_TYPE): NDArray = head / other
   def **(other: NDArray): NDArray = head ** other
-  def **(other: Float): NDArray = head ** other
+  def **(other: MX_PRIMITIVE_TYPE): NDArray = head ** other
   def >(other: NDArray): NDArray = head > other
-  def >(other: Float): NDArray = head > other
+  def >(other: MX_PRIMITIVE_TYPE): NDArray = head > other
   def >=(other: NDArray): NDArray = head >= other
-  def >=(other: Float): NDArray = head >= other
+  def >=(other: MX_PRIMITIVE_TYPE): NDArray = head >= other
   def <(other: NDArray): NDArray = head < other
-  def <(other: Float): NDArray = head < other
+  def <(other: MX_PRIMITIVE_TYPE): NDArray = head < other
   def <=(other: NDArray): NDArray = head <= other
-  def <=(other: Float): NDArray = head <= other
+  def <=(other: MX_PRIMITIVE_TYPE): NDArray = head <= other
   def toArray: Array[Float] = head.toArray
+  def toFloat64Array: Array[Double] = head.toFloat64Array
   def toScalar: Float = head.toScalar
+  def toFloat64Scalar: Double = head.toFloat64Scalar
   def copyTo(other: NDArray): NDArray = head.copyTo(other)
   def copyTo(ctx: Context): NDArray = head.copyTo(ctx)
   def copy(): NDArray = head.copy()
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index a84bd106b76..e30098c3088 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -53,9 +53,9 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
       val label = currentBatch.label(0)
       // properties
       val res = (
-        // TODO: need to allow user to specify DType and Layout
-        IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, Layout.UNDEFINED)),
-        IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, Layout.UNDEFINED)),
+        // TODO: need to allow user to specify Layout
+        IndexedSeq(new DataDesc(dataName, data.shape, data.dtype, Layout.UNDEFINED)),
+        IndexedSeq(new DataDesc(labelName, label.shape, label.dtype, Layout.UNDEFINED)),
         ListMap(dataName -> data.shape),
         ListMap(labelName -> label.shape),
         data.shape(0))
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 0032a54dd80..e690abba0d1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -61,7 +61,8 @@ class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
            dataBatchSize: Int = 1, shuffle: Boolean = false,
            lastBatchHandle: String = "pad",
            dataName: String = "data", labelName: String = "label") {
-    this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, Layout.UNDEFINED),
+    this(IO.initDataDesc(data, allowEmpty = false, dataName,
+      if (data == null || data.isEmpty)  MX_REAL_TYPE else data(0).dtype, Layout.UNDEFINED),
       IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, Layout.UNDEFINED),
       dataBatchSize, shuffle, lastBatchHandle)
   }
@@ -272,7 +273,7 @@ object NDArrayIter {
      */
     def addData(name: String, data: NDArray): Builder = {
       this.data = this.data ++ IndexedSeq((new DataDesc(name,
-        data.shape, DType.Float32, Layout.UNDEFINED), data))
+        data.shape, data.dtype, Layout.UNDEFINED), data))
       this
     }
 
@@ -284,7 +285,7 @@ object NDArrayIter {
      */
     def addLabel(name: String, label: NDArray): Builder = {
       this.label = this.label ++ IndexedSeq((new DataDesc(name,
-        label.shape, DType.Float32, Layout.UNDEFINED), label))
+        label.shape, label.dtype, Layout.UNDEFINED), label))
       this
     }
 
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
index 198102d2377..67809c158af 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/javaapi/NDArray.scala
@@ -91,17 +91,26 @@ object NDArray extends NDArrayBase {
   def full(shape: Shape, value: Float, ctx: Context): NDArray
   = org.apache.mxnet.NDArray.full(shape, value, ctx)
 
+  def full(shape: Shape, value: Double, ctx: Context): NDArray
+  = org.apache.mxnet.NDArray.full(shape, value, ctx)
+
   def power(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
   def power(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
   def power(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
+  def power(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.power(lhs, rhs)
 
   def maximum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
   def maximum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
   def maximum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
+  def maximum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.maximum(lhs, rhs)
 
   def minimum(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
   def minimum(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
   def minimum(lhs: Float, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
+  def minimum(lhs: Double, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.minimum(lhs, rhs)
 
 
   /**
@@ -111,6 +120,7 @@ object NDArray extends NDArrayBase {
     */
   def equal(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
   def equal(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
+  def equal(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.equal(lhs, rhs)
 
   /**
     * Returns the result of element-wise **not equal to** (!=) comparison operation
@@ -120,6 +130,7 @@ object NDArray extends NDArrayBase {
     */
   def notEqual(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
   def notEqual(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
+  def notEqual(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.notEqual(lhs, rhs)
 
   /**
     * Returns the result of element-wise **greater than** (>) comparison operation
@@ -129,6 +140,7 @@ object NDArray extends NDArrayBase {
     */
   def greater(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
   def greater(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
+  def greater(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.greater(lhs, rhs)
 
   /**
     * Returns the result of element-wise **greater than or equal to** (>=) comparison
@@ -140,6 +152,8 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
   def greaterEqual(lhs: NDArray, rhs: Float): NDArray
   = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
+  def greaterEqual(lhs: NDArray, rhs: Double): NDArray
+  = org.apache.mxnet.NDArray.greaterEqual(lhs, rhs)
 
   /**
     * Returns the result of element-wise **lesser than** (<) comparison operation
@@ -149,6 +163,7 @@ object NDArray extends NDArrayBase {
     */
   def lesser(lhs: NDArray, rhs: NDArray): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
   def lesser(lhs: NDArray, rhs: Float): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
+  def lesser(lhs: NDArray, rhs: Double): NDArray = org.apache.mxnet.NDArray.lesser(lhs, rhs)
 
   /**
     * Returns the result of element-wise **lesser than or equal to** (<=) comparison
@@ -160,6 +175,8 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
   def lesserEqual(lhs: NDArray, rhs: Float): NDArray
   = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
+  def lesserEqual(lhs: NDArray, rhs: Double): NDArray
+  = org.apache.mxnet.NDArray.lesserEqual(lhs, rhs)
 
   /**
     * Create a new NDArray that copies content from source_array.
@@ -172,6 +189,18 @@ object NDArray extends NDArrayBase {
   = org.apache.mxnet.NDArray.array(
     sourceArr.asScala.map(ele => Float.unbox(ele)).toArray, shape, ctx)
 
+  /**
+    * Create a new NDArray that copies content from source_array.
+    * @param sourceArr Source data (list of Doubles) to create NDArray from.
+    * @param shape shape of the NDArray
+    * @param ctx The context of the NDArray, default to current default context.
+    * @return The created NDArray.
+    */
+  def arrayWithDouble(sourceArr: java.util.List[java.lang.Double], shape: Shape,
+                      ctx: Context = null): NDArray
+  = org.apache.mxnet.NDArray.array(
+    sourceArr.asScala.map(ele => Double.unbox(ele)).toArray, shape)
+
   /**
     * Returns evenly spaced values within a given interval.
     * Values are generated within the half-open interval [`start`, `stop`). In other
@@ -205,6 +234,10 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
   }
 
+  def this(arr: Array[Double], shape: Shape, ctx: Context) = {
+    this(org.apache.mxnet.NDArray.array(arr, shape, ctx))
+  }
+
   def this(arr: java.util.List[java.lang.Float], shape: Shape, ctx: Context) = {
     this(NDArray.array(arr, shape, ctx))
   }
@@ -304,41 +337,59 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     * @return Current NDArray
     */
   def set(value: Float): NDArray = nd.set(value)
+  def set(value: Double): NDArray = nd.set(value)
   def set(other: NDArray): NDArray = nd.set(other)
   def set(other: Array[Float]): NDArray = nd.set(other)
+  def set(other: Array[Double]): NDArray = nd.set(other)
 
   def add(other: NDArray): NDArray = this.nd + other.nd
   def add(other: Float): NDArray = this.nd + other
+  def add(other: Double): NDArray = this.nd + other
   def addInplace(other: NDArray): NDArray = this.nd += other
   def addInplace(other: Float): NDArray = this.nd += other
+  def addInplace(other: Double): NDArray = this.nd += other
   def subtract(other: NDArray): NDArray = this.nd - other
   def subtract(other: Float): NDArray = this.nd - other
+  def subtract(other: Double): NDArray = this.nd - other
   def subtractInplace(other: NDArray): NDArray = this.nd -= other
   def subtractInplace(other: Float): NDArray = this.nd -= other
+  def subtractInplace(other: Double): NDArray = this.nd -= other
   def multiply(other: NDArray): NDArray = this.nd * other
   def multiply(other: Float): NDArray = this.nd * other
+  def multiply(other: Double): NDArray = this.nd * other
   def multiplyInplace(other: NDArray): NDArray = this.nd *= other
   def multiplyInplace(other: Float): NDArray = this.nd *= other
+  def multiplyInplace(other: Double): NDArray = this.nd *= other
   def div(other: NDArray): NDArray = this.nd / other
   def div(other: Float): NDArray = this.nd / other
+  def div(other: Double): NDArray = this.nd / other
   def divInplace(other: NDArray): NDArray = this.nd /= other
   def divInplace(other: Float): NDArray = this.nd /= other
+  def divInplace(other: Double): NDArray = this.nd /= other
   def pow(other: NDArray): NDArray = this.nd ** other
   def pow(other: Float): NDArray = this.nd ** other
+  def pow(other: Double): NDArray = this.nd ** other
   def powInplace(other: NDArray): NDArray = this.nd **= other
   def powInplace(other: Float): NDArray = this.nd **= other
+  def powInplace(other: Double): NDArray = this.nd **= other
   def mod(other: NDArray): NDArray = this.nd % other
   def mod(other: Float): NDArray = this.nd % other
+  def mod(other: Double): NDArray = this.nd % other
   def modInplace(other: NDArray): NDArray = this.nd %= other
   def modInplace(other: Float): NDArray = this.nd %= other
+  def modInplace(other: Double): NDArray = this.nd %= other
   def greater(other: NDArray): NDArray = this.nd > other
   def greater(other: Float): NDArray = this.nd > other
+  def greater(other: Double): NDArray = this.nd > other
   def greaterEqual(other: NDArray): NDArray = this.nd >= other
   def greaterEqual(other: Float): NDArray = this.nd >= other
+  def greaterEqual(other: Double): NDArray = this.nd >= other
   def lesser(other: NDArray): NDArray = this.nd < other
   def lesser(other: Float): NDArray = this.nd < other
+  def lesser(other: Double): NDArray = this.nd < other
   def lesserEqual(other: NDArray): NDArray = this.nd <= other
   def lesserEqual(other: Float): NDArray = this.nd <= other
+  def lesserEqual(other: Double): NDArray = this.nd <= other
 
   /**
     * Return a copied flat java array of current array (row-major).
@@ -346,6 +397,12 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     */
   def toArray: Array[Float] = nd.toArray
 
+  /**
+    * Return a copied flat java array of current array (row-major).
+    * @return  A copy of array content.
+    */
+  def toFloat64Array: Array[Double] = nd.toFloat64Array
+
   /**
     * Return a CPU scalar(float) of current ndarray.
     * This ndarray must have shape (1,)
@@ -354,6 +411,14 @@ class NDArray private[mxnet] (val nd: org.apache.mxnet.NDArray ) {
     */
   def toScalar: Float = nd.toScalar
 
+  /**
+    * Return a CPU scalar(float) of current ndarray.
+    * This ndarray must have shape (1,)
+    *
+    * @return The scalar representation of the ndarray.
+    */
+  def toFloat64Scalar: Double = nd.toFloat64Scalar
+
   /**
     * Copy the content of current array to other.
     *
diff --git a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b7848bc..86c7eb29d2e 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -40,6 +40,15 @@ public void testCreateNDArray() {
                 new Shape(new int[]{1, 3}),
                 new Context("cpu", 0));
         assertTrue(Arrays.equals(nd.shape().toArray(), arr));
+
+        List<Double> list2 = Arrays.asList(1d, 1d, 1d);
+        nd = NDArray.arrayWithDouble(list2,
+                new Shape(new int[]{1, 3}),
+                new Context("cpu", 0));
+
+        // Float64 assertion
+        assertTrue(nd.dtype() == DType.Float64());
+
     }
 
     @Test
@@ -64,6 +73,12 @@ public void testComparison(){
         nd = nd.subtract(nd2);
         float[] lesser = new float[]{0, 0, 0};
         assertTrue(Arrays.equals(nd.greater(nd2).toArray(), lesser));
+
+        NDArray nd3 = new NDArray(new double[]{1.0, 2.0, 3.0}, new Shape(new int[]{3}), new Context("cpu", 0));
+        nd3 = nd3.add(1.0);
+        double[] smaller = new double[] {2, 3, 4};
+        assertTrue(Arrays.equals(smaller, nd3.toFloat64Array()));
+
     }
 
     @Test
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 2ec6f668dbc..d3969b0ce77 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -303,5 +303,32 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(dataDesc(0).layout == Layout.NTC)
     assert(labelDesc(0).dtype == DType.Int32)
     assert(labelDesc(0).layout == Layout.NT)
+
+
+    // Test with passing Float64 hardcoded as Dtype of data
+    val dataIter4 = new NDArrayIter(
+      IO.initDataDesc(data, false, "data", DType.Float64, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc4 = dataIter4.provideDataDesc
+    val labelDesc4 = dataIter4.provideLabelDesc
+    assert(dataDesc4(0).dtype == DType.Float64)
+    assert(dataDesc4(0).layout == Layout.NTC)
+    assert(labelDesc4(0).dtype == DType.Int32)
+    assert(labelDesc4(0).layout == Layout.NT)
+
+    // Test with Float64 coming from the data itself
+    val dataF64 = IndexedSeq(NDArray.ones(shape0, dtype = DType.Float64),
+      NDArray.zeros(shape0, dtype = DType.Float64))
+
+    val dataIter5 = new NDArrayIter(
+      IO.initDataDesc(dataF64, false, "data", DType.Float64, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc5 = dataIter5.provideDataDesc
+    assert(dataDesc5(0).dtype == DType.Float64)
+    assert(dataDesc5(0).dtype != DType.Float32)
+    assert(dataDesc5(0).layout == Layout.NTC)
+
   }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 2f3b1676d27..bc7a0a026bc 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -21,7 +21,7 @@ import java.io.File
 import java.util.concurrent.atomic.AtomicInteger
 
 import org.apache.mxnet.NDArrayConversions._
-import org.scalatest.{Matchers, BeforeAndAfterAll, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
 
 class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   private val sequence: AtomicInteger = new AtomicInteger(0)
@@ -29,6 +29,9 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("to java array") {
     val ndarray = NDArray.zeros(2, 2)
     assert(ndarray.toArray === Array(0f, 0f, 0f, 0f))
+
+    val float64Array = NDArray.zeros(Shape(2, 2), dtype = DType.Float64)
+    assert(float64Array.toFloat64Array === Array(0d, 0d, 0d, 0d))
   }
 
   test("to scalar") {
@@ -38,8 +41,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(ndones.toScalar === 1f)
   }
 
+  test("to float 64 scalar") {
+    val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64)
+    assert(ndzeros.toFloat64Scalar === 0d)
+    val ndones = NDArray.ones(Shape(1), dtype = DType.Float64)
+    assert(ndones.toFloat64Scalar === 1d)
+  }
+
   test ("call toScalar on an ndarray which is not a scalar") {
     intercept[Exception] { NDArray.zeros(1, 1).toScalar }
+    intercept[Exception] { NDArray.zeros(shape = Shape (1, 1),
+      dtype = DType.Float64).toFloat64Scalar }
   }
 
   test("size and shape") {
@@ -51,12 +63,20 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("dtype") {
     val arr = NDArray.zeros(3, 2)
     assert(arr.dtype === DType.Float32)
+
+    val float64Array = NDArray.zeros(shape = Shape(3, 2), dtype = DType.Float64)
+    assert(float64Array.dtype === DType.Float64)
   }
 
   test("set scalar value") {
     val ndarray = NDArray.empty(2, 1)
     ndarray.set(10f)
     assert(ndarray.toArray === Array(10f, 10f))
+
+    val float64array = NDArray.empty(shape = Shape(2, 1), dtype = DType.Float64)
+    float64array.set(10d)
+    assert(float64array.toFloat64Array === Array(10d, 10d))
+
   }
 
   test("copy from java array") {
@@ -66,19 +86,29 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("plus") {
-    val ndzeros = NDArray.zeros(2, 1)
-    val ndones = ndzeros + 1f
+    var ndzeros = NDArray.zeros(2, 1)
+    var ndones = ndzeros + 1f
     assert(ndones.toArray === Array(1f, 1f))
     assert((ndones + ndzeros).toArray === Array(1f, 1f))
     assert((1 + ndones).toArray === Array(2f, 2f))
     // in-place
     ndones += ndones
     assert(ndones.toArray === Array(2f, 2f))
+
+    // Float64 method test
+    ndzeros = NDArray.zeros(shape = Shape(2, 1), dtype = DType.Float64)
+    ndones = ndzeros + 1d
+    assert(ndones.toFloat64Array === Array(1d, 1d))
+    assert((ndones + ndzeros).toFloat64Array === Array(1d, 1d))
+    assert((1d + ndones).toArray === Array(2d, 2d))
+    // in-place
+    ndones += ndones
+    assert(ndones.toFloat64Array === Array(2d, 2d))
   }
 
   test("minus") {
-    val ndones = NDArray.ones(2, 1)
-    val ndzeros = ndones - 1f
+    var ndones = NDArray.ones(2, 1)
+    var ndzeros = ndones - 1f
     assert(ndzeros.toArray === Array(0f, 0f))
     assert((ndones - ndzeros).toArray === Array(1f, 1f))
     assert((ndzeros - ndones).toArray === Array(-1f, -1f))
@@ -86,23 +116,46 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     // in-place
     ndones -= ndones
     assert(ndones.toArray === Array(0f, 0f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape(2, 1))
+    ndzeros = ndones - 1d
+    assert(ndzeros.toFloat64Array === Array(0d, 0d))
+    assert((ndones - ndzeros).toFloat64Array === Array(1d , 1d))
+    assert((ndzeros - ndones).toFloat64Array === Array(-1d , -1d))
+    assert((ndones - 1).toFloat64Array === Array(0d, 0d))
+    // in-place
+    ndones -= ndones
+    assert(ndones.toArray === Array(0d, 0d))
+
   }
 
   test("multiplication") {
-    val ndones = NDArray.ones(2, 1)
-    val ndtwos = ndones * 2
+    var ndones = NDArray.ones(2, 1)
+    var ndtwos = ndones * 2
     assert(ndtwos.toArray === Array(2f, 2f))
     assert((ndones * ndones).toArray === Array(1f, 1f))
     assert((ndtwos * ndtwos).toArray === Array(4f, 4f))
     ndtwos *= ndtwos
     // in-place
     assert(ndtwos.toArray === Array(4f, 4f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape(2, 1), dtype = DType.Float64)
+    ndtwos = ndones * 2d
+    assert(ndtwos.toFloat64Array === Array(2d, 2d))
+    assert((ndones * ndones).toFloat64Array === Array(1d, 1d))
+    assert((ndtwos * ndtwos).toFloat64Array === Array(4d, 4d))
+    ndtwos *= ndtwos
+    // in-place
+    assert(ndtwos.toFloat64Array === Array(4d, 4d))
+
   }
 
   test("division") {
-    val ndones = NDArray.ones(2, 1)
-    val ndzeros = ndones - 1f
-    val ndhalves = ndones / 2
+    var ndones = NDArray.ones(2, 1)
+    var ndzeros = ndones - 1f
+    var ndhalves = ndones / 2
     assert(ndhalves.toArray === Array(0.5f, 0.5f))
     assert((ndhalves / ndhalves).toArray === Array(1f, 1f))
     assert((ndones / ndones).toArray === Array(1f, 1f))
@@ -110,37 +163,75 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     ndhalves /= ndhalves
     // in-place
     assert(ndhalves.toArray === Array(1f, 1f))
+
+    // Float64 methods test
+    ndones = NDArray.ones(shape = Shape (2, 1), dtype = DType.Float64)
+    ndzeros = ndones - 1d
+    ndhalves = ndones / 2d
+    assert(ndhalves.toFloat64Array === Array(0.5d, 0.5d))
+    assert((ndhalves / ndhalves).toFloat64Array === Array(1d, 1d))
+    assert((ndones / ndones).toFloat64Array === Array(1d, 1d))
+    assert((ndzeros / ndones).toFloat64Array === Array(0d, 0d))
+    ndhalves /= ndhalves
+    // in-place
+    assert(ndhalves.toFloat64Array === Array(1d, 1d))
   }
 
   test("full") {
-    val arr = NDArray.full(Shape(1, 2), 3f)
+    var arr = NDArray.full(Shape(1, 2), 3f)
     assert(arr.shape === Shape(1, 2))
     assert(arr.toArray === Array(3f, 3f))
+
+    // Float64 methods test
+    arr = NDArray.full(Shape(1, 2), value = 5d, Context.cpu())
+    assert(arr.toFloat64Array === Array (5d, 5d))
   }
 
   test("clip") {
-    val ndarray = NDArray.empty(3, 2)
+    var ndarray = NDArray.empty(3, 2)
     ndarray.set(Array(1f, 2f, 3f, 4f, 5f, 6f))
     assert(NDArray.clip(ndarray, 2f, 5f).toArray === Array(2f, 2f, 3f, 4f, 5f, 5f))
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(3, 2), dtype = DType.Float64)
+    ndarray.set(Array(1d, 2d, 3d, 4d, 5d, 6d))
+    assert(NDArray.clip(ndarray, 2d, 5d).toFloat64Array === Array(2d, 2d, 3d, 4d, 5d, 5d))
   }
 
   test("sqrt") {
-    val ndarray = NDArray.empty(4, 1)
+    var ndarray = NDArray.empty(4, 1)
     ndarray.set(Array(0f, 1f, 4f, 9f))
     assert(NDArray.sqrt(ndarray).toArray === Array(0f, 1f, 2f, 3f))
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(4, 1), dtype = DType.Float64)
+    ndarray.set(Array(0d, 1d, 4d, 9d))
+    assert(NDArray.sqrt(ndarray).toFloat64Array === Array(0d, 1d, 2d, 3d))
   }
 
   test("rsqrt") {
-    val ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
+    var ndarray = NDArray.array(Array(1f, 4f), shape = Shape(2, 1))
     assert(NDArray.rsqrt(ndarray).toArray === Array(1f, 0.5f))
+
+    // Float64 methods test
+    ndarray = NDArray.array(Array(1d, 4d, 25d), shape = Shape(3, 1), Context.cpu())
+    assert(NDArray.rsqrt(ndarray).toFloat64Array === Array(1d, 0.5d, 0.2d))
   }
 
   test("norm") {
-    val ndarray = NDArray.empty(3, 1)
+    var ndarray = NDArray.empty(3, 1)
     ndarray.set(Array(1f, 2f, 3f))
-    val normed = NDArray.norm(ndarray)
+    var normed = NDArray.norm(ndarray)
     assert(normed.shape === Shape(1))
     assert(normed.toScalar === math.sqrt(14.0).toFloat +- 1e-3f)
+
+    // Float64 methods test
+    ndarray = NDArray.empty(shape = Shape(3, 1), dtype = DType.Float64)
+    ndarray.set(Array(1d, 2d, 3d))
+    normed = NDArray.norm(ndarray)
+    assert(normed.get.dtype === DType.Float64)
+    assert(normed.shape === Shape(1))
+    assert(normed.toFloat64Scalar === math.sqrt(14.0) +- 1e-3d)
   }
 
   test("one hot encode") {
@@ -176,25 +267,26 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("power") {
-    val arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
+    var arr = NDArray.array(Array(3f, 5f), shape = Shape(2, 1))
 
-    val arrPower1 = NDArray.power(2f, arr)
+    var arrPower1 = NDArray.power(2f, arr)
     assert(arrPower1.shape === Shape(2, 1))
     assert(arrPower1.toArray === Array(8f, 32f))
 
-    val arrPower2 = NDArray.power(arr, 2f)
+    var arrPower2 = NDArray.power(arr, 2f)
     assert(arrPower2.shape === Shape(2, 1))
     assert(arrPower2.toArray === Array(9f, 25f))
 
-    val arrPower3 = NDArray.power(arr, arr)
+    var arrPower3 = NDArray.power(arr, arr)
     assert(arrPower3.shape === Shape(2, 1))
     assert(arrPower3.toArray === Array(27f, 3125f))
 
-    val arrPower4 = arr ** 2f
+    var arrPower4 = arr ** 2f
+
     assert(arrPower4.shape === Shape(2, 1))
     assert(arrPower4.toArray === Array(9f, 25f))
 
-    val arrPower5 = arr ** arr
+    var arrPower5 = arr ** arr
     assert(arrPower5.shape === Shape(2, 1))
     assert(arrPower5.toArray === Array(27f, 3125f))
 
@@ -206,84 +298,211 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     arr **= arr
     assert(arr.shape === Shape(2, 1))
     assert(arr.toArray === Array(27f, 3125f))
+
+    // Float64 tests
+    arr = NDArray.array(Array(3d, 5d), shape = Shape(2, 1))
+
+    arrPower1 = NDArray.power(2d, arr)
+    assert(arrPower1.shape === Shape(2, 1))
+    assert(arrPower1.dtype === DType.Float64)
+    assert(arrPower1.toFloat64Array === Array(8d, 32d))
+
+    arrPower2 = NDArray.power(arr, 2d)
+    assert(arrPower2.shape === Shape(2, 1))
+    assert(arrPower2.dtype === DType.Float64)
+    assert(arrPower2.toFloat64Array === Array(9d, 25d))
+
+    arrPower3 = NDArray.power(arr, arr)
+    assert(arrPower3.shape === Shape(2, 1))
+    assert(arrPower3.dtype === DType.Float64)
+    assert(arrPower3.toFloat64Array === Array(27d, 3125d))
+
+    arrPower4 = arr ** 2f
+    assert(arrPower4.shape === Shape(2, 1))
+    assert(arrPower4.dtype === DType.Float64)
+    assert(arrPower4.toFloat64Array === Array(9d, 25d))
+
+    arrPower5 = arr ** arr
+    assert(arrPower5.shape === Shape(2, 1))
+    assert(arrPower5.dtype === DType.Float64)
+    assert(arrPower5.toFloat64Array === Array(27d, 3125d))
+
+    arr **= 2d
+    assert(arr.shape === Shape(2, 1))
+    assert(arr.dtype === DType.Float64)
+    assert(arr.toFloat64Array === Array(9d, 25d))
+
+    arr.set(Array(3d, 5d))
+    arr **= arr
+    assert(arr.shape === Shape(2, 1))
+    assert(arr.dtype === DType.Float64)
+    assert(arr.toFloat64Array === Array(27d, 3125d))
   }
 
   test("equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = NDArray.equal(arr1, arr2)
+    var arrEqual1 = NDArray.equal(arr1, arr2)
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
 
-    val arrEqual2 = NDArray.equal(arr1, 3f)
+    var arrEqual2 = NDArray.equal(arr1, 3f)
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 0f, 1f, 0f))
+
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = NDArray.equal(arr1, arr2)
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+    arrEqual2 = NDArray.equal(arr1, 3d)
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 0d))
   }
 
   test("not_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 3f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = NDArray.notEqual(arr1, arr2)
+    var arrEqual1 = NDArray.notEqual(arr1, arr2)
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
 
-    val arrEqual2 = NDArray.notEqual(arr1, 3f)
+    var arrEqual2 = NDArray.notEqual(arr1, 3f)
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 1f, 0f, 1f))
+
+    // Float64 methods test
+
+    arr1 = NDArray.array(Array(1d, 2d, 3d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = NDArray.notEqual(arr1, arr2)
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+    arrEqual2 = NDArray.notEqual(arr1, 3d)
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
   }
 
   test("greater") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 > arr2
+    var arrEqual1 = arr1 > arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 0f, 1f, 0f))
 
-    val arrEqual2 = arr1 > 2f
+    var arrEqual2 = arr1 > 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 0f, 1f, 1f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 > arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 0d, 1d, 0d))
+
+    arrEqual2 = arr1 > 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 0d, 1d, 1d))
   }
 
   test("greater_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 >= arr2
+    var arrEqual1 = arr1 >= arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 0f, 1f, 0f))
 
-    val arrEqual2 = arr1 >= 2f
+    var arrEqual2 = arr1 >= 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(0f, 1f, 1f, 1f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 >= arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 0d, 1d, 0d))
+
+    arrEqual2 = arr1 >= 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(0d, 1d, 1d, 1d))
   }
 
   test("lesser") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 < arr2
+    var arrEqual1 = arr1 < arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(0f, 1f, 0f, 1f))
 
-    val arrEqual2 = arr1 < 2f
+    var arrEqual2 = arr1 < 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 0f, 0f, 0f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 < arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(0d, 1d, 0d, 1d))
+
+    arrEqual2 = arr1 < 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 0d, 0d, 0d))
+
   }
 
   test("lesser_equal") {
-    val arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
-    val arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
+    var arr1 = NDArray.array(Array(1f, 2f, 4f, 5f), shape = Shape(2, 2))
+    var arr2 = NDArray.array(Array(1f, 4f, 3f, 6f), shape = Shape(2, 2))
 
-    val arrEqual1 = arr1 <= arr2
+    var arrEqual1 = arr1 <= arr2
     assert(arrEqual1.shape === Shape(2, 2))
     assert(arrEqual1.toArray === Array(1f, 1f, 0f, 1f))
 
-    val arrEqual2 = arr1 <= 2f
+    var arrEqual2 = arr1 <= 2f
     assert(arrEqual2.shape === Shape(2, 2))
     assert(arrEqual2.toArray === Array(1f, 1f, 0f, 0f))
+
+    // Float64 methods test
+    arr1 = NDArray.array(Array(1d, 2d, 4d, 5d), shape = Shape(2, 2))
+    arr2 = NDArray.array(Array(1d, 4d, 3d, 6d), shape = Shape(2, 2))
+
+    arrEqual1 = arr1 <= arr2
+    assert(arrEqual1.shape === Shape(2, 2))
+    assert(arrEqual1.dtype === DType.Float64)
+    assert(arrEqual1.toFloat64Array === Array(1d, 1d, 0d, 1d))
+
+    arrEqual2 = arr1 <= 2d
+    assert(arrEqual2.shape === Shape(2, 2))
+    assert(arrEqual2.dtype === DType.Float64)
+    assert(arrEqual2.toFloat64Array === Array(1d, 1d, 0d, 0d))
   }
 
   test("choose_element_0index") {
@@ -294,11 +513,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   }
 
   test("copy to") {
-    val source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
-    val dest = NDArray.empty(1, 3)
+    var source = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
+    var dest = NDArray.empty(1, 3)
     source.copyTo(dest)
     assert(dest.shape === Shape(1, 3))
     assert(dest.toArray === Array(1f, 2f, 3f))
+
+    // Float64 methods test
+    source = NDArray.array(Array(1d, 2d, 3d), shape = Shape(1, 3))
+    dest = NDArray.empty(shape = Shape(1, 3), dtype = DType.Float64)
+    source.copyTo(dest)
+    assert(dest.dtype === DType.Float64)
+    assert(dest.toFloat64Array === Array(1d, 2d, 3d))
   }
 
   test("abs") {
@@ -365,6 +591,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.maximum(arr1, arr2)
     assert(arr.shape === Shape(3, 1))
     assert(arr.toArray === Array(4f, 2.1f, 3.7f))
+
+    // Float64 methods test
+    val arr3 = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1))
+    val maxArr = NDArray.maximum(arr3, 10d)
+    assert(maxArr.shape === Shape(3, 1))
+    assert(maxArr.toArray === Array(10d, 10d, 10d))
   }
 
   test("min") {
@@ -378,11 +610,18 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.minimum(arr1, arr2)
     assert(arr.shape === Shape(3, 1))
     assert(arr.toArray === Array(1.5f, 1f, 3.5f))
+
+    // Float64 methods test
+    val arr3 = NDArray.array(Array(4d, 5d, 6d), shape = Shape(3, 1))
+    val minArr = NDArray.minimum(arr3, 5d)
+    assert(minArr.shape === Shape(3, 1))
+    assert(minArr.toFloat64Array === Array(4d, 5d, 5d))
   }
 
   test("sum") {
-    val arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
+    var arr = NDArray.array(Array(1f, 2f, 3f, 4f), shape = Shape(2, 2))
     assert(NDArray.sum(arr).toScalar === 10f +- 1e-3f)
+
   }
 
   test("argmaxChannel") {
@@ -398,6 +637,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.concatenate(arr1, arr2)
     assert(arr.shape === Shape(3, 3))
     assert(arr.toArray === Array(1f, 2f, 4f, 3f, 3f, 3f, 8f, 7f, 6f))
+
+    // Try concatenating float32 arr with float64 arr. Should get exception
+    intercept[Exception] {
+      val arr3 = NDArray.array(Array (5d, 6d, 7d), shape = Shape(1, 3))
+      NDArray.concatenate(Array(arr1, arr3))
+    }
   }
 
   test("concatenate axis-1") {
@@ -406,6 +651,12 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val arr = NDArray.concatenate(Array(arr1, arr2), axis = 1)
     assert(arr.shape === Shape(2, 3))
     assert(arr.toArray === Array(1f, 2f, 5f, 3f, 4f, 6f))
+
+    // Try concatenating float32 arr with float64 arr. Should get exception
+    intercept[Exception] {
+      val arr3 = NDArray.array(Array (5d, 6d), shape = Shape(2, 1))
+      NDArray.concatenate(Array(arr1, arr3), axis = 1)
+    }
   }
 
   test("transpose") {
@@ -428,6 +679,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val loadedArray = arrays(0)
       assert(loadedArray.shape === Shape(3, 1))
       assert(loadedArray.toArray === Array(1f, 2f, 3f))
+      assert(loadedArray.dtype === DType.Float32)
+    } finally {
+      val file = new File(filename)
+      file.delete()
+    }
+
+    // Try the same for Float64 array
+    try {
+      val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+      NDArray.save(filename, Map("local" -> ndarray))
+      val (keys, arrays) = NDArray.load(filename)
+      assert(keys.length === 1)
+      assert(keys(0) === "local")
+      assert(arrays.length === 1)
+      val loadedArray = arrays(0)
+      assert(loadedArray.shape === Shape(3, 1))
+      assert(loadedArray.toArray === Array(1d, 2d, 3d))
+      assert(loadedArray.dtype === DType.Float64)
     } finally {
       val file = new File(filename)
       file.delete()
@@ -446,6 +715,24 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
       val loadedArray = arrays(0)
       assert(loadedArray.shape === Shape(3, 1))
       assert(loadedArray.toArray === Array(1f, 2f, 3f))
+      assert(loadedArray.dtype === DType.Float32)
+    } finally {
+      val file = new File(filename)
+      file.delete()
+    }
+
+    // Try the same thing for Float64 array :
+
+    try {
+      val ndarray = NDArray.array(Array(1d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
+      NDArray.save(filename, Array(ndarray))
+      val (keys, arrays) = NDArray.load(filename)
+      assert(keys.length === 0)
+      assert(arrays.length === 1)
+      val loadedArray = arrays(0)
+      assert(loadedArray.shape === Shape(3, 1))
+      assert(loadedArray.toArray === Array(1d, 2d, 3d))
+      assert(loadedArray.dtype === DType.Float64)
     } finally {
       val file = new File(filename)
       file.delete()
@@ -464,9 +751,11 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val ndarray2 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(3, 1))
     val ndarray3 = NDArray.array(Array(1f, 2f, 3f), shape = Shape(1, 3))
     val ndarray4 = NDArray.array(Array(3f, 2f, 3f), shape = Shape(3, 1))
+    val ndarray5 = NDArray.array(Array(3d, 2d, 3d), shape = Shape(3, 1), ctx = Context.cpu())
     ndarray1 shouldEqual ndarray2
     ndarray1 shouldNot equal(ndarray3)
     ndarray1 shouldNot equal(ndarray4)
+    ndarray5 shouldNot equal(ndarray3)
   }
 
   test("slice") {
@@ -545,6 +834,7 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     val bytes = arr.serialize()
     val arrCopy = NDArray.deserialize(bytes)
     assert(arr === arrCopy)
+    assert(arrCopy.dtype === DType.Float32)
   }
 
   test("dtype int32") {
@@ -580,18 +870,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
   test("NDArray random module is generated properly") {
     val lam = NDArray.ones(1, 2)
     val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
-    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)),
+      dtype = Some("float64"))
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
+    assert(rnd2.head.dtype === DType.Float64)
   }
 
   test("NDArray random module is generated properly - special case of 'normal'") {
     val mu = NDArray.ones(1, 2)
     val sigma = NDArray.ones(1, 2) * 2
     val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
-    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)))
+    val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)),
+      dtype = Some("float64"))
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
+    assert(rnd2.head.dtype === DType.Float64)
   }
 
   test("Generated api") {
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
index f6c283c3dfb..9f0430eaada 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainModel.scala
@@ -19,6 +19,7 @@ package org.apache.mxnetexamples.imclassification
 
 import java.util.concurrent._
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnetexamples.imclassification.models._
 import org.apache.mxnetexamples.imclassification.util.Trainer
 import org.apache.mxnet._
@@ -42,12 +43,13 @@ object TrainModel {
     * @return The final validation accuracy
     */
   def test(model: String, dataPath: String, numExamples: Int = 60000,
-           numEpochs: Int = 10, benchmark: Boolean = false): Float = {
+           numEpochs: Int = 10, benchmark: Boolean = false,
+           dtype: DType = DType.Float32): Float = {
     ResourceScope.using() {
       val devs = Array(Context.cpu(0))
       val envs: mutable.Map[String, String] = mutable.HashMap.empty[String, String]
       val (dataLoader, net) = dataLoaderAndModel("mnist", model, dataPath,
-        numExamples = numExamples, benchmark = benchmark)
+        numExamples = numExamples, benchmark = benchmark, dtype = dtype)
       val Acc = Trainer.fit(batchSize = 128, numExamples, devs = devs,
         network = net, dataLoader = dataLoader,
         kvStore = "local", numEpochs = numEpochs)
@@ -69,7 +71,7 @@ object TrainModel {
     */
   def dataLoaderAndModel(dataset: String, model: String, dataDir: String = "",
                          numLayers: Int = 50, numExamples: Int = 60000,
-                         benchmark: Boolean = false
+                         benchmark: Boolean = false, dtype: DType = DType.Float32
                         ): ((Int, KVStore) => (DataIter, DataIter), Symbol) = {
     val (imageShape, numClasses) = dataset match {
       case "mnist" => (List(1, 28, 28), 10)
@@ -80,16 +82,17 @@ object TrainModel {
     val List(channels, height, width) = imageShape
     val dataSize: Int = channels * height * width
     val (datumShape, net) = model match {
-      case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses))
-      case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses))
+      case "mlp" => (List(dataSize), MultiLayerPerceptron.getSymbol(numClasses, dtype = dtype))
+      case "lenet" => (List(channels, height, width), Lenet.getSymbol(numClasses, dtype = dtype))
       case "resnet" => (List(channels, height, width), Resnet.getSymbol(numClasses,
-        numLayers, imageShape))
+        numLayers, imageShape, dtype = dtype))
       case _ => throw new Exception("Invalid model name")
     }
 
     val dataLoader: (Int, KVStore) => (DataIter, DataIter) = if (benchmark) {
       (batchSize: Int, kv: KVStore) => {
-        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples)
+        val iter = new SyntheticDataIter(numClasses, batchSize, datumShape, List(), numExamples,
+          dtype)
         (iter, iter)
       }
     } else {
@@ -116,8 +119,10 @@ object TrainModel {
         val dataPath = if (inst.dataDir == null) System.getenv("MXNET_HOME")
         else inst.dataDir
 
+        val dtype = DType.withName(inst.dType)
+
         val (dataLoader, net) = dataLoaderAndModel(inst.dataset, inst.network, dataPath,
-          inst.numLayers, inst.numExamples, inst.benchmark)
+          inst.numLayers, inst.numExamples, inst.benchmark, dtype)
 
         val devs =
           if (inst.gpus != null) inst.gpus.split(',').map(id => Context.gpu(id.trim.toInt))
@@ -210,5 +215,8 @@ class TrainModel {
   private val numWorker: Int = 1
   @Option(name = "--num-server", usage = "# of servers")
   private val numServer: Int = 1
+  @Option(name = "--dtype", usage = "data type of the model to train. " +
+    "Can be float32/float64. Works only with synthetic data currently")
+  private val dType: String = "float32"
 }
 
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
index 9421f102161..e4d3b2ae7c3 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/datasets/SyntheticDataIter.scala
@@ -24,7 +24,7 @@ import scala.collection.immutable.ListMap
 import scala.util.Random
 
 class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[Int],
-                        labelShape: List[Int], maxIter: Int, dtype: DType = DType.Float32
+                        labelShape: List[Int], maxIter: Int, dType: DType = DType.Float32
                        ) extends DataIter {
   var curIter = 0
   val random = new Random()
@@ -35,12 +35,12 @@ class SyntheticDataIter(numClasses: Int, val batchSize: Int, datumShape: List[In
   var label: IndexedSeq[NDArray] = IndexedSeq(
     NDArray.api.random_uniform(Some(0f), Some(maxLabel), shape = Some(batchLabelShape)))
   var data: IndexedSeq[NDArray] = IndexedSeq(
-    NDArray.api.random_uniform(shape = Some(shape)))
+    NDArray.api.random_uniform(shape = Some(shape), dtype = Some(dType.toString)))
 
   val provideDataDesc: IndexedSeq[DataDesc] = IndexedSeq(
-    new DataDesc("data", shape, dtype, Layout.UNDEFINED))
+    new DataDesc("data", shape, data(0).dtype, Layout.UNDEFINED))
   val provideLabelDesc: IndexedSeq[DataDesc] = IndexedSeq(
-    new DataDesc("softmax_label", batchLabelShape, dtype, Layout.UNDEFINED))
+    new DataDesc("softmax_label", batchLabelShape, label(0).dtype, Layout.UNDEFINED))
   val getPad: Int = 0
 
   override def getData(): IndexedSeq[NDArray] = data
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
index 76fb7bb6602..6f8b138d5cc 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Lenet.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object Lenet {
@@ -26,8 +27,8 @@ object Lenet {
     * @param numClasses Number of classes to classify into
     * @return model symbol
     */
-  def getSymbol(numClasses: Int): Symbol = {
-    val data = Symbol.Variable("data")
+  def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+    val data = Symbol.Variable("data", dType = dtype)
     // first conv
     val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), num_filter = 20)
     val tanh1 = Symbol.api.tanh(data = Some(conv1))
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
index 5d880bbe061..089b65f24a6 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/MultiLayerPerceptron.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object MultiLayerPerceptron {
@@ -26,8 +27,8 @@ object MultiLayerPerceptron {
     * @param numClasses Number of classes to classify into
     * @return model symbol
     */
-  def getSymbol(numClasses: Int): Symbol = {
-    val data = Symbol.Variable("data")
+  def getSymbol(numClasses: Int, dtype: DType = DType.Float32): Symbol = {
+    val data = Symbol.Variable("data", dType = dtype)
 
     val fc1 = Symbol.api.FullyConnected(data = Some(data), num_hidden = 128, name = "fc1")
     val act1 = Symbol.api.Activation(data = Some(fc1), "relu", name = "relu")
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
index c3f43d97e89..e5f597680f9 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/models/Resnet.scala
@@ -17,6 +17,7 @@
 
 package org.apache.mxnetexamples.imclassification.models
 
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 
 object Resnet {
@@ -77,13 +78,14 @@ object Resnet {
     */
   def resnet(units: List[Int], numStages: Int, filterList: List[Int], numClasses: Int,
              imageShape: List[Int], bottleNeck: Boolean = true, bnMom: Float = 0.9f,
-             workspace: Int = 256, dtype: String = "float32", memonger: Boolean = false): Symbol = {
+             workspace: Int = 256, dtype: DType = DType.Float32,
+             memonger: Boolean = false): Symbol = {
     assert(units.size == numStages)
     var data = Symbol.Variable("data", shape = Shape(List(4) ::: imageShape), dType = DType.Float32)
-    if (dtype == "float32") {
+    if (dtype == DType.Float32) {
       data = Symbol.api.identity(Some(data), "id")
-    } else if (dtype == "float16") {
-      data = Symbol.api.cast(Some(data), "float16")
+    } else if (dtype == DType.Float16) {
+      data = Symbol.api.cast(Some(data), DType.Float16.toString)
     }
     data = Symbol.api.BatchNorm(Some(data), fix_gamma = Some(true), eps = Some(2e-5),
       momentum = Some(bnMom), name = "bn_data")
@@ -118,8 +120,8 @@ object Resnet {
       kernel = Some(Shape(7, 7)), pool_type = Some("avg"), name = "pool1")
     val flat = Symbol.api.Flatten(Some(pool1))
     var fc1 = Symbol.api.FullyConnected(Some(flat), num_hidden = numClasses, name = "fc1")
-    if (dtype == "float16") {
-      fc1 = Symbol.api.cast(Some(fc1), "float32")
+    if (dtype == DType.Float16) {
+      fc1 = Symbol.api.cast(Some(fc1), DType.Float32.toString)
     }
     Symbol.api.SoftmaxOutput(Some(fc1), name = "softmax")
   }
@@ -134,7 +136,7 @@ object Resnet {
     * @return Model symbol
     */
   def getSymbol(numClasses: Int, numLayers: Int, imageShape: List[Int], convWorkspace: Int = 256,
-                dtype: String = "float32"): Symbol = {
+                dtype: DType = DType.Float32): Symbol = {
     val List(channels, height, width) = imageShape
     val (numStages, units, filterList, bottleNeck): (Int, List[Int], List[Int], Boolean) =
       if (height <= 28) {
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
index 6e9667abe9c..0daba5a97d7 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/IMClassificationExampleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.imclassification
 
 import java.io.File
 
-import org.apache.mxnet.Context
+import org.apache.mxnet.{Context, DType}
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
@@ -55,9 +55,15 @@ class IMClassificationExampleSuite extends FunSuite with BeforeAndAfterAll {
 
   for(model <- List("mlp", "lenet", "resnet")) {
     test(s"Example CI: Test Image Classification Model ${model}") {
-      var context = Context.cpu()
       val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true)
     }
   }
 
+  for(model <- List("mlp", "lenet", "resnet")) {
+    test(s"Example CI: Test Image Classification Model ${model} with Float64 input") {
+      val valAccuracy = TrainModel.test(model, "", 10, 1, benchmark = true,
+        dtype = DType.Float64)
+    }
+  }
+
 }
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index 5208923275f..bf658158811 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -17,9 +17,10 @@
 
 package org.apache.mxnet.infer
 
-import org.apache.mxnet.{Context, DataDesc, NDArray}
+import org.apache.mxnet._
 import java.io.File
 
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
 import org.slf4j.LoggerFactory
 
 import scala.io
@@ -30,13 +31,13 @@ trait ClassifierBase {
 
   /**
     * Takes an array of floats and returns corresponding (Label, Score) tuples
-    * @param input            Indexed sequence one-dimensional array of floats
+    * @param input            Indexed sequence one-dimensional array of floats/doubles
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
     * @return                 Indexed sequence of (Label, Score) tuples
     */
-  def classify(input: IndexedSeq[Array[Float]],
-               topK: Option[Int] = None): IndexedSeq[(String, Float)]
+  def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+               topK: Option[Int] = None): IndexedSeq[(String, T)]
 
   /**
     * Takes a sequence of NDArrays and returns (Label, Score) tuples
@@ -78,17 +79,35 @@ class Classifier(modelPathPrefix: String,
 
   /**
     * Takes flat arrays as input and returns (Label, Score) tuples.
-    * @param input            Indexed sequence one-dimensional array of floats
+    * @param input            Indexed sequence one-dimensional array of floats/doubles
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
     * @return                 Indexed sequence of (Label, Score) tuples
     */
-  override def classify(input: IndexedSeq[Array[Float]],
-                        topK: Option[Int] = None): IndexedSeq[(String, Float)] = {
+  override def classify[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]],
+                        topK: Option[Int] = None): IndexedSeq[(String, T)] = {
+
+    // considering only the first output
+    val result = input(0)(0) match {
+      case d: Double => {
+        classifyImpl(input.asInstanceOf[IndexedSeq[Array[Double]]], topK)
+      }
+      case _ => {
+        classifyImpl(input.asInstanceOf[IndexedSeq[Array[Float]]], topK)
+      }
+    }
+
+    result.asInstanceOf[IndexedSeq[(String, T)]]
+  }
+
+  private def classifyImpl[B, A <: MX_PRIMITIVE_TYPE]
+  (input: IndexedSeq[Array[B]], topK: Option[Int] = None)(implicit ev: B => A)
+  : IndexedSeq[(String, B)] = {
 
     // considering only the first output
     val predictResult = predictor.predict(input)(0)
-    var result: IndexedSeq[(String, Float)] = IndexedSeq.empty
+
+    var result: IndexedSeq[(String, B)] = IndexedSeq.empty
 
     if (topK.isDefined) {
       val sortedIndex = predictResult.zipWithIndex.sortBy(-_._1).map(_._2).take(topK.get)
@@ -105,7 +124,7 @@ class Classifier(modelPathPrefix: String,
     * @param input            Indexed sequence of NDArrays
     * @param topK             (Optional) How many result (sorting based on the last axis)
     *                         elements to return. Default returns unsorted output.
-    * @return                 Traversable sequence of (Label, Score) tuples
+    * @return                 Traversable sequence of (Label, Score) tuples.
     */
   override def classifyWithNDArray(input: IndexedSeq[NDArray], topK: Option[Int] = None)
   : IndexedSeq[IndexedSeq[(String, Float)]] = {
@@ -113,7 +132,7 @@ class Classifier(modelPathPrefix: String,
     // considering only the first output
     // Copy NDArray to CPU to avoid frequent GPU to CPU copying
     val predictResultND: NDArray =
-      predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
+    predictor.predictWithNDArray(input)(0).asInContext(Context.cpu())
     // Parallel Execution with ParArray for better performance
     val predictResultPar: ParArray[Array[Float]] =
       new ParArray[Array[Float]](predictResultND.shape(0))
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
index 96be12179d4..3c80f922639 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.infer
 
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 
 import scala.collection.mutable.ListBuffer
 
@@ -70,14 +71,18 @@ class ImageClassifier(modelPathPrefix: String,
     *
     * @param inputImage       Path prefix of the input image
     * @param topK             Number of result elements to return, sorted by probability
+    * @param dType            The precision at which to run the inference.
+    *                         specify the DType as DType.Float64 for Double precision.
+    *                         Defaults to DType.Float32
     * @return                 List of list of tuples of (Label, Probability)
     */
-  def classifyImage(inputImage: BufferedImage,
-                    topK: Option[Int] = None): IndexedSeq[IndexedSeq[(String, Float)]] = {
+  def classifyImage
+  (inputImage: BufferedImage, topK: Option[Int] = None, dType: DType = DType.Float32):
+  IndexedSeq[IndexedSeq[(String, Float)]] = {
 
     val scaledImage = ImageClassifier.reshapeImage(inputImage, width, height)
     val imageShape = inputShape.drop(1)
-    val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+    val pixelsNDArray = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
     val imgWithBatchNum = NDArray.api.expand_dims(pixelsNDArray, 0)
     inputImage.flush()
     scaledImage.flush()
@@ -95,16 +100,19 @@ class ImageClassifier(modelPathPrefix: String,
     *
     * @param inputBatch       Input array of buffered images
     * @param topK             Number of result elements to return, sorted by probability
+    * @param dType            The precision at which to run the inference.
+    *                         specify the DType as DType.Float64 for Double precision.
+    *                         Defaults to DType.Float32
     * @return                 List of list of tuples of (Label, Probability)
     */
-  def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None):
-  IndexedSeq[IndexedSeq[(String, Float)]] = {
+  def classifyImageBatch(inputBatch: Traversable[BufferedImage], topK: Option[Int] = None,
+   dType: DType = DType.Float32): IndexedSeq[IndexedSeq[(String, Float)]] = {
 
     val inputBatchSeq = inputBatch.toIndexedSeq
     val imageBatch = inputBatchSeq.indices.par.map(idx => {
       val scaledImage = ImageClassifier.reshapeImage(inputBatchSeq(idx), width, height)
       val imageShape = inputShape.drop(1)
-      val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape)
+      val imgND = ImageClassifier.bufferedImageToPixels(scaledImage, imageShape, dType)
       val imgWithBatch = NDArray.api.expand_dims(imgND, 0).get
       handler.execute(imgND.dispose())
       imgWithBatch
@@ -152,11 +160,29 @@ object ImageClassifier {
     * returned by this method after the use.
     * </p>
     * @param resizedImage     BufferedImage to get pixels from
+    *
     * @param inputImageShape  Input shape; for example for resnet it is (3,224,224).
                               Should be same as inputDescriptor shape.
+    * @param dType            The DataType of the NDArray created from the image
+    *                         that should be returned.
+    *                         Currently it defaults to Dtype.Float32
     * @return                 NDArray pixels array with shape (3, 224, 224) in CHW format
     */
-  def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape): NDArray = {
+  def bufferedImageToPixels(resizedImage: BufferedImage, inputImageShape: Shape,
+                            dType : DType = DType.Float32): NDArray = {
+
+      if (dType == DType.Float64) {
+        val result = getFloatPixelsArray(resizedImage)
+        NDArray.array(result.map(_.toDouble), shape = inputImageShape)
+      }
+      else {
+        val result = getFloatPixelsArray(resizedImage)
+        NDArray.array(result, shape = inputImageShape)
+      }
+  }
+
+  private def getFloatPixelsArray(resizedImage: BufferedImage): Array[Float] = {
+
     // Get height and width of the image
     val w = resizedImage.getWidth()
     val h = resizedImage.getHeight()
@@ -166,7 +192,6 @@ object ImageClassifier {
 
     // 3 times height and width for R,G,B channels
     val result = new Array[Float](3 * h * w)
-
     var row = 0
     // copy pixels to array vertically
     while (row < h) {
@@ -184,11 +209,10 @@ object ImageClassifier {
       }
       row += 1
     }
+
     resizedImage.flush()
 
-    // creating NDArray according to the input shape
-    val pixelsArray = NDArray.array(result, shape = inputImageShape)
-    pixelsArray
+    result
   }
 
   /**
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index d4bce9f0d71..67692a316cc 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -17,8 +17,9 @@
 
 package org.apache.mxnet.infer
 
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
 import org.apache.mxnet.io.NDArrayIter
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet._
 import org.apache.mxnet.module.Module
 
 import scala.collection.mutable.ListBuffer
@@ -36,11 +37,13 @@ private[infer] trait PredictBase {
    * <p>
    * This method will take input as IndexedSeq one dimensional arrays and creates the
    * NDArray needed for inference. The array will be reshaped based on the input descriptors.
-   * @param input:            An IndexedSequence of a one-dimensional array.
+   * @param input:            An Indexed Sequence of a one-dimensional array of datatype
+    *                         Float or Double
                               An IndexedSequence is needed when the model has more than one input.
    * @return                  Indexed sequence array of outputs
    */
-  def predict(input: IndexedSeq[Array[Float]]): IndexedSeq[Array[Float]]
+  def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+  : IndexedSeq[Array[T]]
 
   /**
    * Predict using NDArray as input.
@@ -123,13 +126,13 @@ class Predictor(modelPathPrefix: String,
    * Takes input as IndexedSeq one dimensional arrays and creates the NDArray needed for inference
    * The array will be reshaped based on the input descriptors.
    *
-   * @param input:            An IndexedSequence of a one-dimensional array.
+   * @param input:            An IndexedSequence of a one-dimensional array
+    *                         of data type Float or Double.
                               An IndexedSequence is needed when the model has more than one input.
    * @return                  Indexed sequence array of outputs
    */
-  override def predict(input: IndexedSeq[Array[Float]])
-  : IndexedSeq[Array[Float]] = {
-
+  override def predict[@specialized (Base.MX_PRIMITIVES) T](input: IndexedSeq[Array[T]])
+  : IndexedSeq[Array[T]] = {
     require(input.length == inputDescriptors.length,
       s"number of inputs provided: ${input.length} does not match number of inputs " +
         s"in inputDescriptors: ${inputDescriptors.length}")
@@ -139,12 +142,30 @@ class Predictor(modelPathPrefix: String,
         s"number of elements:${i.length} in the input does not match the shape:" +
           s"${d.shape.toString()}")
     }
+
+    // Infer the dtype of input and call relevant method
+    val result = input(0)(0) match {
+      case d: Double => predictImpl(input.asInstanceOf[IndexedSeq[Array[Double]]])
+      case _ => predictImpl(input.asInstanceOf[IndexedSeq[Array[Float]]])
+    }
+
+    result.asInstanceOf[IndexedSeq[Array[T]]]
+  }
+
+  private def predictImpl[B, A <: MX_PRIMITIVE_TYPE]
+  (input: IndexedSeq[Array[B]])(implicit ev: B => A)
+  : IndexedSeq[Array[B]] = {
+
     var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
 
     for((i, d) <- input.zip(inputDescriptors)) {
       val shape = d.shape.toVector.patch(from = batchIndex, patch = Vector(1), replaced = 1)
-
-      inputND += mxNetHandler.execute(NDArray.array(i, Shape(shape)))
+      if (d.dtype == DType.Float64) {
+        inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Double]], Shape(shape)))
+      }
+      else {
+        inputND += mxNetHandler.execute(NDArray.array(i.asInstanceOf[Array[Float]], Shape(shape)))
+      }
     }
 
     // rebind with batchsize 1
@@ -158,7 +179,8 @@ class Predictor(modelPathPrefix: String,
     val resultND = mxNetHandler.execute(mod.predict(new NDArrayIter(
       inputND.toIndexedSeq, dataBatchSize = 1)))
 
-    val result = resultND.map((f : NDArray) => f.toArray)
+    val result =
+      resultND.map((f : NDArray) => if (f.dtype == DType.Float64) f.toFloat64Array else f.toArray)
 
     mxNetHandler.execute(inputND.foreach(_.dispose))
     mxNetHandler.execute(resultND.foreach(_.dispose))
@@ -168,9 +190,11 @@ class Predictor(modelPathPrefix: String,
       mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false, forceRebind = true))
     }
 
-    result
+    result.asInstanceOf[IndexedSeq[Array[B]]]
   }
 
+
+
   /**
    * Predict using NDArray as input
    * This method is useful when the input is a batch of data
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index 0466693be9b..146fe93105e 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -72,6 +72,30 @@ class Predictor private[mxnet] (val predictor: org.apache.mxnet.infer.Predictor)
     predictor.predict(input).toArray
   }
 
+  /**
+    * Takes input as Array of one dimensional arrays and creates the NDArray needed for inference
+    * The array will be reshaped based on the input descriptors. Example of calling in Java:
+    *
+    * <pre>
+    * {@code
+    * double tmp[][] = new double[1][224];
+    * for (int x = 0; x < 1; x++)
+    *   for (int y = 0; y < 224; y++)
+    *     tmp[x][y] = (int)(Math.random()*10);
+    * predictor.predict(tmp);
+    * }
+    * </pre>
+    *
+    * @param input:            An Array of a one-dimensional array.
+                              An extra Array is needed for when the model has more than one input.
+    * @return                  Indexed sequence array of outputs
+    */
+
+  def predict(input: Array[Array[Double]]):
+  Array[Array[Double]] = {
+    predictor.predict(input).toArray
+  }
+
   /**
     * Takes input as List of one dimensional arrays and creates the NDArray needed for inference
     * The array will be reshaped based on the input descriptors.
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
index b28aeba1dee..d9ccec46879 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ClassifierSuite.scala
@@ -22,7 +22,7 @@ import java.nio.file.{Files, Paths}
 import java.util
 
 import org.apache.mxnet.module.Module
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet.{Context, DType, DataDesc, NDArray, Shape}
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -127,6 +127,29 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
 
   }
 
+  test("ClassifierSuite-flatFloat64Array-topK") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Double](12)(1d)
+
+    val predictResult : IndexedSeq[Array[Double]] =
+      IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+    val result: IndexedSeq[(String, Double)] = testClassifier.
+      classify(IndexedSeq(inputData), topK = Some(10))
+
+    assert((result(0)_2).getClass == 1d.getClass)
+
+    assertResult(predictResult(0).sortBy(-_)) {
+      result.map(_._2).toArray
+    }
+
+  }
+
   test("ClassifierSuite-flatArrayInput") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
     val inputData = Array.fill[Float](12)(1)
@@ -147,6 +170,28 @@ class ClassifierSuite extends FunSuite with BeforeAndAfterAll {
     }
   }
 
+  test("ClassifierSuite-flatArrayFloat64Input") {
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputData = Array.fill[Double](12)(1d)
+
+    val predictResult : IndexedSeq[Array[Double]] =
+      IndexedSeq[Array[Double]](Array(.98d, 0.97d, 0.96d, 0.99d))
+
+    val testClassifier = new MyClassifier(modelPath, inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testClassifier.predictor)
+      .predict(any(classOf[IndexedSeq[Array[Double]]]))
+
+    val result: IndexedSeq[(String, Double)] = testClassifier.
+      classify(IndexedSeq(inputData))
+
+    assert((result(0)_2).getClass == 1d.getClass)
+
+    assertResult(predictResult(0)) {
+      result.map(_._2).toArray
+    }
+  }
+
   test("ClassifierSuite-NDArray1InputWithoutTopK") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2)))
     val inputDataShape = Shape(1, 3, 2, 2)
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
index 1c291e1e7b3..5198c4a1f30 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
@@ -68,6 +68,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
     val result = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2))
 
     assert(result.shape == inputDescriptor(0).shape.drop(1))
+    assert(result.dtype == DType.Float32)
+
+    val resultFloat64 = ImageClassifier.bufferedImageToPixels(image2, Shape(3, 2, 2), DType.Float64)
+    assert(resultFloat64.dtype == DType.Float64)
   }
 
   test("ImageClassifierSuite-testWithInputImage") {
@@ -106,8 +110,10 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
         predictResult(i).map(_._2).toArray
       }
     }
+
   }
 
+
   test("ImageClassifierSuite-testWithInputBatchImage") {
     val dType = DType.Float32
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, Shape(1, 3, 512, 512),
@@ -152,4 +158,5 @@ class ImageClassifierSuite extends ClassifierSuite with BeforeAndAfterAll {
       }
     }
   }
+
 }
diff --git a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 509ffb35db8..9afbc9b3d4a 100644
--- a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
+import org.apache.mxnet._
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -91,6 +91,36 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
       , any[Option[BaseModule]], any[String])
   }
 
+  test("PredictorSuite-testWithFlatFloat64Arrays") {
+
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
+      layout = Layout.NCHW, dtype = DType.Float64))
+    val inputData = Array.fill[Double](12)(1d)
+
+    // this will disposed at the end of the predict call on Predictor.
+    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64))
+
+    val testPredictor = new MyPredictor("xyz", inputDescriptor)
+
+    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
+      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
+
+    val testFun = testPredictor.predict(IndexedSeq(inputData))
+
+    assert(testFun.size == 1, "output size should be 1 ")
+
+    assert(testFun(0)(0).getClass == 1d.getClass)
+
+    assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString)
+
+    // Verify that the module was bound with batch size 1 and rebound back to the original
+    // input descriptor. the number of times is twice here because loadModule overrides the
+    // initial bind.
+    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
+      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
+      , any[Option[BaseModule]], any[String])
+  }
+
   test("PredictorSuite-testWithNDArray") {
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
       layout = Layout.NCHW))
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index d684c6d1356..ea6e9c8f5ba 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -424,6 +424,15 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+  (JNIEnv *env, jobject obj, jlong arrayPtr, jdoubleArray sourceArr, jint arrSize) {
+  jdouble *sourcePtr = env->GetDoubleArrayElements(sourceArr, NULL);
+  int ret = MXNDArraySyncCopyFromCPU(reinterpret_cast<NDArrayHandle>(arrayPtr),
+                                     static_cast<const double *>(sourcePtr), arrSize);
+  env->ReleaseDoubleArrayElements(sourceArr, sourcePtr, 0);
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
   (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) {
   int outDevType;
diff --git a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 40230ac6daa..7e8e03de912 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -175,6 +175,14 @@ JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromCPU
   (JNIEnv *, jobject, jlong, jfloatArray, jint);
 
+/*
+ * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxFloat64NDArraySyncCopyFromCPU
+ * Signature: (J[DI)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFromCPU
+  (JNIEnv *, jobject, jlong, jdoubleArray, jint);
+
 /*
  * Class:     org_apache_mxnet_LibInfo
  * Method:    mxNDArrayLoad


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services