You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/03/05 19:34:18 UTC

[GitHub] [incubator-mxnet] satyakrishnagorti opened a new pull request #14337: Optimizer MXKVStoreUpdater bug fix in serializeState method

satyakrishnagorti opened a new pull request #14337: Optimizer MXKVStoreUpdater bug fix in serializeState method
URL: https://github.com/apache/incubator-mxnet/pull/14337
 
 
   PR addressing issue: https://github.com/apache/incubator-mxnet/issues/14265
   
   ## Description
   Currently there is a bug in the way Optimizer is trying to serialize state which fails when trying to deserialize Optimizer that has no `states` (like SGD without momentum).
   
   ## Issue
   Currently the way serialize is being done is as below: (pasting Optimizer.serailizeState())
   ```scala
     override def serializeState(): Array[Byte] = {
           val bos = new ByteArrayOutputStream()
           try {
             val out = new ObjectOutputStream(bos)
             out.writeInt(states.size)
             states.foreach { case (k, v) =>
               if (v != null) {
                 out.writeInt(k)
                 val stateBytes = optimizer.serializeState(v)
                 if (stateBytes == null) {
                   out.writeInt(0)
                 } else {
                   out.writeInt(stateBytes.length)
                   out.write(stateBytes)
                 }
               }
             }
             out.flush()
             bos.toByteArray
           } finally {
            ...
         }
     }
   ```
   When an Optimizer without `states` like SGD with momentum set as `0` is being used. The `states` map (`Map[Int, AnyRef]`) contains a (key, value) pair as (`some integer index`, `null`).
   
   The above serialize method does not write `k` as the value of `key` and `0` as the value of `stateBytes`, due to the null check `if (v != null)`
   
   Now while deserializing: (Pasting code from Optimizer.deserializeState())
   ```scala
     override def deserializeState(bytes: Array[Byte]): Unit = {
           val bis = new ByteArrayInputStream(bytes)
           var in: ObjectInputStream = null
           try {
             in = new ObjectInputStream(bis)
             val size = in.readInt()
             (0 until size).foreach(_ => {
               val key = in.readInt()
               val bytesLength = in.readInt()
               val value =
                 if (bytesLength > 0) {
                   val bytes = Array.fill[Byte](bytesLength)(0)
                   in.readFully(bytes)
                   optimizer.deserializeState(bytes)
                 } else {
                   null
                 }
               states.update(key, value)
             })
           } finally {
             ...
         }
     }
   ```
   In the `foreach` loop, the key is being read (which wasn't serialized previously) hence, this would cause an `java.io.EOFException`.
   
   ## Solution.
   
   Get rid of `if (v != null)` check and retain the rest.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services