You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/07/30 12:34:11 UTC

[incubator-mxnet] branch master updated: [BACKPORT] [FEATURE] Add API to control denormalized computations (#20387)

This is an automated email from the ASF dual-hosted git repository.

akarbown pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 1155c9e  [BACKPORT] [FEATURE] Add API to control denormalized computations (#20387)
1155c9e is described below

commit 1155c9e99d142877307b6423013fe5b2c43cc4cc
Author: bgawrych <ba...@intel.com>
AuthorDate: Fri Jul 30 14:31:41 2021 +0200

    [BACKPORT] [FEATURE] Add API to control denormalized computations (#20387)
    
    * [1.x] Add API to control denormalized computations
    
    * Edit name and description
    
    * Add direct imports
    
    * Edit description
    
    Co-authored-by: Andrzej Kotłowski <An...@intel.com>
    
    * Sanity & review
    
    * Return previous state of the FTZ flag
    
    * Utilize Engine::PushSync
    
    * Disable FTZ for numpy_interoperability case
    
    * Update python/mxnet/util.py
    
    Co-authored-by: Sheng Zha <sz...@users.noreply.github.com>
    
    * Add required header & fix test
    
    * Fix macro expansion
    
    * Don't include x86instrin.h when compiling with MSVC
    
    * Update documentation
    
    Co-authored-by: Andrzej Kotłowski <An...@intel.com>
    Co-authored-by: Sheng Zha <sz...@users.noreply.github.com>
---
 include/mxnet/c_api.h                              | 11 ++++
 python/mxnet/base.py                               |  2 +
 python/mxnet/util.py                               | 24 +++++++++
 src/c_api/c_api.cc                                 | 63 ++++++++++++++++++++++
 .../python/unittest/test_numpy_interoperability.py |  8 ++-
 5 files changed, 106 insertions(+), 2 deletions(-)

diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index b68765e..977e3e0 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -273,6 +273,17 @@ MXNET_DLL int MXRandomSeed(int seed);
 MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
 
 /*!
+ * \brief Change floating-point calculations when dealing with denormalized values.
+ * Currently this option is only supported in CPU backend.
+ * Flushing denormalized values to zero is enabled by default.
+ *
+ * \param value state of flush-to-zero and denormals-are-zero to set.
+ * \param prev_state state of flush-to-zero and denormals-are-zero before setting new state.
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetFlushDenorms(bool value, bool* prev_state);
+
+/*!
  * \brief Notify the engine about a shutdown,
  *  This can help engine to print less messages into display.
  *
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 1f9f37d..2e8d4b4 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -311,6 +311,8 @@ __version__ = libinfo.__version__
 # library instance of mxnet
 _LIB = _load_lib()
 
+check_call(_LIB.MXSetFlushDenorms(ctypes.c_bool(True),
+                                  ctypes.byref(ctypes.c_bool())))
 # type definitions
 mx_int = ctypes.c_int
 mx_uint = ctypes.c_uint
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index cafff0f..ea75030 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -1200,3 +1200,27 @@ def get_rtc_compile_opts(ctx):
     arch_opt = "--gpu-architecture={}_{}".format("sm" if should_compile_to_SASS else "compute",
                                                  device_cc_as_used)
     return [arch_opt]
+
+def set_flush_denorms(value):
+    """Change floating-point calculations on CPU when dealing with denormalized values.
+       This is only applicable to architectures which supports flush-to-zero.
+       Denormalized values are positive and negative values that are very close to 0
+       (exponent is the smallest possible value).
+       Flushing denormalized values to 0 can speedup calculations if such values occurs,
+       but if fulfilling whole IEEE 754 standard is required this option should be disabled.
+       Flushing denormalized values is enabled in MXNet by default.
+
+    Parameters
+    ----------
+    value : bool
+        State of flush-to-zero and denormals-are-zero in MXCSR register
+
+    Returns
+    -------
+    prev_state : bool
+        Previous state of flush-to-zero in MXCSR register
+    """
+    ret = ctypes.c_bool()
+    passed_value = ctypes.c_bool(value)
+    check_call(_LIB.MXSetFlushDenorms(passed_value, ctypes.byref(ret)))
+    return ret.value
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index eac1944..c54cc0e 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -62,6 +62,23 @@
 #include "miniz.h"
 #include "nnvm/pass_functions.h"
 
+// FTZ only applies to SSE and AVX instructions.
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+    (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+#define SUPPORT_FTZ_DMZ 1
+#else
+#define SUPPORT_FTZ_DMZ 0
+#endif
+
+#if SUPPORT_FTZ_DMZ
+#include <immintrin.h>
+#include <xmmintrin.h>
+#endif
+#if SUPPORT_FTZ_DMZ && !defined(_MSC_VER)
+#include <x86intrin.h>
+#endif
+
+
 using namespace mxnet;
 
 // Internal function to get the information
@@ -1587,6 +1604,52 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) {
   API_END();
 }
 
+int MXSetFlushDenorms(bool value, bool* prev_state) {
+  API_BEGIN();
+  *prev_state = false;
+
+  #if SUPPORT_FTZ_DMZ
+    std::function<bool()> is_dmz_flag_available = []() {
+      // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1
+      // "Checking for the DAZ Flag in the MXCSR Register"
+      constexpr unsigned int mxcsr_mask_offset = 28;
+      constexpr unsigned int dmz_flag_offset = 5;
+      constexpr unsigned int fxsave_req_bytes = 512;
+
+      char* fxsave_area_ptr = reinterpret_cast<char*>(malloc(fxsave_req_bytes));
+      memset(fxsave_area_ptr, 0, fxsave_req_bytes);  // fill memory with 0
+      _fxsave(fxsave_area_ptr);
+
+      char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset;
+      uint32_t mxcsr_mask = *(reinterpret_cast<uint32_t*>((mxcsr_mask_ptr)));
+      // DMZ flag is supported if sixth bit of MXCSR_MASK is hot
+      bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1;
+      free(fxsave_area_ptr);
+      return dmz_flag;
+    };
+
+    Engine::Get()->PushSync(
+      [value, prev_state, is_dmz_flag_available](RunContext rctx) {
+        const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF;
+        const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF;
+        *prev_state = _MM_GET_FLUSH_ZERO_MODE();
+        _MM_SET_FLUSH_ZERO_MODE(FTZ_STATE);
+
+        // If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1
+        // to it will cause a general-protection exception (#GP)
+        if (is_dmz_flag_available()) {
+          _MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE);
+        }
+      }, Context::CPU(), {}, {},
+      FnProperty::kNormal, 0, "SetFlushDenorms");
+
+    Engine::Get()->WaitForAll();
+
+  #endif
+
+  API_END();
+}
+
 int MXNotifyShutdown() {
   API_BEGIN();
   mxnet::op::custom::CustomOperator::Get()->Stop();
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index 1fa7d52..1b8fe4d 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -25,7 +25,7 @@ import itertools
 import numpy as _np
 import unittest
 import pytest
-from mxnet import np
+from mxnet import np, util
 from mxnet.test_utils import assert_almost_equal
 from mxnet.test_utils import use_np
 from mxnet.test_utils import is_op_runnable
@@ -3341,7 +3341,11 @@ def test_np_array_function_protocol():
 @with_array_ufunc_protocol
 @pytest.mark.serial
 def test_np_array_ufunc_protocol():
-    check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
+    prev_state = util.set_flush_denorms(False)
+    try:
+        check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
+    finally:
+        util.set_flush_denorms(prev_state)
 
 
 @use_np