You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/05/07 18:14:18 UTC

[GitHub] Hodapp87 commented on issue #10338: NDArray saved in Python cannot be loaded in Scala

Hodapp87 commented on issue #10338: NDArray saved in Python cannot be loaded in Scala
URL: https://github.com/apache/incubator-mxnet/issues/10338#issuecomment-387154390
 
 
   I did a quick fix below that seems to resolve it, but I have tested only with @gigasquid's Clojure bindings, and I still need to move this to upstream master and check the resultant values more carefully. An actual error message as she suggests would also be very helpful.
   
   ```patch
   diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
   index bfe757d5..0fc53918 100644
   --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
   +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/DType.scala
   @@ -24,6 +24,8 @@ object DType extends Enumeration {
      val Float16 = Value(2, "float16")
      val UInt8 = Value(3, "uint8")
      val Int32 = Value(4, "int32")
   +  val Int8 = Value(5, "int8")
   +  val Int64 = Value(6, "int64")
      private[mxnet] def numOfBytes(dtype: DType): Int = {
        dtype match {
          case DType.UInt8 => 1
   @@ -31,6 +33,8 @@ object DType extends Enumeration {
          case DType.Float16 => 2
          case DType.Float32 => 4
          case DType.Float64 => 8
   +      case DType.Int8 => 1
   +      case DType.Int64 => 8
        }
      }
    }
   diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
   index 7cfc059c..423040fd 100644
   --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
   +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala
   @@ -1167,8 +1167,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
        dtype match {
          case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
          case DType.Float64 => units.map(wrapBytes(_).getDouble)
   -      case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
   -      case DType.UInt8 => internal.map(_.toDouble)
   +      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toDouble)
   +      case DType.UInt8 | DType.Int8 => internal.map(_.toDouble)
        }
      }
      def toFloatArray: Array[Float] = {
   @@ -1176,8 +1176,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
        dtype match {
          case DType.Float32 => units.map(wrapBytes(_).getFloat)
          case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
   -      case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
   -      case DType.UInt8 => internal.map(_.toFloat)
   +      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toFloat)
   +      case DType.UInt8 | DType.Int8 => internal.map(_.toFloat)
        }
      }
      def toIntArray: Array[Int] = {
   @@ -1185,8 +1185,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
        dtype match {
          case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
          case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
   -      case DType.Int32 => units.map(wrapBytes(_).getInt)
   -      case DType.UInt8 => internal.map(_.toInt)
   +      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt)
   +      case DType.UInt8 | DType.Int8 => internal.map(_.toInt)
        }
      }
      def toByteArray: Array[Byte] = {
   @@ -1194,8 +1194,8 @@ private[mxnet] class NDArrayInternal (private val internal: Array[Byte], private
        dtype match {
          case DType.Float16 | DType.Float32 => units.map(wrapBytes(_).getFloat.toByte)
          case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
   -      case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
   -      case DType.UInt8 => internal.clone()
   +      case DType.Int32 | DType.Int64 => units.map(wrapBytes(_).getInt.toByte)
   +      case DType.UInt8 | DType.Int8 => internal.clone()
        }
      }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services