You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ak...@apache.org on 2021/07/28 09:23:31 UTC
[incubator-mxnet] branch v1.x updated: [1.x] [FEATURE] Add API to
control denormalized computations (#20338)
This is an automated email from the ASF dual-hosted git repository.
akarbown pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new aeff388 [1.x] [FEATURE] Add API to control denormalized computations (#20338)
aeff388 is described below
commit aeff38880709b91fd1cb35a6d5a5e6c2c074d7ab
Author: bgawrych <ba...@intel.com>
AuthorDate: Wed Jul 28 11:21:18 2021 +0200
[1.x] [FEATURE] Add API to control denormalized computations (#20338)
* [1.x] Add API to control denormalized computations
* Edit name and description
* Add direct imports
* Edit description
Co-authored-by: Andrzej Kotłowski <An...@intel.com>
* Sanity & review
* Return previous state of the FTZ flag
* Utilize Engine::PushSync
* Disable FTZ for numpy_interoperability case
* Update python/mxnet/util.py
Co-authored-by: Sheng Zha <sz...@users.noreply.github.com>
* Add required header & fix test
* Fix macro expansion
* Don't include x86instrin.h when compiling with MSVC
* Update documentation
* Remove added empty line
Co-authored-by: Andrzej Kotłowski <An...@intel.com>
Co-authored-by: Sheng Zha <sz...@users.noreply.github.com>
---
include/mxnet/c_api.h | 11 ++++
python/mxnet/base.py | 2 +
python/mxnet/util.py | 24 +++++++++
src/c_api/c_api.cc | 63 ++++++++++++++++++++++
.../python/unittest/test_numpy_interoperability.py | 11 ++--
5 files changed, 108 insertions(+), 3 deletions(-)
diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h
index 98a7a70..6e441e2 100644
--- a/include/mxnet/c_api.h
+++ b/include/mxnet/c_api.h
@@ -262,6 +262,17 @@ MXNET_DLL int MXRandomSeed(int seed);
MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
/*!
+ * \brief Change floating-point calculations when dealing with denormalized values.
+ * Currently this option is only supported in CPU backend.
+ * Flushing denormalized values to zero is enabled by default.
+ *
+ * \param value state of flush-to-zero and denormals-are-zero to set.
+ * \param prev_state state of flush-to-zero and denormals-are-zero before setting new state.
+ * \return 0 when success, -1 when failure happens.
+ */
+MXNET_DLL int MXSetFlushDenorms(bool value, bool* prev_state);
+
+/*!
* \brief Notify the engine about a shutdown,
* This can help engine to print less messages into display.
*
diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 496f1f5..ca98116 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -350,6 +350,8 @@ __version__ = libinfo.__version__
# library instance of mxnet
_LIB = _load_lib()
+check_call(_LIB.MXSetFlushDenorms(ctypes.c_bool(True),
+ ctypes.byref(ctypes.c_bool())))
# type definitions
mx_int = ctypes.c_int
mx_uint = ctypes.c_uint
diff --git a/python/mxnet/util.py b/python/mxnet/util.py
index aabd5fe..d0584b3 100644
--- a/python/mxnet/util.py
+++ b/python/mxnet/util.py
@@ -848,3 +848,27 @@ def setenv(name, value):
"""
passed_value = None if value is None else c_str(value)
check_call(_LIB.MXSetEnv(c_str(name), passed_value))
+
+def set_flush_denorms(value):
+ """Change floating-point calculations on CPU when dealing with denormalized values.
+ This is only applicable to architectures which supports flush-to-zero.
+ Denormalized values are positive and negative values that are very close to 0
+ (exponent is the smallest possible value).
+ Flushing denormalized values to 0 can speedup calculations if such values occurs,
+ but if fulfilling whole IEEE 754 standard is required this option should be disabled.
+ Flushing denormalized values is enabled in MXNet by default.
+
+ Parameters
+ ----------
+ value : bool
+ State of flush-to-zero and denormals-are-zero in MXCSR register
+
+ Returns
+ -------
+ prev_state : bool
+ Previous state of flush-to-zero in MXCSR register
+ """
+ ret = ctypes.c_bool()
+ passed_value = ctypes.c_bool(value)
+ check_call(_LIB.MXSetFlushDenorms(passed_value, ctypes.byref(ret)))
+ return ret.value
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 1cb5583..bcbdab1 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -59,6 +59,23 @@
#include "../common/utils.h"
#include "nnvm/pass_functions.h"
+// FTZ only applies to SSE and AVX instructions.
+#if defined(__SSE__) || defined(__x86_64__) || defined(_M_X64) || \
+ (defined(_M_IX86_FP) && _M_IX86_FP >= 1)
+#define SUPPORT_FTZ_DMZ 1
+#else
+#define SUPPORT_FTZ_DMZ 0
+#endif
+
+#if SUPPORT_FTZ_DMZ
+#include <immintrin.h>
+#include <xmmintrin.h>
+#endif
+#if SUPPORT_FTZ_DMZ && !defined(_MSC_VER)
+#include <x86intrin.h>
+#endif
+
+
using namespace mxnet;
// Internal function to get the information
@@ -1573,6 +1590,52 @@ int MXRandomSeedContext(int seed, int dev_type, int dev_id) {
API_END();
}
+int MXSetFlushDenorms(bool value, bool* prev_state) {
+ API_BEGIN();
+ *prev_state = false;
+
+ #if SUPPORT_FTZ_DMZ
+ std::function<bool()> is_dmz_flag_available = []() {
+ // Intel 64 and IA-32 Architectures Software Developer’s Manual: Vol. 1
+ // "Checking for the DAZ Flag in the MXCSR Register"
+ constexpr unsigned int mxcsr_mask_offset = 28;
+ constexpr unsigned int dmz_flag_offset = 5;
+ constexpr unsigned int fxsave_req_bytes = 512;
+
+ char* fxsave_area_ptr = reinterpret_cast<char*>(malloc(fxsave_req_bytes));
+ memset(fxsave_area_ptr, 0, fxsave_req_bytes); // fill memory with 0
+ _fxsave(fxsave_area_ptr);
+
+ char* mxcsr_mask_ptr = fxsave_area_ptr + mxcsr_mask_offset;
+ uint32_t mxcsr_mask = *(reinterpret_cast<uint32_t*>((mxcsr_mask_ptr)));
+ // DMZ flag is supported if sixth bit of MXCSR_MASK is hot
+ bool dmz_flag = (mxcsr_mask >> dmz_flag_offset) & 0x1;
+ free(fxsave_area_ptr);
+ return dmz_flag;
+ };
+
+ Engine::Get()->PushSync(
+ [value, prev_state, is_dmz_flag_available](RunContext rctx) {
+ const unsigned int DMZ_STATE = value ? _MM_DENORMALS_ZERO_ON : _MM_DENORMALS_ZERO_OFF;
+ const unsigned int FTZ_STATE = value ? _MM_FLUSH_ZERO_ON : _MM_FLUSH_ZERO_OFF;
+ *prev_state = _MM_GET_FLUSH_ZERO_MODE();
+ _MM_SET_FLUSH_ZERO_MODE(FTZ_STATE);
+
+ // If the DAZ flag is not supported, then it is a reserved bit and attempting to write a 1
+ // to it will cause a general-protection exception (#GP)
+ if (is_dmz_flag_available()) {
+ _MM_SET_DENORMALS_ZERO_MODE(DMZ_STATE);
+ }
+ }, Context::CPU(), {}, {},
+ FnProperty::kNormal, 0, "SetFlushDenorms");
+
+ Engine::Get()->WaitForAll();
+
+ #endif
+
+ API_END();
+}
+
int MXNotifyShutdown() {
API_BEGIN();
mxnet::op::custom::CustomOperator::Get()->Stop();
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index fd8abf1..cae2d46 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -22,9 +22,10 @@ from distutils.version import StrictVersion
import sys
import platform
import itertools
-import numpy as _np
import unittest
-from mxnet import np
+
+from mxnet import np, util
+import numpy as _np
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
@@ -3075,7 +3076,11 @@ def test_np_array_function_protocol():
@use_np
@with_array_ufunc_protocol
def test_np_array_ufunc_protocol():
- check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
+ prev_state = util.set_flush_denorms(False)
+ try:
+ check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)
+ finally:
+ util.set_flush_denorms(prev_state)
@with_seed()