You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/27 05:04:49 UTC

[GitHub] lanking520 closed pull request #12110: [MXNET-730][WIP] Scala test in nightly

lanking520 closed pull request #12110: [MXNET-730][WIP] Scala test in nightly
URL: https://github.com/apache/incubator-mxnet/pull/12110
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
index c1ff10c6c8a..a275291bca2 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala
@@ -21,9 +21,15 @@ import java.io.File
 import java.net.URL
 
 import org.apache.commons.io.FileUtils
+import org.slf4j.LoggerFactory
 
 object Util {
-
+  /**
+    * a Download wrapper with retry scheme on Scala
+    * @param url the URL for the file
+    * @param filePath the path to store the file
+    * @param maxRetry maximum retries will take
+    */
   def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
     val tmpFile = new File(filePath)
     var retry = maxRetry.getOrElse(3)
@@ -42,4 +48,30 @@ object Util {
     }
    if (!success) throw new Exception(s"$url Download failed!")
   }
+
+  /**
+    * This Util is designed to manage the tests in CI
+    * @param name the name of the test
+    * @return runTest and number of epoch
+    */
+  def testManager(name: String) : (Boolean, Int) = {
+    val GPUTest = Map[String, Int]("CNN" -> 10, "GAN" -> 5, "MultiTask" -> 3,
+      "NSBoost" -> 10, "NSNeural" -> 80)
+    val CPUTest = Set("CustomOp", "MNIST", "Infer", "Profiler")
+    val GPU_Enable = System.getenv().containsKey("SCALA_TEST_INTEGRATION")
+    if (GPUTest.contains(name)) {
+      if (GPU_Enable) {
+        val epoch = if (System.getenv("SCALA_TEST_INTEGRATION").toInt == 1) {
+          1
+        } else GPUTest.get(name).get
+        (true, epoch)
+      } else {
+        (false, 0)
+      }
+    } else if (CPUTest.contains(name)) {
+      (true, 0)
+    } else {
+      throw new IllegalArgumentException("Test not found, please registered in Util!")
+    }
+  }
 }
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
index 674c81459f0..a0c616a7457 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/cnntextclassification/CNNTextClassification.scala
@@ -103,9 +103,8 @@ object CNNTextClassification {
   def trainCNN(model: CNNModel, trainBatches: Array[Array[Array[Float]]],
                trainLabels: Array[Float], devBatches: Array[Array[Array[Float]]],
                devLabels: Array[Float], batchSize: Int, saveModelPath: String,
-               learningRate: Float = 0.001f): Float = {
+               learningRate: Float = 0.001f, epoch : Int = 10): Float = {
     val maxGradNorm = 0.5f
-    val epoch = 10
     val initializer = new Uniform(0.1f)
     val opt = new RMSProp(learningRate)
     val updater = Optimizer.getUpdater(opt)
@@ -236,7 +235,7 @@ object CNNTextClassification {
   }
 
   def test(w2vFilePath : String, mrDatasetPath: String,
-           ctx : Context, saveModelPath: String) : Float = {
+           ctx : Context, saveModelPath: String, epoch : Int) : Float = {
     val (numEmbed, word2vec) = DataHelper.loadGoogleModel(w2vFilePath)
     val (datas, labels) = DataHelper.loadMSDataWithWord2vec(
       mrDatasetPath, numEmbed, word2vec)
@@ -259,7 +258,7 @@ object CNNTextClassification {
     val lr = 0.001f
     val cnnModel = setupCnnModel(ctx, batchSize, sentenceSize, numEmbed)
     val result = trainCNN(cnnModel, trainDats, trainLabels, devDatas, devLabels, batchSize,
-      saveModelPath, learningRate = lr)
+      saveModelPath, learningRate = lr, epoch)
     result
   }
 
diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
index 08b4c85d2c5..6189cc15109 100644
--- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
+++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/neuralstyle/end2end/BoostTrain.scala
@@ -55,7 +55,8 @@ object BoostTrain {
   }
 
   def runTraining(dataPath : String, vggModelPath: String, ctx : Context,
-                  styleImage : String, saveModelPath : String) : Unit = {
+                  styleImage : String, saveModelPath : String,
+                  startEpoch : Int = 0, endEpoch : Int = 3) : Unit = {
     // params
     val vggParams = NDArray.load2Map(vggModelPath)
     val styleWeight = 1.2f
@@ -106,9 +107,6 @@ object BoostTrain {
 
     val tvWeight = 1e-2f
 
-    val startEpoch = 0
-    val endEpoch = 3
-
     for (k <- 0 until gens.length) {
       val path = new File(s"${saveModelPath}/$k")
       if (!path.exists()) path.mkdir()
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
index 95c9823e3b2..dd1ac47da40 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala
@@ -58,7 +58,7 @@ class CNNClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
       val modelDirPath = tempDirPath + File.separator + "CNN"
 
       val output = CNNTextClassification.test(modelDirPath + File.separator + w2vModelName,
-        modelDirPath, context, modelDirPath)
+        modelDirPath, context, modelDirPath, 10)
 
       Process("rm -rf " + modelDirPath) !
 
diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
index dc8fc5b8c14..ef4ba10b31a 100644
--- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
+++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/neuralstyle/NeuralStyleSuite.scala
@@ -70,7 +70,7 @@ class NeuralStyleSuite extends FunSuite with BeforeAndAfterAll {
       System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
       val ctx = Context.gpu()
       BoostTrain.runTraining(tempDirPath + "/NS/images", tempDirPath + "/NS/vgg19.params", ctx,
-        tempDirPath + "/NS/starry_night.jpg", tempDirPath + "/NS")
+        tempDirPath + "/NS/starry_night.jpg", tempDirPath + "/NS", 0, 3)
     } else {
       logger.info("GPU test only, skip CPU...")
     }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services