You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by sk...@apache.org on 2019/01/25 18:36:49 UTC

[incubator-mxnet] branch master updated: Added optional parameters to BilinearResize2D to do relative scaling (#13985)

This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 577275d  Added optional parameters to BilinearResize2D to do relative scaling (#13985)
577275d is described below

commit 577275dc5bc379cc885eba6fb5ed03e00d1e72c0
Author: Istvan Fehervari <go...@gmail.com>
AuthorDate: Fri Jan 25 10:36:33 2019 -0800

    Added optional parameters to BilinearResize2D to do relative scaling (#13985)
    
    * Added optional parameters to BilinearResize2D to do relative scaling
    
    * Removed unnecessary params in unit tests.
    
    * Fixed deprecated casting style
---
 src/operator/contrib/bilinear_resize-inl.h | 28 ++++++++++++++++++++++------
 tests/python/unittest/test_operator.py     |  5 +++++
 2 files changed, 27 insertions(+), 6 deletions(-)

diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h
index ff3f794..5a653d8 100644
--- a/src/operator/contrib/bilinear_resize-inl.h
+++ b/src/operator/contrib/bilinear_resize-inl.h
@@ -50,11 +50,17 @@ namespace op {
 struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
   int height;
   int width;
+  dmlc::optional<float> scale_height;
+  dmlc::optional<float> scale_width;
   DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
-    DMLC_DECLARE_FIELD(height).set_range(1, 10000)
-    .describe("output height (required)");
-    DMLC_DECLARE_FIELD(width).set_range(1, 10000)
-    .describe("output width (required)");
+    DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000)
+    .describe("output height (required, but ignored if scale_height is defined)");
+    DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000)
+    .describe("output width (required, but ignored if scale_width is defined)");
+    DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional<float>())
+    .describe("sampling scale of the height (optional, ignores height if defined)");
+    DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional<float>())
+    .describe("sampling scale of the scale_width (optional, ignores width if defined)");
   }
 };
 
@@ -129,8 +135,18 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
   const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
   TShape dshape(in_shape->at(0));
   if (dshape.ndim() == 0) return false;
-  dshape[2] = param.height;
-  dshape[3] = param.width;
+  if (param.scale_height.has_value()) {
+    dshape[2] = static_cast<int>(param.scale_height.value() * in_shape->at(0)[2]);
+  } else {
+    dshape[2] = param.height;
+  }
+
+  if (param.scale_height.has_value()) {
+    dshape[3] = static_cast<int>(param.scale_width.value() * in_shape->at(0)[3]);
+  } else {
+    dshape[3] = param.width;
+  }
+
   out_shape->clear();
   out_shape->push_back(dshape);
   return true;
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 67aeddf..3f34ade 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -6533,6 +6533,11 @@ def test_bilinear_resize_op():
         x = mx.nd.random.uniform(shape=shape)
         y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width)
         assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width))
+
+        x_scale = width / shape[-1]
+        y_scale = height / shape[-2]
+        y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale)
+        assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width))
     shape = (2, 2, 10, 10)
     check_bilinear_resize_op(shape, 5, 5)
     check_bilinear_resize_op(shape, 10, 10)