You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/03 17:39:36 UTC

[GitHub] lanking520 closed pull request #13619: [MXNET-1231] Allow not using Some in the Scala operators

lanking520 closed pull request #13619: [MXNET-1231] Allow not using Some in the Scala operators
URL: https://github.com/apache/incubator-mxnet/pull/13619
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
new file mode 100644
index 00000000000..2cf453ac3d1
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/util/OptionConversion.scala
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet.util
+
+object OptionConversion {
+  implicit def someWrapper[A](noSome : A) : Option[A] = Option(noSome)
+}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 7992a0ed867..2db9ff11b37 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -593,4 +593,17 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
     assert(rnd.shape === Shape(1, 2, 3, 4))
     assert(rnd2.shape === Shape(3, 4))
   }
+
+  test("Generated api") {
+    // Without SomeConversion
+    val arr3 = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+    val arr4 = NDArray.ones(Shape(1), dtype = DType.Float64)
+    val arr5 = NDArray.api.norm(arr3, ord = Some(1), out = Some(arr4))
+    // With SomeConversion
+    import org.apache.mxnet.util.OptionConversion._
+    val arr = NDArray.ones(Shape(1, 2), dtype = DType.Float64)
+    val arr2 = NDArray.ones(Shape(1), dtype = DType.Float64)
+    NDArray.api.norm(arr, ord = 1, out = arr2)
+    val result = NDArray.api.dot(arr2, arr2)
+  }
 }
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index cf55bc10d97..5208923275f 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -126,7 +126,6 @@ class Classifier(modelPathPrefix: String,
     })
 
     val predictResult = predictResultPar.toArray
-
     var result: ListBuffer[IndexedSeq[(String, Float)]] =
       ListBuffer.empty[IndexedSeq[(String, Float)]]
 
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
index 1c2c4fd704b..498c4e94366 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
@@ -96,7 +96,7 @@ private[mxnet] abstract class GeneratorBase {
       else if (isSymbol) "org.apache.mxnet.Symbol"
       else "org.apache.mxnet.NDArray"
       val typeAndOption =
-        CToScalaUtils.argumentCleaner(argName, argType, family)
+        CToScalaUtils.argumentCleaner(argName, argType, family, isJava)
       Arg(argName, typeAndOption._1, argDesc, typeAndOption._2)
     }
     val returnType =
@@ -191,7 +191,7 @@ private[mxnet] trait RandomHelpers {
   // unify call targets (random_xyz and sample_xyz) and unify their argument types
   private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
     var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
-      "java.lang.Float", "java.lang.Integer")
+      "Float", "Int")
 
     func.copy(
       name = func.name.replaceAll("(random|sample)_", ""),
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 2fd8b2e73c7..57c4cfba10b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -18,23 +18,35 @@ package org.apache.mxnet.utils
 
 private[mxnet] object CToScalaUtils {
 
-
+  private val javaType = Map(
+    "float" -> "java.lang.Float",
+    "int" -> "java.lang.Integer",
+    "long" -> "java.lang.Long",
+    "double" -> "java.lang.Double",
+    "bool" -> "java.lang.Boolean")
+  private val scalaType = Map(
+    "float" -> "Float",
+    "int" -> "Int",
+    "long" -> "Long",
+    "double" -> "Double",
+    "bool" -> "Boolean")
 
   // Convert C++ Types to Scala Types
   def typeConversion(in : String, argType : String = "", argName : String,
-                     returnType : String) : String = {
+                     returnType : String, isJava : Boolean) : String = {
     val header = returnType.split("\\.").dropRight(1)
+    val types = if (isJava) javaType else scalaType
     in match {
       case "Shape(tuple)" | "ShapeorNone" => s"${header.mkString(".")}.Shape"
       case "Symbol" | "NDArray" | "NDArray-or-Symbol" => returnType
       case "Symbol[]" | "NDArray[]" | "NDArray-or-Symbol[]" | "SymbolorSymbol[]"
       => s"Array[$returnType]"
-      case "float" | "real_t" | "floatorNone" => "java.lang.Float"
-      case "int" | "intorNone" | "int(non-negative)" => "java.lang.Integer"
-      case "long" | "long(non-negative)" => "java.lang.Long"
-      case "double" | "doubleorNone" => "java.lang.Double"
+      case "float" | "real_t" | "floatorNone" => types("float")
+      case "int" | "intorNone" | "int(non-negative)" => types("int")
+      case "long" | "long(non-negative)" => types("long")
+      case "double" | "doubleorNone" => types("double")
       case "string" => "String"
-      case "boolean" | "booleanorNone" => "java.lang.Boolean"
+      case "boolean" | "booleanorNone" => types("bool")
       case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => "Any"
       case default => throw new IllegalArgumentException(
         s"Invalid type for args: $default\nString argType: $argType\nargName: $argName")
@@ -54,7 +66,7 @@ private[mxnet] object CToScalaUtils {
     * @return (Scala_Type, isOptional)
     */
   def argumentCleaner(argName: String, argType : String,
-                      returnType : String) : (String, Boolean) = {
+                      returnType : String, isJava : Boolean) : (String, Boolean) = {
     val spaceRemoved = argType.replaceAll("\\s+", "")
     var commaRemoved : Array[String] = new Array[String](0)
     // Deal with the case e.g: stype : {'csr', 'default', 'row_sparse'}
@@ -72,9 +84,9 @@ private[mxnet] object CToScalaUtils {
         s"""expected "optional" got ${commaRemoved(1)}""")
       require(commaRemoved(2).startsWith("default="),
         s"""expected "default=..." got ${commaRemoved(2)}""")
-      (typeConversion(commaRemoved(0), argType, argName, returnType), true)
+      (typeConversion(commaRemoved(0), argType, argName, returnType, isJava), true)
     } else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
-      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
+      val tempType = typeConversion(commaRemoved(0), argType, argName, returnType, isJava)
       val tempOptional = tempType.equals("org.apache.mxnet.Symbol")
       (tempType, tempOptional)
     } else {
diff --git a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
index 4404b0885d5..4069bba2522 100644
--- a/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
+++ b/scala-package/macros/src/test/scala/org/apache/mxnet/MacrosSuite.scala
@@ -36,14 +36,15 @@ class MacrosSuite extends FunSuite with BeforeAndAfterAll {
     )
     val output = List(
       ("org.apache.mxnet.Symbol", true),
-      ("java.lang.Integer", false),
+      ("Int", false),
       ("org.apache.mxnet.Shape", true),
       ("String", true),
       ("Any", false)
     )
 
     for (idx <- input.indices) {
-      val result = CToScalaUtils.argumentCleaner("Sample", input(idx), "org.apache.mxnet.Symbol")
+      val result = CToScalaUtils.argumentCleaner("Sample", input(idx),
+        "org.apache.mxnet.Symbol", false)
       assert(result._1 === output(idx)._1 && result._2 === output(idx)._2)
     }
   }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services