You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ns...@apache.org on 2018/09/20 15:28:08 UTC
[incubator-mxnet] branch master updated: review require() usages to
add meaningful messages. (#12570)
This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 3401e6e review require() usages to add meaningful messages. (#12570)
3401e6e is described below
commit 3401e6e116615730882b9bede5b6abbb51b9a547
Author: mathieu <ma...@gmail.com>
AuthorDate: Thu Sep 20 17:27:53 2018 +0200
review require() usages to add meaningful messages. (#12570)
* review require() usages to add meaningful messages.
---
.../main/scala/org/apache/mxnet/EvalMetric.scala | 47 ++++++++++++-----
.../src/main/scala/org/apache/mxnet/Executor.scala | 15 +++---
.../scala/org/apache/mxnet/ExecutorManager.scala | 35 +++++++++----
.../main/scala/org/apache/mxnet/FeedForward.scala | 11 ++--
.../core/src/main/scala/org/apache/mxnet/IO.scala | 9 ++--
.../main/scala/org/apache/mxnet/Initializer.scala | 3 +-
.../src/main/scala/org/apache/mxnet/Model.scala | 3 +-
.../src/main/scala/org/apache/mxnet/NDArray.scala | 18 ++++---
.../src/main/scala/org/apache/mxnet/Operator.scala | 19 ++++---
.../src/main/scala/org/apache/mxnet/Profiler.scala | 2 +-
.../src/main/scala/org/apache/mxnet/Symbol.scala | 38 +++++++++-----
.../scala/org/apache/mxnet/Visualization.scala | 4 +-
.../scala/org/apache/mxnet/module/BaseModule.scala | 25 ++++-----
.../org/apache/mxnet/module/BucketingModule.scala | 42 ++++++++-------
.../mxnet/module/DataParallelExecutorGroup.scala | 61 ++++++++++++++--------
.../scala/org/apache/mxnet/module/Module.scala | 46 +++++++++-------
.../org/apache/mxnet/module/SequentialModule.scala | 39 ++++++++------
.../scala/org/apache/mxnet/optimizer/DCASGD.scala | 3 +-
.../scala/org/apache/mxnet/optimizer/NAG.scala | 3 +-
.../scala/org/apache/mxnet/optimizer/SGD.scala | 3 +-
.../scala/org/apache/mxnet/infer/Classifier.scala | 6 +--
.../org/apache/mxnet/infer/MXNetHandler.scala | 3 +-
.../scala/org/apache/mxnet/infer/Predictor.scala | 21 ++++----
.../org/apache/mxnet/utils/CToScalaUtils.scala | 6 ++-
.../scala/org/apache/mxnet/spark/MXNDArray.scala | 2 +-
.../scala/org/apache/mxnet/spark/MXNetModel.scala | 6 +--
26 files changed, 286 insertions(+), 184 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/EvalMetric.scala b/scala-package/core/src/main/scala/org/apache/mxnet/EvalMetric.scala
index de2881a..aedf7c8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/EvalMetric.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/EvalMetric.scala
@@ -133,17 +133,20 @@ class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
- "labels and predictions should have the same length.")
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((pred, label) <- preds zip labels) {
val predShape = pred.shape
val dims = predShape.length
- require(dims <= 2, "Predictions should be no more than 2 dims.")
+ require(dims <= 2, s"Predictions should be no more than 2 dims (got $predShape).")
val labelArray = label.toArray
val numSamples = predShape(0)
if (dims == 1) {
val predArray = pred.toArray.zipWithIndex.sortBy(_._1).reverse.map(_._2)
- require(predArray.length == labelArray.length)
+ require(predArray.length == labelArray.length,
+ s"Each label and prediction array should have the same length " +
+ s"(got ${labelArray.length} and ${predArray.length}).")
this.sumMetric +=
labelArray.zip(predArray).map { case (l, p) => if (l == p) 1 else 0 }.sum
} else if (dims == 2) {
@@ -151,7 +154,9 @@ class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") {
val predArray = pred.toArray.grouped(numclasses).map { a =>
a.zipWithIndex.sortBy(_._1).reverse.map(_._2)
}.toArray
- require(predArray.length == labelArray.length)
+ require(predArray.length == labelArray.length,
+ s"Each label and prediction array should have the same length " +
+ s"(got ${labelArray.length} and ${predArray.length}).")
val topK = Math.max(this.topK, numclasses)
for (j <- 0 until topK) {
this.sumMetric +=
@@ -169,7 +174,8 @@ class TopKAccuracy(topK: Int) extends EvalMetric("top_k_accuracy") {
class F1 extends EvalMetric("f1") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
- "labels and predictions should have the same length.")
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((pred, label) <- preds zip labels) {
val predLabel = NDArray.argmax_channel(pred)
@@ -223,7 +229,8 @@ class F1 extends EvalMetric("f1") {
class Perplexity(ignoreLabel: Option[Int] = None, axis: Int = -1) extends EvalMetric("Perplexity") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
require(labels.length == preds.length,
- "labels and predictions should have the same length.")
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
var loss = 0d
var num = 0
val probs = ArrayBuffer[NDArray]()
@@ -261,12 +268,16 @@ class Perplexity(ignoreLabel: Option[Int] = None, axis: Int = -1) extends EvalMe
*/
class MAE extends EvalMetric("mae") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
- require(labels.size == preds.size, "labels and predictions should have the same length.")
+ require(labels.size == preds.size,
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
- require(labelArr.length == predArr.length)
+ require(labelArr.length == predArr.length,
+ s"Each label and prediction array should have the same length " +
+ s"(got ${labelArr.length} and ${predArr.length}).")
this.sumMetric +=
(labelArr zip predArr).map { case (l, p) => Math.abs(l - p) }.sum / labelArr.length
this.numInst += 1
@@ -277,12 +288,16 @@ class MAE extends EvalMetric("mae") {
// Calculate Mean Squared Error loss
class MSE extends EvalMetric("mse") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
- require(labels.size == preds.size, "labels and predictions should have the same length.")
+ require(labels.size == preds.size,
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
- require(labelArr.length == predArr.length)
+ require(labelArr.length == predArr.length,
+ s"Each label and prediction array should have the same length " +
+ s"(got ${labelArr.length} and ${predArr.length}).")
this.sumMetric +=
(labelArr zip predArr).map { case (l, p) => (l - p) * (l - p) }.sum / labelArr.length
this.numInst += 1
@@ -295,12 +310,16 @@ class MSE extends EvalMetric("mse") {
*/
class RMSE extends EvalMetric("rmse") {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
- require(labels.size == preds.size, "labels and predictions should have the same length.")
+ require(labels.size == preds.size,
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((label, pred) <- labels zip preds) {
val labelArr = label.toArray
val predArr = pred.toArray
- require(labelArr.length == predArr.length)
+ require(labelArr.length == predArr.length,
+ s"Each label and prediction array should have the same length " +
+ s"(got ${labelArr.length} and ${predArr.length}).")
val metric: Double = Math.sqrt(
(labelArr zip predArr).map { case (l, p) => (l - p) * (l - p) }.sum / labelArr.length)
this.sumMetric += metric.toFloat
@@ -318,7 +337,9 @@ class RMSE extends EvalMetric("rmse") {
class CustomMetric(fEval: (NDArray, NDArray) => Float,
name: String) extends EvalMetric(name) {
override def update(labels: IndexedSeq[NDArray], preds: IndexedSeq[NDArray]): Unit = {
- require(labels.size == preds.size, "labels and predictions should have the same length.")
+ require(labels.size == preds.size,
+ s"labels and predictions should have the same length " +
+ s"(got ${labels.length} and ${preds.length}).")
for ((label, pred) <- labels zip preds) {
this.sumMetric += fEval(label, pred)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index 181b232..fc791d5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -26,7 +26,7 @@ object Executor {
// Get the dictionary given name and ndarray pairs.
private[mxnet] def getDict(names: Seq[String],
ndarrays: Seq[NDArray]): Map[String, NDArray] = {
- require(names.toSet.size == names.length, "Duplicate names detected")
+ require(names.toSet.size == names.length, s"Duplicate names detected in ($names)")
(names zip ndarrays).toMap
}
}
@@ -86,7 +86,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
- require(argShapes != null, "Insufficient argument shapes provided.")
+ // TODO: more precise error message should be provided by backend
+ require(argShapes != null, "Shape inference failed." +
+ s"Known shapes are $kwargs for symbol arguments ${symbol.listArguments()} " +
+ s"and aux states ${symbol.listAuxiliaryStates()}")
var newArgDict = Map[String, NDArray]()
var newGradDict = Map[String, NDArray]()
@@ -194,13 +197,13 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray]): Unit = {
- require(outGrads != null)
+ require(outGrads != null, "outGrads must not be null")
val ndArrayPtrs = outGrads.map(_.handle)
checkCall(_LIB.mxExecutorBackward(handle, ndArrayPtrs))
}
def backward(outGrad: NDArray): Unit = {
- require(outGrad != null)
+ require(outGrad != null, "outGrads must not be null")
backward(Array(outGrad))
}
@@ -271,7 +274,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
if (argDict.contains(name)) {
array.copyTo(argDict(name))
} else {
- require(allowExtraParams, s"Find name $name that is not in the arguments")
+ require(allowExtraParams, s"Provided name $name is not in the arguments")
}
}
if (auxParams != null) {
@@ -279,7 +282,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
if (auxDict.contains(name)) {
array.copyTo(auxDict(name))
} else {
- require(allowExtraParams, s"Find name $name that is not in the auxiliary states")
+ require(allowExtraParams, s"Provided name $name is not in the auxiliary states")
}
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
index 22914a5..b13741b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ExecutorManager.scala
@@ -54,7 +54,8 @@ private[mxnet] class DataParallelExecutorManager(private val symbol: Symbol,
if (workLoadList == null) {
workLoadList = Seq.fill(numDevice)(1f)
}
- require(workLoadList.size == numDevice, "Invalid settings for work load.")
+ require(workLoadList.size == numDevice, "Invalid settings for work load. " +
+ s"Size (${workLoadList.size}) should match num devices ($numDevice)")
private val slices = ExecutorManager.splitInputSlice(trainData.batchSize, workLoadList)
@@ -212,13 +213,13 @@ private[mxnet] object ExecutorManager {
private[mxnet] def checkArguments(symbol: Symbol): Unit = {
val argNames = symbol.listArguments()
require(argNames.toSet.size == argNames.length,
- "Find duplicated argument name," +
+ "Found duplicated argument name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $argNames")
val auxNames = symbol.listAuxiliaryStates()
require(auxNames.toSet.size == auxNames.length,
- "Find duplicated auxiliary param name," +
+ "Found duplicated auxiliary param name," +
"please make the weight name non-duplicated(using name arguments)," +
s"arguments are $auxNames")
}
@@ -272,7 +273,11 @@ private[mxnet] object ExecutorManager {
sharedDataArrays: mutable.Map[String, NDArray] = null,
inputTypes: ListMap[String, DType] = null) = {
val (argShape, _, auxShape) = sym.inferShape(inputShapes)
- require(argShape != null)
+ // TODO: more precise error message should be provided by backend
+ require(argShape != null, "Shape inference failed." +
+ s"Known shapes are $inputShapes for symbol arguments ${sym.listArguments()} " +
+ s"and aux states ${sym.listAuxiliaryStates()}")
+
val inputTypesUpdate =
if (inputTypes == null) {
inputShapes.map { case (key, _) => (key, Base.MX_REAL_TYPE) }
@@ -280,7 +285,9 @@ private[mxnet] object ExecutorManager {
inputTypes
}
val (argTypes, _, auxTypes) = sym.inferType(inputTypesUpdate)
- require(argTypes != null)
+ require(argTypes != null, "Type inference failed." +
+ s"Known types as $inputTypes for symbol arguments ${sym.listArguments()} " +
+ s"and aux states ${sym.listAuxiliaryStates()}")
val argArrays = ArrayBuffer.empty[NDArray]
val gradArrays: mutable.Map[String, NDArray] =
@@ -311,7 +318,8 @@ private[mxnet] object ExecutorManager {
val arr = sharedDataArrays(name)
if (arr.shape.product >= argShape(i).product) {
// good, we can share this memory
- require(argTypes(i) == arr.dtype)
+ require(argTypes(i) == arr.dtype,
+ s"Type ${arr.dtype} of argument $name does not match inferred type ${argTypes(i)}")
arr.reshape(argShape(i))
} else {
DataParallelExecutorManager.logger.warn(
@@ -345,8 +353,10 @@ private[mxnet] object ExecutorManager {
NDArray.zeros(argShape(i), ctx, dtype = argTypes(i))
} else {
val arr = baseExec.argDict(name)
- require(arr.shape == argShape(i))
- require(arr.dtype == argTypes(i))
+ require(arr.shape == argShape(i),
+ s"Shape ${arr.shape} of argument $name does not match inferred shape ${argShape(i)}")
+ require(arr.dtype == argTypes(i),
+ s"Type ${arr.dtype} of argument $name does not match inferred type ${argTypes(i)}")
if (gradSet.contains(name)) {
gradArrays.put(name, baseExec.gradDict(name))
}
@@ -356,6 +366,7 @@ private[mxnet] object ExecutorManager {
}
}
// create or borrow aux variables
+ val auxNames = sym.listAuxiliaryStates()
val auxArrays =
if (baseExec == null) {
(auxShape zip auxTypes) map { case (s, t) =>
@@ -363,8 +374,12 @@ private[mxnet] object ExecutorManager {
}
} else {
baseExec.auxArrays.zipWithIndex.map { case (a, i) =>
- require(auxShape(i) == a.shape)
- require(auxTypes(i) == a.dtype)
+ require(auxShape(i) == a.shape,
+ s"Shape ${a.shape} of aux variable ${auxNames(i)} does not match " +
+ s"inferred shape ${auxShape(i)}")
+ require(auxTypes(i) == a.dtype,
+ s"Type ${a.dtype} of aux variable ${auxNames(i)} does not match " +
+ s"inferred type ${auxTypes(i)}")
a
}.toSeq
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
index 87c9bc7..00a1450 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/FeedForward.scala
@@ -99,7 +99,7 @@ class FeedForward private(
// verify the argument of the default symbol and user provided parameters
def checkArguments(): Unit = {
if (!argumentChecked) {
- require(symbol != null)
+ require(symbol != null, "Symbol must not be null")
// check if symbol contain duplicated names.
ExecutorManager.checkArguments(symbol)
// rematch parameters to delete useless ones
@@ -169,7 +169,9 @@ class FeedForward private(
private def initPredictor(inputShapes: Map[String, Shape]): Unit = {
if (this.predExec != null) {
val (argShapes, _, _) = symbol.inferShape(inputShapes)
- require(argShapes != null, "Incomplete input shapes")
+ require(argShapes != null, "Shape inference failed." +
+ s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
+ s"and aux states ${symbol.listAuxiliaryStates()}")
val predShapes = this.predExec.argArrays.map(_.shape)
if (argShapes.sameElements(predShapes)) {
return
@@ -187,7 +189,8 @@ class FeedForward private(
require(y != null || !isTrain, "y must be specified")
val label = if (y == null) NDArray.zeros(X.shape(0)) else y
require(label.shape.length == 1, "Label must be 1D")
- require(X.shape(0) == label.shape(0), "The numbers of data points and labels not equal")
+ require(X.shape(0) == label.shape(0),
+ s"The numbers of data points (${X.shape(0)}) and labels (${label.shape(0)}) are not equal")
if (isTrain) {
new NDArrayIter(IndexedSeq(X), IndexedSeq(label), batchSize,
shuffle = isTrain, lastBatchHandle = "roll_over")
@@ -402,7 +405,7 @@ class FeedForward private(
* - ``prefix-epoch.params`` will be saved for parameters.
*/
def save(prefix: String, epoch: Int = this.numEpoch): Unit = {
- require(epoch >= 0)
+ require(epoch >= 0, s"epoch must be >=0 (got $epoch)")
Model.saveCheckpoint(prefix, epoch, this.symbol, getArgParams, getAuxParams)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index a1095cf..e835142 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -114,8 +114,9 @@ object IO {
defaultName: String,
defaultDType: DType,
defaultLayout: String): IndexedSeq[(DataDesc, NDArray)] = {
- require(data != null)
- require(data != IndexedSeq.empty || allowEmpty)
+ require(data != null, "data is required.")
+ require(data != IndexedSeq.empty || allowEmpty,
+ s"data should not be empty when allowEmpty is false")
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
@@ -372,9 +373,7 @@ abstract class DataPack() extends Iterable[DataBatch] {
case class DataDesc(name: String, shape: Shape,
dtype: DType = DType.Float32, layout: String = Layout.UNDEFINED) {
require(layout == Layout.UNDEFINED || shape.length == layout.length,
- ("number of dimensions in shape :%d with" +
- " shape: %s should match the length of the layout: %d with layout: %s").
- format(shape.length, shape.toString, layout.length, layout))
+ s"number of dimensions in $shape should match the layout $layout")
override def toString(): String = {
s"DataDesc[$name,$shape,$dtype,$layout]"
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Initializer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Initializer.scala
index e26690c..8531285 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Initializer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Initializer.scala
@@ -98,7 +98,8 @@ abstract class Initializer {
*/
class Mixed(protected val patterns: List[String],
protected val initializers: List[Initializer]) extends Initializer {
- require(patterns.length == initializers.length)
+ require(patterns.length == initializers.length,
+ "Should provide a pattern for each initializer")
private val map = patterns.map(_.r).zip(initializers)
override def apply(name: String, arr: NDArray): Unit = {
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
index ad6fae5..4bb9cdd 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
@@ -160,7 +160,8 @@ object Model {
argParams: Map[String, NDArray],
paramNames: IndexedSeq[String],
updateOnKVStore: Boolean): Unit = {
- require(paramArrays.length == paramNames.length)
+ require(paramArrays.length == paramNames.length,
+ s"Provided parameter arrays does not match parameter names")
for (idx <- 0 until paramArrays.length) {
val paramOnDevs = paramArrays(idx)
val name = paramNames(idx)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 8b5e1e0..9b6a7dc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -427,7 +427,7 @@ object NDArray extends NDArrayBase {
* @return An `NDArray` that lives on the same context as `arrays[0].context`.
*/
def concatenate(arrays: Seq[NDArray], axis: Int = 0, alwaysCopy: Boolean = true): NDArray = {
- require(arrays.size > 0)
+ require(arrays.size > 0, "Provide at least one array")
val array0 = arrays(0)
if (!alwaysCopy && arrays.size == 1) {
@@ -439,9 +439,12 @@ object NDArray extends NDArrayBase {
val shapeAxis =
arrays.map(arr => {
- require(shapeRest1 == arr.shape.slice(0, axis))
- require(shapeRest2 == arr.shape.slice(axis + 1, arr.shape.length))
- require(dtype == arr.dtype)
+ require(shapeRest1 == arr.shape.slice(0, axis),
+ s"Mismatch between shape $shapeRest1 and ${arr.shape}")
+ require(shapeRest2 == arr.shape.slice(axis + 1, arr.shape.length),
+ s"Mismatch between shape $shapeRest2 and ${arr.shape}")
+ require(dtype == arr.dtype,
+ s"All arrays must have the same type (got ${dtype} and ${arr.dtype})")
arr.shape(axis)
}).sum
val retShape = shapeRest1 ++ Shape(shapeAxis) ++ shapeRest2
@@ -484,7 +487,7 @@ object NDArray extends NDArrayBase {
* - `s3://my-bucket/path/my-s3-ndarray`
* - `hdfs://my-bucket/path/my-hdfs-ndarray`
* - `/path-to/my-local-ndarray`
- * @return dict of str->NDArray to be saved
+ * @return dict of str->NDArray
*/
def load(fname: String): (Array[String], Array[NDArray]) = {
val outSize = new MXUintRef
@@ -492,7 +495,8 @@ object NDArray extends NDArrayBase {
val handles = ArrayBuffer.empty[NDArrayHandle]
val names = ArrayBuffer.empty[String]
checkCall(_LIB.mxNDArrayLoad(fname, outSize, handles, outNameSize, names))
- require(outNameSize.value == 0 || outNameSize.value == outSize.value)
+ require(outNameSize.value == 0 || outNameSize.value == outSize.value,
+ s"Mismatch between names and arrays in file $fname")
(names.toArray, handles.map(new NDArray(_)).toArray)
}
@@ -1003,7 +1007,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
val ndim = new MXUintRef
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
- require(ndim.value == data.length, s"ndim=$ndim, while len(pdata)=${data.length}")
+ require(ndim.value == data.length, s"ndim=$ndim, while len(data)=${data.length}")
Shape(data)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
index f2abe5e..a521702 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
@@ -134,7 +134,8 @@ abstract class CustomOpProp(needTopGrad: Boolean = false) {
protected var kwargs: Map[String, String] = Map[String, String]()
private[mxnet] def init(keys: Array[String], vals: Array[String]): Unit = {
- require(keys.length == vals.length)
+ require(keys.length == vals.length,
+ s"Number of keys (${keys.length}) does not match arrays (${vals.length})")
kwargs = keys.zip(vals).toMap
}
@@ -166,11 +167,13 @@ abstract class CustomOpProp(needTopGrad: Boolean = false) {
val tmp = this.listAuxiliaryStates()
if (tmp == null) 0 else tmp.length
}
- require(numTensor == (nIn + nOut + nAux))
+ require(numTensor == (nIn + nOut + nAux),
+ s"Shape inference failed. $numTensor tensors expected, but got " +
+ s"$nIn args, $nOut ouputs and $nAux aux states")
val (inShapes, outShapes, auxShapes) =
inferShape(intputShapes.map(Shape(_)))
- require(inShapes != null && inShapes.length != 0)
- require(outShapes != null && outShapes.length != 0)
+ require(inShapes != null && inShapes.length != 0, "InputShape is undefined or empty")
+ require(outShapes != null && outShapes.length != 0, "OutputShape is undefined or empty")
if (auxShapes != null && auxShapes.length != 0) {
inShapes.map(_.toArray) ++ outShapes.map(_.toArray) ++ auxShapes.map(_.toArray)
} else inShapes.map(_.toArray) ++ outShapes.map(_.toArray)
@@ -206,11 +209,13 @@ abstract class CustomOpProp(needTopGrad: Boolean = false) {
val tmp = this.listAuxiliaryStates()
if (tmp == null) 0 else tmp.length
}
- require(numTensor == (nIn + nOut + nAux))
+ require(numTensor == (nIn + nOut + nAux),
+ s"Type inference failed. $numTensor tensors expected, but got " +
+ s"$nIn args, $nOut ouputs and $nAux aux states")
val (inTypes, outTypes, auxTypes) =
inferType(intputTypes.map(DType(_)))
- require(inTypes != null && inTypes.length != 0)
- require(outTypes != null && outTypes.length != 0)
+ require(inTypes != null && inTypes.length != 0, "InputType is undefined or empty")
+ require(outTypes != null && outTypes.length != 0, "OutputType is undefined or empty")
if (auxTypes != null && auxTypes.length != 0) {
inTypes.map(_.id) ++ outTypes.map(_.id) ++ auxTypes.map(_.id)
} else inTypes.map(_.id) ++ outTypes.map(_.id)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Profiler.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Profiler.scala
index df10b34..a917377 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Profiler.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Profiler.scala
@@ -44,7 +44,7 @@ object Profiler {
* be "stop" or "run". Default is "stop".
*/
def profilerSetState(state: String = "stop"): Unit = {
- require(state2Int.contains(state))
+ require(state2Int.contains(state), s"Invalid state $state")
checkCall(_LIB.mxSetProfilerState(state2Int(state)))
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index e3e1a32..b1a3e39 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -423,7 +423,13 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
}
val (argShapes, _, auxShapes) = inferShape(shapeDict)
val (argTypes, _, auxTypes) = inferType(types)
- require(argShapes != null && argTypes != null, "Input node is not complete")
+ require(argShapes != null, "Shape inference failed." +
+ s"Known shapes are $shapeDict for symbol arguments ${listArguments()} " +
+ s"and aux states ${listAuxiliaryStates()}")
+ require(argTypes != null, "Type inference failed." +
+ s"Known types as $typeDict for symbol arguments ${listArguments()} " +
+ s"and aux states ${listAuxiliaryStates()}")
+
// alloc space
val argNDArrays = (argShapes zip argTypes) map { case (shape, t) =>
NDArray.zeros(shape, ctx, dtype = t)
@@ -715,10 +721,14 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
args: Iterable[_], argsGrad: Iterable[_],
gradsReq: Iterable[_], auxStates: Iterable[_],
group2ctx: Map[String, Context], sharedExec: Executor): Executor = {
- require(args != null && !args.isInstanceOf[Set[_]])
- require(argsGrad == null || !argsGrad.isInstanceOf[Set[_]])
- require(auxStates == null || !auxStates.isInstanceOf[Set[_]])
- require(gradsReq != null && !gradsReq.isInstanceOf[Set[_]])
+ require(args != null && !args.isInstanceOf[Set[_]],
+ s"args must be provided (Set is not supported)")
+ require(argsGrad == null || !argsGrad.isInstanceOf[Set[_]],
+ s"argsGrad must be provided (Set is not supported)")
+ require(auxStates == null || !auxStates.isInstanceOf[Set[_]],
+ s"auxStates must be provided (Set is not supported)")
+ require(gradsReq != null && !gradsReq.isInstanceOf[Set[_]],
+ s"gradsReq must be provided (Set is not supported)")
val (argsHandle, argsNDArray) =
if (args.isInstanceOf[Seq[_]]) {
@@ -756,14 +766,16 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD
val reqsArray =
if (gradsReq.isInstanceOf[Seq[_]]) {
gradsReq.asInstanceOf[Seq[String]].map { req =>
- require(Symbol.bindReqMap.contains(req), s"grad_req must be in ${Symbol.bindReqMap}")
+ require(Symbol.bindReqMap.contains(req),
+ s"grad_req $req must be in ${Symbol.bindReqMap}")
Symbol.bindReqMap(req)
}.toArray
} else {
val gradsReqMap = gradsReq.asInstanceOf[Map[String, String]]
symbolArguments.map { req =>
val value = gradsReqMap.getOrElse(req, "null")
- require(Symbol.bindReqMap.contains(value), s"grad_req must be in ${Symbol.bindReqMap}")
+ require(Symbol.bindReqMap.contains(value),
+ s"grad_req $req must be in ${Symbol.bindReqMap}")
Symbol.bindReqMap(value)
}.toArray
}
@@ -1083,9 +1095,9 @@ object Symbol extends SymbolBase {
(key, value.toString)
}
}
- require(symbols.isEmpty || symbolKwargs.isEmpty, String.format(
- "%s can only accept input Symbols either as positional or keyword arguments, not both",
- operator))
+ require(symbols.isEmpty || symbolKwargs.isEmpty,
+ s"$operator can only accept input Symbols either as positional or keyword arguments, " +
+ s"not both")
if (symbols.isEmpty) {
createFromNamedSymbols(operator, name, attr)(symbolKwargs, strKwargs)
} else {
@@ -1217,7 +1229,8 @@ object Symbol extends SymbolBase {
*/
private def getNDArrayInputs(argKey: String, args: Seq[NDArray], argNames: Seq[String],
allowMissing: Boolean): (Array[NDArrayHandle], Array[NDArray]) = {
- require(args.length == argNames.length, s"Length of $argKey do not match number of arguments")
+ require(args.length == argNames.length,
+ s"Length of $argKey do not match number of arguments")
val argHandles = args.map(_.handle)
(argHandles.toArray, args.toArray)
}
@@ -1232,7 +1245,8 @@ object Symbol extends SymbolBase {
argArrays += narr.get
argHandles += narr.get.handle
case None =>
- require(allowMissing, s"Must specify all the arguments in $argKey")
+ require(allowMissing,
+ s"Must specify all the arguments in $argKey. $name is unknown")
argArrays += null
argHandles += 0L
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Visualization.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Visualization.scala
index 6ecc3ca..2a7b7a8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Visualization.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Visualization.scala
@@ -198,9 +198,9 @@ object Visualization {
case None => null
case Some(map) => map.asInstanceOf[Map[String, Any]]
}
- require(conf != null)
+ require(conf != null, "Invalid json")
- require(conf.contains("nodes"))
+ require(conf.contains("nodes"), "Invalid json")
val nodes = conf("nodes").asInstanceOf[List[Any]]
// default attributes of node
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 60b80f2..30e57c5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -171,8 +171,8 @@ abstract class BaseModule {
batchEndCallback: Option[BatchEndCallback] = None,
scoreEndCallback: Option[BatchEndCallback] = None,
reset: Boolean = true, epoch: Int = 0): EvalMetric = {
- require(evalData != null && evalMetric != null)
- require(binded && paramsInitialized)
+ require(evalData != null && evalMetric != null, "evalData and evalMetric must be defined")
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
if (reset) {
evalData.reset()
@@ -216,7 +216,7 @@ abstract class BaseModule {
*/
def predictEveryBatch(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
: IndexedSeq[IndexedSeq[NDArray]] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
if (reset) {
evalData.reset()
}
@@ -234,7 +234,7 @@ abstract class BaseModule {
}
def predict(batch: DataBatch): IndexedSeq[NDArray] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
forward(batch, isTrain = Option(false))
val pad = batch.pad
getOutputsMerged().map(out => {
@@ -260,7 +260,8 @@ abstract class BaseModule {
val numOutputs = outputBatches.head.size
outputBatches.foreach(out =>
require(out.size == numOutputs,
- "Cannot merge batches, as num of outputs is not the same in mini-batches." +
+ s"Cannot merge batches, as num of outputs $numOutputs is not the same " +
+ s"in mini-batches (${out.size})." +
"Maybe bucketing is used?")
)
val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
@@ -395,8 +396,8 @@ abstract class BaseModule {
*/
def fit(trainData: DataIter, evalData: Option[DataIter] = None, numEpoch: Int = 1,
fitParams: FitParams = new FitParams): Unit = {
- require(fitParams != null)
- require(numEpoch > 0, "please specify number of epochs")
+ require(fitParams != null, "Undefined fitParams")
+ require(numEpoch > 0, s"Invalid number of epochs $numEpoch")
import org.apache.mxnet.DataDesc._
bind(dataShapes = trainData.provideData, labelShapes = Option(trainData.provideLabel),
forTraining = true, forceRebind = fitParams.forceRebind)
@@ -604,7 +605,7 @@ class FitParams {
// The performance measure used to display during training.
def setEvalMetric(evalMetric: EvalMetric): FitParams = {
- require(evalMetric != null)
+ require(evalMetric != null, "Undefined evalMetric")
this.evalMetric = evalMetric
this
}
@@ -623,13 +624,13 @@ class FitParams {
}
def setKVStore(kvStore: String): FitParams = {
- require(kvStore != null)
+ require(kvStore != null, "Undefined kvStore")
this.kvstore = kvstore
this
}
def setOptimizer(optimizer: Optimizer): FitParams = {
- require(optimizer != null)
+ require(optimizer != null, "Undefined optimizer")
this.optimizer = optimizer
this
}
@@ -649,7 +650,7 @@ class FitParams {
// Will be called to initialize the module parameters if not already initialized.
def setInitializer(initializer: Initializer): FitParams = {
- require(initializer != null)
+ require(initializer != null, "Undefined Initializer")
this.initializer = initializer
this
}
@@ -697,7 +698,7 @@ class FitParams {
// checkpoint saved at a previous training phase at epoch N,
// then we should specify this value as N+1.
def setBeginEpoch(beginEpoch: Int): FitParams = {
- require(beginEpoch >= 0)
+ require(beginEpoch >= 0, s"Invalid epoch $beginEpoch")
this.beginEpoch = beginEpoch
this
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
index 2823818..2262f5c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BucketingModule.scala
@@ -52,7 +52,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
}
private val workLoads = workLoadList.getOrElse(contexts.map(_ => 1f).toIndexedSeq)
- require(workLoads.size == contexts.length)
+ require(workLoads.size == contexts.length,
+ s"workloads size (${workLoads.size}) do not match number of contexts ${contexts.length}")
private val _buckets = scala.collection.mutable.Map[AnyRef, Module]()
private var _currModule: Module = null
@@ -84,7 +85,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
// Input/Output information
// A list of (name, shape) pairs specifying the data inputs to this module.
override def dataShapes: IndexedSeq[DataDesc] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this._currModule.dataShapes
}
@@ -95,13 +96,13 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* list `[]`.
*/
override def labelShapes: IndexedSeq[DataDesc] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this._currModule.labelShapes
}
// A list of (name, shape) pairs specifying the outputs of this module.
override def outputShapes: IndexedSeq[(String, Shape)] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this._currModule.outputShapes
}
@@ -111,7 +112,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* `NDArray`) mapping.
*/
override def getParams: (Map[String, NDArray], Map[String, NDArray]) = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this._currModule.paramsDirty = this.paramsDirty
val params = this._currModule.getParams
this.paramsDirty = false
@@ -220,8 +221,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
return
}
- require(sharedModule == None,
- "shared_module for BucketingModule is not supported")
+ require(sharedModule.isEmpty,
+ "sharedModule for BucketingModule is not supported")
this.forTraining = forTraining
this.inputsNeedGrad = inputsNeedGrad
@@ -276,7 +277,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
*/
override def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
if (optimizerInitialized && !forceInit) {
logger.warn("optimizer already initialized, ignoring ...")
} else {
@@ -294,7 +295,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
*/
def prepare(dataBatch: DataBatch): Unit = {
// perform bind if haven't done so
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
val bucketKey = dataBatch.bucketKey
val originalBucketKey = this._currBucketKey
this.switchBucket(bucketKey, dataBatch.provideData, Option(dataBatch.provideLabel))
@@ -308,7 +309,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* @param isTrain Default is `None`, which means `is_train` takes the value of `for_training`.
*/
override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this.switchBucket(dataBatch.bucketKey, dataBatch.provideData,
Option(dataBatch.provideLabel))
this._currModule.forward(dataBatch, isTrain)
@@ -321,14 +322,15 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* on outputs that are not a loss function.
*/
override def backward(outGrads: Array[NDArray] = null): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this._currModule.backward(outGrads)
}
// Update parameters according to the installed optimizer and the gradients computed
// in the previous forward-backward cycle.
override def update(): Unit = {
- require(binded && paramsInitialized && optimizerInitialized)
+ require(binded && paramsInitialized && optimizerInitialized,
+ "bind(), initParams() and initOptimizer() must be called first.")
this.paramsDirty = true
this._currModule.update()
}
@@ -341,7 +343,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* those `NDArray` might live on different devices.
*/
override def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this._currModule.getOutputs()
}
@@ -353,7 +355,7 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* The results will look like `[out1, out2]`
*/
override def getOutputsMerged(): IndexedSeq[NDArray] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this._currModule.getOutputsMerged()
}
@@ -365,7 +367,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* those `NDArray` might live on different devices.
*/
override def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(binded && paramsInitialized && inputsNeedGrad)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGrads() but inputsNeedGrad is false")
this._currModule.getInputGrads()
}
@@ -377,7 +380,8 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* The results will look like `[grad1, grad2]`
*/
override def getInputGradsMerged(): IndexedSeq[NDArray] = {
- require(binded && paramsInitialized && inputsNeedGrad)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGradsMerged() but inputsNeedGrad is false")
this._currModule.getInputGradsMerged()
}
@@ -387,18 +391,18 @@ class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], IndexedSeq[
* @param labels
*/
override def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
this._currModule.updateMetric(evalMetric, labels)
}
override def getSymbol: Symbol = {
- require(binded)
+ require(binded, "bind() must be called first.")
this._currModule.symbol
}
// Install monitor on all executors
override def installMonitor(monitor: Monitor): Unit = {
- require(binded)
+ require(binded, "bind() must be called first.")
for (mod <- this._buckets.values) mod.installMonitor(monitor)
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index 1494dc8..5c567fe 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -101,10 +101,10 @@ private object DataParallelExecutorGroup {
gradReq: String, argNames: IndexedSeq[String], paramNames: IndexedSeq[String],
fixedParamNames: Set[String], dataNames: Seq[String], inputsNeedGrad: Boolean)
: Map[String, String] = {
- require(argNames != null)
- require(paramNames != null)
- require(fixedParamNames != null)
- require(dataNames != null)
+ require(argNames != null, "Invalid argNames")
+ require(paramNames != null, "Invalid paramNames")
+ require(fixedParamNames != null, "Invalid fixedParamNames")
+ require(dataNames != null, "Invalid dataNames")
argNames.map(k => {
if (paramNames.contains(k)) {
(k, if (fixedParamNames.contains(k)) "null" else gradReq)
@@ -139,13 +139,13 @@ private object DataParallelExecutorGroup {
}
def setDataShapes(shapes: IndexedSeq[DataDesc]): Builder = {
- require(shapes != null)
+ require(shapes != null, "Invalid shapes")
this.dataShapes = shapes
this
}
def setDataShapesByName(shapes: IndexedSeq[(String, Shape)]): Builder = {
- require(shapes != null)
+ require(shapes != null, "Invalid shapes")
this.dataShapes = shapes.map { case (k, s) => new DataDesc(k, s) }
this
}
@@ -188,7 +188,7 @@ private object DataParallelExecutorGroup {
}
def setGradReq(gradReq: Map[String, String]): Builder = {
- require(dataShapes != null)
+ require(dataShapes != null, "dataShapes must be set first")
val gradReqTmp = mutable.HashMap.empty[String, String]
val dataNames = dataShapes.map(_.name)
for (k <- argNames) {
@@ -206,7 +206,7 @@ private object DataParallelExecutorGroup {
}
def setGradReq(gradReq: String): Builder = {
- require(dataShapes != null)
+ require(dataShapes != null, "dataShapes must be set first")
val dataNames = dataShapes.map(_.name)
this.gradReqs = Builder.convertGradReq(
gradReq, argNames, paramNames, fixedParamNames, dataNames, inputsNeedGrad)
@@ -214,7 +214,9 @@ private object DataParallelExecutorGroup {
}
def setGradReq(gradReq: Seq[(String, String)]): Builder = {
- require(gradReq.size == argNames.size)
+ require(gradReq.size == argNames.size,
+ s"provided number of gradReq (${gradReq.size}) do not match number of args " +
+ s"(${argNames.size})")
this.gradReqs = gradReq.toMap
this
}
@@ -276,8 +278,8 @@ class DataParallelExecutorGroup private[module](
fixedParamNames: Set[String] = Set.empty[String],
gradReq: Map[String, String] = null) {
- require(symbol != null)
- require(contexts != null)
+ require(symbol != null, "Undefined symbol")
+ require(contexts != null, "Undefined context")
private val argNames = symbol.listArguments()
private val auxNames = symbol.listAuxiliaryStates()
@@ -329,7 +331,7 @@ class DataParallelExecutorGroup private[module](
* the shapes for the input data or label.
*/
private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = {
- require(dataShapes.size > 0)
+ require(dataShapes.size > 0, "dataShapes must be non empty")
val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout)))
for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) {
@@ -341,7 +343,7 @@ class DataParallelExecutorGroup private[module](
s"but ${dataDesc.name} has shape ${dataDesc.shape}")
} else {
this.batchSize = batchSize
- require(this.workLoadList != null)
+ require(this.workLoadList != null, "Undefined workLoadList")
this.slices = ExecutorManager.splitInputSlice(this.batchSize, this.workLoadList)
}
}
@@ -489,9 +491,9 @@ class DataParallelExecutorGroup private[module](
DataParallelExecutorGroup.loadData(dataBatch, dataArrays, dataLayouts)
val isTrainOpt = isTrain.getOrElse(this.forTraining)
labelArrays.foreach(labels => {
- require(!isTrainOpt || dataBatch.label != null)
+ require(!isTrainOpt || dataBatch.label != null, "label must be defined if in training phase")
if (dataBatch.label != null) {
- require(labelLayouts != null)
+ require(labelLayouts != null, "label layouts are undefined")
DataParallelExecutorGroup.loadLabel(dataBatch, labels, labelLayouts)
}
})
@@ -541,7 +543,7 @@ class DataParallelExecutorGroup private[module](
* those `NDArray` might live on different devices.
*/
def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(inputsNeedGrad)
+ require(inputsNeedGrad, "Cannot get InputGrads when inputNeedGrad is set to false")
inputGradArrays.map(_.toIndexedSeq)
}
@@ -632,13 +634,17 @@ class DataParallelExecutorGroup private[module](
= dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
- require(argShapes != null, "shape inference failed")
+ require(argShapes != null, "Shape inference failed." +
+ s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
+ s"and aux states ${symbol.listAuxiliaryStates()}")
val inputTypesGot = inputTypes.getOrElse(inputShapes.map { case (k, v) =>
(k, Base.MX_REAL_TYPE)
})
val (argTypes, _, auxTypes) = symbol.inferType(inputTypesGot)
- require(argTypes != null, "type inference failed")
+ require(argTypes != null, "Type inference failed." +
+ s"Known types as $inputTypes for symbol arguments ${symbol.listArguments()} " +
+ s"and aux states ${symbol.listAuxiliaryStates()}")
val argArrays = ArrayBuffer.empty[NDArray]
val gradArrayMap = mutable.HashMap.empty[String, NDArray]
@@ -659,8 +665,12 @@ class DataParallelExecutorGroup private[module](
argArr
case Some(sharedExecInst) =>
val argArr = sharedExecInst.argDict(name)
- require(argArr.shape == argShapes(j))
- require(argArr.dtype == argTypes(j))
+ require(argArr.shape == argShapes(j),
+ s"Shape ${argArr.shape} of argument $name does not match " +
+ s"inferred shape ${argShapes(j)}")
+ require(argArr.dtype == argTypes(j),
+ s"Type ${argArr.dtype} of argument $name does not match " +
+ s"inferred type ${argTypes(j)}")
if (gradReqRun(name) != "null") {
gradArrayMap.put(name, sharedExecInst.gradDict(name))
}
@@ -687,8 +697,12 @@ class DataParallelExecutorGroup private[module](
}.toArray
case Some(sharedExecInst) =>
for ((arr, j) <- sharedExecInst.auxArrays.zipWithIndex) {
- require(auxShapes(j) == arr.shape)
- require(auxTypes(j) == arr.dtype)
+ require(auxShapes(j) == arr.shape,
+ s"Shape ${arr.shape} of aux variable ${auxNames(j)} does not match " +
+ s"inferred shape ${auxShapes(j)}")
+ require(auxTypes(j) == arr.dtype,
+ s"Type ${arr.dtype} of aux variable ${auxNames(j)} does not match " +
+ s"inferred type ${auxTypes(j)}")
}
sharedExecInst.auxArrays.map(identity)
}
@@ -729,7 +743,8 @@ class DataParallelExecutorGroup private[module](
val argArr = sharedDataArrays(name)
if (argArr.shape.product >= argShape.product) {
// nice, we can directly re-use this data blob
- require(argArr.dtype == argType)
+ require(argArr.dtype == argType,
+ s"Type ${argArr.dtype} of argument $name does not match infered type ${argType}")
argArr.reshape(argShape)
} else {
DataParallelExecutorGroup.logger.warn(s"bucketing: data $name has a shape $argShape," +
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index 9cf64b1..fec1ba0 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -45,11 +45,12 @@ class Module(symbolVar: Symbol,
fixedParamNames: Option[Set[String]] = None) extends BaseModule {
private val logger = LoggerFactory.getLogger(classOf[Module])
- require(symbolVar != null)
+ require(symbolVar != null, "Undefined symbol")
this.symbol = symbolVar
private val workLoads = workLoadList.getOrElse(contexts.map(_ => 1f).toIndexedSeq)
- require(workLoads.size == contexts.length)
+ require(workLoads.size == contexts.length,
+ s"workloads size (${workLoads.size}) do not match number of contexts ${contexts.length}")
private val labelNameList = if (labelNames == null) IndexedSeq.empty[String] else labelNames
@@ -71,17 +72,17 @@ class Module(symbolVar: Symbol,
private var labelShapesVar: Option[IndexedSeq[DataDesc]] = None
override def dataShapes: IndexedSeq[DataDesc] = {
- require(binded)
+ require(binded, "bind() must be called first.")
dataShapesVar
}
override def labelShapes: IndexedSeq[DataDesc] = {
- require(binded)
+ require(binded, "bind() must be called first.")
labelShapesVar.orNull
}
override def outputShapes: IndexedSeq[(String, Shape)] = {
- require(binded)
+ require(binded, "bind() must be called first.")
execGroup.getOutputShapes
}
@@ -93,7 +94,7 @@ class Module(symbolVar: Symbol,
* `NDArray`) mapping.
*/
override def getParams: (Map[String, NDArray], Map[String, NDArray]) = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
if (paramsDirty) {
syncParamsFromDevices()
}
@@ -253,7 +254,7 @@ class Module(symbolVar: Symbol,
this.binded = true
if (!forTraining) {
- require(!inputsNeedGrad)
+ require(!inputsNeedGrad, "Invalid inputsNeedGrad (cannot be true if not forTraining)")
} else {
// this is not True, as some module might not contains a loss function
// that consumes the labels
@@ -265,7 +266,8 @@ class Module(symbolVar: Symbol,
val sharedGroup =
sharedModule.map(sharedModuleInst => {
- require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized)
+ require(sharedModuleInst.binded && sharedModuleInst.paramsInitialized,
+ s"bind() and initParams() must be called first on shared module.")
sharedModuleInst.execGroup
})
@@ -338,7 +340,7 @@ class Module(symbolVar: Symbol,
*/
def reshape(dataShapes: IndexedSeq[DataDesc],
labelShapes: Option[IndexedSeq[DataDesc]] = None): Unit = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
val (tdataShapes, tlabelShapes) = this._parseDataDesc(
this.dataNames, this.labelNames, dataShapes, labelShapes)
this.dataShapesVar = tdataShapes
@@ -357,7 +359,7 @@ class Module(symbolVar: Symbol,
*/
def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
if (optimizerInitialized && !forceInit) {
logger.warn("optimizer already initialized, ignoring ...")
} else {
@@ -414,7 +416,8 @@ class Module(symbolVar: Symbol,
* @param sharedModule
*/
def borrowOptimizer(sharedModule: Module): Unit = {
- require(sharedModule.optimizerInitialized)
+ require(sharedModule.optimizerInitialized,
+ "initOptimizer() must be called first for shared module")
optimizer = sharedModule.optimizer
kvstore = sharedModule.kvstore
updateOnKVStore = sharedModule.updateOnKVStore
@@ -428,7 +431,7 @@ class Module(symbolVar: Symbol,
* @param isTrain Default is `None`, which means `is_train` takes the value of `for_training`.
*/
def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
val currDataShapes = this.dataShapes.map(_.shape)
val newDataShapes = dataBatch.data.map(_.shape)
if (currDataShapes != newDataShapes) {
@@ -459,20 +462,21 @@ class Module(symbolVar: Symbol,
* on outputs that are not a loss function.
*/
def backward(outGrads: Array[NDArray] = null): Unit = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
execGroup.backward(outGrads)
}
// Update parameters according to the installed optimizer and the gradients computed
// in the previous forward-backward batch.
def update(): Unit = {
- require(binded && paramsInitialized && optimizerInitialized)
+ require(binded && paramsInitialized && optimizerInitialized,
+ "bind(), initParams() and initOptimizer() must be called first.")
paramsDirty = true
if (updateOnKVStore) {
Model.updateParamsOnKVStore(execGroup.paramArrays,
execGroup.gradArrays, kvstore, execGroup.paramNames)
} else {
- require(updater != None)
+ require(updater.isDefined, "Undefined updater")
Model.updateParams(execGroup.paramArrays,
execGroup.gradArrays, updater.orNull, contexts.length, execGroup.paramNames, kvstore)
}
@@ -486,7 +490,7 @@ class Module(symbolVar: Symbol,
* those `NDArray` might live on different devices.
*/
def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
execGroup.getOutputs()
}
@@ -498,7 +502,7 @@ class Module(symbolVar: Symbol,
* The results will look like `[out1, out2]`
*/
def getOutputsMerged(): IndexedSeq[NDArray] = {
- require(binded && paramsInitialized)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
execGroup.getOutputsMerged()
}
@@ -510,7 +514,8 @@ class Module(symbolVar: Symbol,
* those `NDArray` might live on different devices.
*/
def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(binded && paramsInitialized && inputsNeedGrad)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGrads() but inputsNeedGrad is false")
execGroup.getInputGrads()
}
@@ -522,7 +527,8 @@ class Module(symbolVar: Symbol,
* The results will look like `[grad1, grad2]`
*/
def getInputGradsMerged(): IndexedSeq[NDArray] = {
- require(binded && paramsInitialized && inputsNeedGrad)
+ require(binded && paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGradsMerged() but inputsNeedGrad is false")
execGroup.getInputGradsMerged()
}
@@ -544,7 +550,7 @@ class Module(symbolVar: Symbol,
// Install monitor on all executors
def installMonitor(monitor: Monitor): Unit = {
- require(binded)
+ require(binded, "bind() must be called first.")
execGroup.installMonitor(monitor)
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
index f376b54..e75550a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/SequentialModule.scala
@@ -96,7 +96,7 @@ class SequentialModule extends BaseModule {
* @return The data shapes of the first module is the data shape of a SequentialModule.
*/
override def dataShapes: IndexedSeq[DataDesc] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this.modules.head.dataShapes
}
@@ -107,7 +107,7 @@ class SequentialModule extends BaseModule {
* training (in this case, label information is not available).
*/
override def labelShapes: IndexedSeq[DataDesc] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this.labelShapesVar.orNull
}
@@ -117,7 +117,7 @@ class SequentialModule extends BaseModule {
* module is the output shape of a SequentialModule.
*/
override def outputShapes: IndexedSeq[(String, Shape)] = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this.modules.reverse.head.outputShapes
}
@@ -127,7 +127,7 @@ class SequentialModule extends BaseModule {
* each a Map of name to parameters (in NDArray) mapping.
*/
override def getParams: (Map[String, NDArray], Map[String, NDArray]) = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
((Map[String, NDArray](), Map[String, NDArray]()) /: this.modules){ (result, module) =>
val (arg, aux) = module.getParams
(result._1 ++ arg, result._2 ++ aux)
@@ -220,7 +220,7 @@ class SequentialModule extends BaseModule {
}
if (inputsNeedGrad) {
- require(forTraining == true)
+ require(forTraining, "inputsNeedGrad can be set only for training")
}
require(sharedModule == None, "Shared module is not supported")
@@ -246,7 +246,8 @@ class SequentialModule extends BaseModule {
val myInputsNeedGrad = if (inputsNeedGrad || (forTraining && iLayer > 0)) true else false
if (meta.contains(META_AUTO_WIRING) && meta(META_AUTO_WIRING)) {
val dataNames = module.dataNames
- require(dataNames.length == myDataShapes.length)
+ require(dataNames.length == myDataShapes.length,
+ s"dataNmes $dataNames and dataShapes $myDataShapes do not match")
myDataShapes = dataNames.zip(myDataShapes).map { case (newName, dataDes) =>
DataDesc(newName, dataDes.shape)
}
@@ -276,7 +277,7 @@ class SequentialModule extends BaseModule {
*/
override def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(),
resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
if (optimizerInitialized && !forceInit) {
logger.warn("optimizer already initialized, ignoring ...")
} else {
@@ -293,7 +294,7 @@ class SequentialModule extends BaseModule {
* @param isTrain Default is `None`, which means `isTrain` takes the value of `forTraining`.
*/
override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
var data = dataBatch
for ((module, iLayer) <- this.modules.zipWithIndex) {
@@ -304,7 +305,8 @@ class SequentialModule extends BaseModule {
// need to update this, in case the internal module is using bucketing
// or whatever
val dataNames = module.outputShapes.map(_._1)
- require(dataNames.length == data.data.length)
+ require(dataNames.length == data.data.length,
+ s"dataNames $dataNames do not match with number of arrays in batch")
var provideData = ListMap[String, Shape]()
for ((name, x) <- dataNames.zip(out.head)) {
provideData += name -> x.shape
@@ -322,7 +324,7 @@ class SequentialModule extends BaseModule {
* on outputs that are not a loss function.
*/
override def backward(outGrads: Array[NDArray] = null): Unit = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
var grad = outGrads
for ((module, iLayer) <- this.modules.zipWithIndex.reverse) {
module.backward(outGrads = grad)
@@ -335,7 +337,8 @@ class SequentialModule extends BaseModule {
// Update parameters according to the installed optimizer and the gradients computed
// in the previous forward-backward batch.
override def update(): Unit = {
- require(this.binded && this.paramsInitialized && this.optimizerInitialized)
+ require(this.binded && this.paramsInitialized && this.optimizerInitialized,
+ "bind(), initParams() and initOptimizer() must be called first.")
this.modules.foreach(_.update())
}
@@ -347,7 +350,7 @@ class SequentialModule extends BaseModule {
* those `NDArray` might live on different devices.
*/
def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
this.modules.reverse.head.getOutputs()
}
@@ -359,7 +362,7 @@ class SequentialModule extends BaseModule {
* The results will look like `[out1, out2]`
*/
def getOutputsMerged(): IndexedSeq[NDArray] = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
this.modules.reverse.head.getOutputsMerged()
}
@@ -371,7 +374,8 @@ class SequentialModule extends BaseModule {
* those `NDArray` might live on different devices.
*/
def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
- require(this.binded && this.paramsInitialized && inputsNeedGrad)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGrads() but inputsNeedGrad is false")
this.modules.head.getInputGrads()
}
@@ -383,7 +387,8 @@ class SequentialModule extends BaseModule {
* The results will look like `[grad1, grad2]`
*/
def getInputGradsMerged(): IndexedSeq[NDArray] = {
- require(this.binded && this.paramsInitialized && inputsNeedGrad)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
+ require(inputsNeedGrad, "Call to getInputGradsMerged() but inputsNeedGrad is false")
this.modules.head.getInputGradsMerged()
}
@@ -393,7 +398,7 @@ class SequentialModule extends BaseModule {
* @param labels
*/
def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit = {
- require(this.binded && this.paramsInitialized)
+ require(this.binded && this.paramsInitialized, "bind() and initParams() must be called first.")
for ((meta, module) <- this.metas.zip(this.modules)) {
if (meta.contains(META_TAKE_LABELS) && meta(META_TAKE_LABELS)) {
module.updateMetric(evalMetric, labels)
@@ -403,7 +408,7 @@ class SequentialModule extends BaseModule {
// Install monitor on all executors
def installMonitor(monitor: Monitor): Unit = {
- require(this.binded)
+ require(this.binded, "bind() must be called first.")
this.modules.foreach(_.installMonitor(monitor))
}
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/DCASGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/DCASGD.scala
index 6b5053b..af804a5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/DCASGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/DCASGD.scala
@@ -71,7 +71,8 @@ class DCASGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
mon *= this.momentum
mon += monUpdated
} else {
- require(this.momentum == 0)
+ require(this.momentum == 0f,
+ s"momentum should be zero when state is provided.")
mon = monUpdated
}
previousWeight.set(weight)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/NAG.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/NAG.scala
index 47fe62d..5ed8954 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/NAG.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/NAG.scala
@@ -76,7 +76,8 @@ class NAG(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
resdGrad += momentum * mom
weight += -lr * resdGrad
} else {
- require(momentum == 0f)
+ require(momentum == 0f,
+ s"momentum should be zero when state is provided.")
// adder = -lr * (resdGrad + this.wd * weight)
// we write in this way to get rid of memory leak
val adder = this.wd * weight
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index e228e72..e20b433 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -73,7 +73,8 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
weight += mom
adder.dispose()
} else {
- require(momentum == 0f)
+ require(momentum == 0f,
+ s"momentum should be zero when state is provided.")
// adder = -lr * (resdGrad + this.wd * weight)
// we write in this way to get rid of memory leak
val adder = this.wd * weight
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
index aef4468..adeb33d 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Classifier.scala
@@ -146,11 +146,11 @@ class Classifier(modelPathPrefix: String,
private[infer] def getSynsetFilePath(modelPathPrefix: String): String = {
val dirPath = modelPathPrefix.substring(0, 1 + modelPathPrefix.lastIndexOf(File.separator))
val d = new File(dirPath)
- require(d.exists && d.isDirectory, "directory: %s not found".format(dirPath))
+ require(d.exists && d.isDirectory, s"directory: $dirPath not found")
val s = new File(dirPath + "synset.txt")
- require(s.exists() && s.isFile, "File synset.txt should exist inside modelPath: %s".format
- (dirPath + "synset.txt"))
+ require(s.exists() && s.isFile,
+ s"File synset.txt should exist inside modelPath: ${dirPath + "synset.txt"}")
s.getCanonicalPath
}
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/MXNetHandler.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/MXNetHandler.scala
index 0038045..d2bed3a 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/MXNetHandler.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/MXNetHandler.scala
@@ -39,8 +39,7 @@ private[infer] object MXNetHandlerType extends Enumeration {
private[infer] class MXNetThreadPoolHandler(numThreads: Int = 1)
extends MXNetHandler {
- require(numThreads > 0, "numThreads should be a positive number, you passed:%d".
- format(numThreads))
+ require(numThreads > 0, s"Invalid numThreads $numThreads")
private val logger = LoggerFactory.getLogger(classOf[MXNetThreadPoolHandler])
private var threadCount: Int = 0
diff --git a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index 2a4f030..3987c64 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -113,13 +113,14 @@ class Predictor(modelPathPrefix: String,
override def predict(input: IndexedSeq[Array[Float]])
: IndexedSeq[Array[Float]] = {
- require(input.length == inputDescriptors.length, "number of inputs provided: %d" +
- " does not match number of inputs in inputDescriptors: %d".format(input.length,
- inputDescriptors.length))
+ require(input.length == inputDescriptors.length,
+ s"number of inputs provided: ${input.length} does not match number of inputs " +
+ s"in inputDescriptors: ${inputDescriptors.length}")
for((i, d) <- input.zip(inputDescriptors)) {
- require (i.length == d.shape.product/batchSize, "number of elements:" +
- " %d in the input does not match the shape:%s".format( i.length, d.shape.toString()))
+ require(i.length == d.shape.product / batchSize,
+ s"number of elements:${i.length} in the input does not match the shape:" +
+ s"${d.shape.toString()}")
}
var inputND: ListBuffer[NDArray] = ListBuffer.empty[NDArray]
@@ -163,17 +164,17 @@ class Predictor(modelPathPrefix: String,
*/
override def predictWithNDArray(inputBatch: IndexedSeq[NDArray]): IndexedSeq[NDArray] = {
- require(inputBatch.length == inputDescriptors.length, "number of inputs provided: %d" +
- " do not match number of inputs in inputDescriptors: %d".format(inputBatch.length,
- inputDescriptors.length))
+ require(inputBatch.length == inputDescriptors.length,
+ s"number of inputs provided: ${inputBatch.length} do not match number " +
+ s"of inputs in inputDescriptors: ${inputDescriptors.length}")
// Shape validation, remove this when backend throws better error messages.
for((i, d) <- inputBatch.zip(iDescriptors)) {
require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
"All inputs should be of same batch size")
require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
- "Input Data Shape: %s should match the inputDescriptor shape: %s except batchSize".format(
- i.shape.toString, d.shape.toString))
+ s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
+ s"shape: ${d.shape} except batchSize")
}
val inputBatchSize = inputBatch(0).shape(batchIndex)
diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index b07e6f9..d0ebe5b 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -67,8 +67,10 @@ private[mxnet] object CToScalaUtils {
// Optional Field
if (commaRemoved.length >= 3) {
// arg: Type, optional, default = Null
- require(commaRemoved(1).equals("optional"))
- require(commaRemoved(2).startsWith("default="))
+ require(commaRemoved(1).equals("optional"),
+ s"""expected "optional" got ${commaRemoved(1)}""")
+ require(commaRemoved(2).startsWith("default="),
+ s"""expected "default=..." got ${commaRemoved(2)}""")
(typeConversion(commaRemoved(0), argType, argName, returnType), true)
} else if (commaRemoved.length == 2 || commaRemoved.length == 1) {
val tempType = typeConversion(commaRemoved(0), argType, argName, returnType)
diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNDArray.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNDArray.scala
index da078fc..a18c47d 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNDArray.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNDArray.scala
@@ -24,7 +24,7 @@ import org.apache.mxnet.NDArray
* @author Yizhi Liu
*/
class MXNDArray(@transient private var ndArray: NDArray) extends Serializable {
- require(ndArray != null)
+ require(ndArray != null, "Undefined ndArray")
private val arrayBytes: Array[Byte] = ndArray.serialize()
def get: NDArray = {
diff --git a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetModel.scala b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetModel.scala
index 2ecc2c8..2c4c8fe 100644
--- a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetModel.scala
+++ b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/MXNetModel.scala
@@ -32,9 +32,9 @@ class MXNetModel private[mxnet](
private val batchSize: Int,
private val dataName: String = "data",
private val labelName: String = "label") extends Serializable {
- require(model != null, "try to serialize an empty FeedForward model")
- require(dimension != null, "unknown dimension")
- require(batchSize > 0, s"invalid batchSize: $batchSize")
+ require(model != null, "Undefined model")
+ require(dimension != null, "Undefined dimension")
+ require(batchSize > 0, s"Invalid batchSize: $batchSize")
val serializedModel = model.serialize()
/**