You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2021/04/02 17:21:40 UTC

[incubator-mxnet] branch master updated: [FFI] fix masked_softmax (#20114)

This is an automated email from the ASF dual-hosted git repository.

lausen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 78d7b1d  [FFI] fix masked_softmax (#20114)
78d7b1d is described below

commit 78d7b1d209c8b5427071d75a9cb8618cf585226b
Author: barry-jin <69...@users.noreply.github.com>
AuthorDate: Fri Apr 2 10:19:06 2021 -0700

    [FFI] fix masked_softmax (#20114)
---
 python/mxnet/ndarray/numpy_extension/_op.py        |  38 ++------
 python/mxnet/numpy_extension/_op.py                |  18 ++--
 src/api/operator/numpy_extension/npx_softmax_op.cc | 100 ++++++++++++++++++++-
 src/operator/nn/softmax-inl.h                      |  10 ++-
 tests/python/unittest/test_numpy_op.py             |  82 ++++++++++++++---
 5 files changed, 190 insertions(+), 58 deletions(-)

diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py
index 718022d..1ff9c2d 100644
--- a/python/mxnet/ndarray/numpy_extension/_op.py
+++ b/python/mxnet/ndarray/numpy_extension/_op.py
@@ -19,7 +19,6 @@
 used in Gluon dispatched by F=ndarray module."""
 
 import numpy as _np
-from .. import numpy as np  # pylint: disable=reimported
 from .._internal import NDArrayBase
 from . import _api_internal
 from ...util import set_module
@@ -134,7 +133,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False,
 
 # pylint: disable=too-many-arguments
 @set_module('mxnet.ndarray.numpy_extension')
-def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
     r"""Applies the softmax function masking elements according to the mask provided
 
     Parameters
@@ -147,9 +146,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
         The axis along which to compute softmax.
     temperature : double or None, optional, default=None
         Temperature parameter in softmax
-    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
-        DType of the output in case this can't be inferred. Defaults to
-        the same as input's dtype if not defined (dtype=None).
     normalize : boolean or None, optional, default=1
         Whether to normalize input data x: x = x - max(x)
 
@@ -167,22 +163,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
     >>> data = np.arange(10).reshape((2, 5))
     >>> npx.masked_softmax(data, mask, axis=0)
     array([[0.00669285, 0.        , 0.00669285, 0.        , 0.00669285],
-        [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
+           [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
     """
-    if mask is not None:
-        neg = -1e18
-        if _np.dtype(dtype) == _np.float16:
-            neg = -1e4
-        data = np.where(mask, data, neg)
-        logits = (softmax(data, axis=axis) / temperature) * mask
-    else:
-        logits = softmax(data, axis=axis) / temperature
-    return logits
+    assert data is not None and mask is not None, "Missing input data and mask"
+    return _api_internal.masked_softmax(data, mask, axis, temperature, normalize)
 
 
 # pylint: disable=too-many-arguments
 @set_module('mxnet.ndarray.numpy_extension')
-def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
     r"""Computes the masked log softmax of the input.
     This is equivalent to computing masked softmax followed by log.
 
@@ -196,9 +185,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
         The axis along which to compute softmax.
     temperature : double or None, optional, default=None
         Temperature parameter in softmax
-    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
-        DType of the output in case this can't be inferred. Defaults to
-        the same as input's dtype if not defined (dtype=None).
     normalize : boolean or None, optional, default=1
         Whether to normalize input data x: x = x - max(x)
 
@@ -216,18 +202,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
     >>> data = np.arange(10).reshape((2, 5))
     >>> npx.masked_log_softmax(data, mask, axis=0)
     array([[-5.0067153 ,        -inf, -5.0067153 ,        -inf, -5.0067153 ],
-       [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
+           [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
     """
-    if mask is not None:
-        neg = -1e18
-        inf = -_np.inf
-        if _np.dtype(dtype) == _np.float16:
-            neg = -1e4
-        data = np.where(mask, data, neg)
-        logits = np.where(mask, log_softmax(data, axis=axis) / temperature, inf)
-    else:
-        logits = log_softmax(data, axis=axis) / temperature
-    return logits
+    assert data is not None and mask is not None, "Missing input data and mask"
+    return _api_internal.masked_log_softmax(data, mask, axis, temperature, normalize)
 
 
 # pylint: disable=too-many-arguments, unused-argument
diff --git a/python/mxnet/numpy_extension/_op.py b/python/mxnet/numpy_extension/_op.py
index b7d75ff..124eb00 100644
--- a/python/mxnet/numpy_extension/_op.py
+++ b/python/mxnet/numpy_extension/_op.py
@@ -118,7 +118,7 @@ def log_softmax(data, axis=-1, length=None, temperature=None, use_length=False,
 
 # pylint: disable=too-many-arguments
 @set_module('mxnet.numpy_extension')
-def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+def masked_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
     r"""Applies the softmax function masking elements according to the mask provided
 
     Parameters
@@ -131,9 +131,6 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
         The axis along which to compute softmax.
     temperature : double or None, optional, default=None
         Temperature parameter in softmax
-    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
-        DType of the output in case this can't be inferred. Defaults to
-        the same as input's dtype if not defined (dtype=None).
     normalize : boolean or None, optional, default=1
         Whether to normalize input data x: x = x - max(x)
 
@@ -151,15 +148,15 @@ def masked_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
     >>> data = np.arange(10).reshape((2, 5))
     >>> npx.masked_softmax(data, mask, axis=0)
     array([[0.00669285, 0.        , 0.00669285, 0.        , 0.00669285],
-        [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
+           [0.9933072 , 0.        , 0.9933072 , 0.        , 0.9933072 ]])
     """
     return _mx_nd_npx.masked_softmax(data, mask, axis=axis, temperature=temperature,
-                                     dtype=dtype)
+                                     normalize=normalize)
 
 
 # pylint: disable=too-many-arguments
 @set_module('mxnet.numpy_extension')
-def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
+def masked_log_softmax(data, mask, axis=-1, temperature=1.0, normalize=True):
     r"""Computes the masked log softmax of the input.
     This is equivalent to computing masked softmax followed by log.
 
@@ -173,9 +170,6 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
         The axis along which to compute softmax.
     temperature : double or None, optional, default=None
         Temperature parameter in softmax
-    dtype : {None, 'float16', 'float32', 'float64'},optional, default='None'
-        DType of the output in case this can't be inferred. Defaults to
-        the same as input's dtype if not defined (dtype=None).
     normalize : boolean or None, optional, default=1
         Whether to normalize input data x: x = x - max(x)
 
@@ -193,10 +187,10 @@ def masked_log_softmax(data, mask, axis=-1, temperature=1.0, dtype=None):
     >>> data = np.arange(10).reshape((2, 5))
     >>> npx.masked_log_softmax(data, mask, axis=0)
     array([[-5.0067153 ,        -inf, -5.0067153 ,        -inf, -5.0067153 ],
-       [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
+           [-0.00671535,        -inf, -0.00671535,        -inf, -0.00671535]])
     """
     return _mx_nd_npx.masked_log_softmax(data, mask, axis=axis, temperature=temperature,
-                                         dtype=dtype)
+                                         normalize=normalize)
 
 
 # pylint: disable=too-many-arguments, unused-argument
diff --git a/src/api/operator/numpy_extension/npx_softmax_op.cc b/src/api/operator/numpy_extension/npx_softmax_op.cc
index 641129e..6e934ed 100644
--- a/src/api/operator/numpy_extension/npx_softmax_op.cc
+++ b/src/api/operator/numpy_extension/npx_softmax_op.cc
@@ -53,15 +53,17 @@ MXNET_REGISTER_API("_npx.softmax")
   // parse axis
   if (args[args_size - 4].type_code() == kDLInt) {
     param.axis = args[args_size - 4].operator int();
-  } else {
+  } else if (args[args_size - 4].type_code() == kDLFloat) {
     param.axis = static_cast<int>(args[args_size - 4].operator double());
+  } else {
+    param.axis = -1;
   }
 
   // parse temperature
   if (args[args_size - 3].type_code() == kNull) {
     param.temperature = dmlc::nullopt;
   } else {
-    param.temperature = args[args_size - 3].operator int64_t();
+    param.temperature = args[args_size - 3].operator double();
   }
 
   // parse dtype
@@ -106,15 +108,17 @@ MXNET_REGISTER_API("_npx.log_softmax")
   // parse axis
   if (args[args_size - 4].type_code() == kDLInt) {
     param.axis = args[args_size - 4].operator int();
-  } else {
+  } else if (args[args_size - 4].type_code() == kDLFloat) {
     param.axis = static_cast<int>(args[args_size - 4].operator double());
+  } else {
+    param.axis = -1;
   }
 
   // parse temperature
   if (args[args_size - 3].type_code() == kNull) {
     param.temperature = dmlc::nullopt;
   } else {
-    param.temperature = args[args_size - 3].operator int64_t();
+    param.temperature = args[args_size - 3].operator double();
   }
 
   // parse dtype
@@ -133,4 +137,92 @@ MXNET_REGISTER_API("_npx.log_softmax")
   *ret = ndoutputs[0];
 });
 
+MXNET_REGISTER_API("_npx.masked_softmax")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  static const nnvm::Op* op = Op::Get("_npx_masked_softmax");
+  op::MaskedSoftmaxParam param;
+
+  // inputs
+  int num_inputs = 2;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  // parse axis
+  if (args[2].type_code() == kDLInt) {
+    param.axis = args[2].operator int();
+  } else if (args[2].type_code() == kDLFloat) {
+    param.axis = static_cast<int>(args[2].operator double());
+  } else {
+    param.axis = -1;
+  }
+  // parse temperature
+  if (args[3].type_code() == kNull) {
+    param.temperature = dmlc::nullopt;
+  } else {
+    param.temperature = args[3].operator double();
+  }
+  // parse normalize
+  if (args[4].type_code() == kNull) {
+    param.normalize = true;
+  } else {
+    param.normalize = args[4].operator bool();
+  }
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::MaskedSoftmaxParam>(&attrs);
+
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
+MXNET_REGISTER_API("_npx.masked_log_softmax")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  nnvm::NodeAttrs attrs;
+  static const nnvm::Op* op = Op::Get("_npx_masked_log_softmax");
+  op::MaskedSoftmaxParam param;
+
+  // inputs
+  int num_inputs = 2;
+  std::vector<NDArray*> inputs;
+  inputs.reserve(num_inputs);
+  for (int i = 0; i < num_inputs; ++i) {
+    inputs.push_back(args[i].operator mxnet::NDArray*());
+  }
+  // parse axis
+  if (args[2].type_code() == kDLInt) {
+    param.axis = args[2].operator int();
+  } else if (args[2].type_code() == kDLFloat) {
+    param.axis = static_cast<int>(args[2].operator double());
+  } else {
+    param.axis = -1;
+  }
+  // parse temperature
+  if (args[3].type_code() == kNull) {
+    param.temperature = dmlc::nullopt;
+  } else {
+    param.temperature = args[3].operator double();
+  }
+  // parse normalize
+  if (args[4].type_code() == kNull) {
+    param.normalize = true;
+  } else {
+    param.normalize = args[4].operator bool();
+  }
+
+  attrs.parsed = param;
+  attrs.op = op;
+  SetAttrDict<op::MaskedSoftmaxParam>(&attrs);
+
+  int num_outputs = 0;
+  auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs.data(), &num_outputs, nullptr);
+  *ret = ndoutputs[0];
+});
+
 }  // namespace mxnet
diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h
index 7f64b74..3f037f9 100644
--- a/src/operator/nn/softmax-inl.h
+++ b/src/operator/nn/softmax-inl.h
@@ -1199,7 +1199,6 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
 struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
   int axis;
   dmlc::optional<double> temperature;
-  dmlc::optional<int> dtype;
   dmlc::optional<bool> normalize;
   DMLC_DECLARE_PARAMETER(MaskedSoftmaxParam) {
     DMLC_DECLARE_FIELD(axis).set_default(-1)
@@ -1210,6 +1209,15 @@ struct MaskedSoftmaxParam : public dmlc::Parameter<MaskedSoftmaxParam> {
     .set_default(dmlc::optional<bool>(true))
     .describe("Whether to normalize input data x: x = x - max(x)");
   }
+  void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
+    std::ostringstream axis_s, temperature_s, normalize_s;
+    axis_s << axis;
+    temperature_s << temperature;
+    normalize_s << normalize;
+    (*dict)["axis"] = axis_s.str();
+    (*dict)["temperature"] = temperature_s.str();
+    (*dict)["normalize"] = normalize_s.str();
+  }
 };
 
 static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 6bea510..83ef73a 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -1931,6 +1931,18 @@ def test_npx_batch_norm(shape, fix_gamma, cudnn_off, output_mean_var):
                     _test_batchnorm_impl(axis,
                         data_grad_req, gamma_grad_req, beta_grad_req)
 
+
+def np_softmax(x, axis=-1):
+    if (x.shape[axis] == 0):
+        return _np.sum(x, axis=axis, keepdims=True)
+    x = x - _np.max(x, axis=axis, keepdims=True)
+    x = _np.exp(x)
+    x /= _np.sum(x, axis=axis, keepdims=True)
+    return x
+
+def np_log_softmax(x, axis=-1):
+    return _np.log(np_softmax(x, axis))
+
 @use_np
 def test_npx_softmax():
     class TestSoftmax(HybridBlock):
@@ -1949,17 +1961,6 @@ def test_npx_softmax():
         def hybrid_forward(self, F, a):
             return F.npx.log_softmax(a, axis=axis)
 
-    def np_softmax(x, axis=-1):
-        if (x.shape[axis] == 0):
-            return _np.sum(x, axis=axis, keepdims=True)
-        x = x - _np.max(x, axis=axis, keepdims=True)
-        x = _np.exp(x)
-        x /= _np.sum(x, axis=axis, keepdims=True)
-        return x
-
-    def np_log_softmax(x, axis=-1):
-        return _np.log(np_softmax(x, axis))
-
     #(operator, function) tuples
     tested_ops = [(TestSoftmax, np_softmax),
                   (TestLogSoftmax, np_log_softmax)]
@@ -1988,6 +1989,65 @@ def test_npx_softmax():
                     assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)
 
 
+def np_masked_softmax(data, mask, axis=-1, temperature=1.0):
+    neg = -1e18
+    if data.dtype == _np.float16:
+        neg = -1e4
+    temp = _np.where(mask, data, neg)
+    result = (np_softmax(temp, axis=axis) / temperature) * mask
+    return result
+
+def np_masked_log_softmax(data, mask, axis=-1, temperature=1.0):
+    neg = -1e18
+    if data.dtype == _np.float16:
+        neg = -1e4
+    data = _np.where(mask, data, neg)
+    return _np.where(mask, np_log_softmax(data, axis=axis) / temperature, -_np.inf)
+
+@use_np
+@pytest.mark.parametrize('hybridize', [True, False])
+@pytest.mark.parametrize('shape', [(3, 0, 4), (0, 0)])
+def test_npx_masked_softmax(hybridize, shape):
+    class TestMaskedSoftmax(HybridBlock):
+        def __init__(self, axis):
+            super(TestMaskedSoftmax, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a, mask):
+            return F.npx.masked_softmax(a, mask, axis=self._axis)
+
+    class TestMaskedLogSoftmax(HybridBlock):
+        def __init__(self, axis):
+            super(TestMaskedLogSoftmax, self).__init__()
+            self._axis = axis
+
+        def hybrid_forward(self, F, a, mask):
+            return F.npx.masked_log_softmax(a, mask, axis=self._axis)
+
+    #(operator, function) tuples
+    tested_ops = [(TestMaskedSoftmax, np_masked_softmax),
+                  (TestMaskedLogSoftmax, np_masked_log_softmax)]
+
+    # only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
+    for SoftmaxOp, softmax_function in tested_ops:
+        mx_a = np.random.uniform(size=shape)
+        mask = np.random.randint(0, 2, shape)
+        mx_a.attach_grad()
+        mask.attach_grad()
+        for axis in range(-len(shape), len(shape)):
+            test_softmax_op = SoftmaxOp(axis)
+            if hybridize:
+                test_softmax_op.hybridize()
+
+            with mx.autograd.record():
+                mx_out = test_softmax_op(mx_a, mask)
+
+            mx_out.wait_to_read()
+
+            np_out = softmax_function(mx_a.asnumpy(), mask.asnumpy(), axis)
+            assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)
+
+
 @use_np
 def test_npi_boolean_assign():
     class TestBooleanAssignScalar(HybridBlock):