You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2021/03/04 14:04:54 UTC

[GitHub] [incubator-mxnet] shuo-ouyang commented on a change in pull request #17952: [v1.x][KVStore]1Bit gradient compression

shuo-ouyang commented on a change in pull request #17952:
URL: https://github.com/apache/incubator-mxnet/pull/17952#discussion_r587495312



##########
File path: src/kvstore/gradient_compression.cc
##########
@@ -142,35 +169,52 @@ void GradientCompression::Dequantize(const mxnet::NDArray &from, mxnet::NDArray
   const int a = from.ctx().dev_mask();
   const int b = to->ctx().dev_mask();
   const float threshold = threshold_;
-  if (type_ == CompressionType::kTwoBit) {
-    if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
+  if (a == mshadow::cpu::kDevMask && b == mshadow::cpu::kDevMask) {
+    if (type_ == CompressionType::kOneBit) {
+      mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
+        std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
+        Dequantize1BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
+      }, from.ctx(), {from.var()}, {to->var()},
+      mxnet::FnProperty::kNormal, priority, "DequantizeCPU");
+    } else if (type_ == CompressionType::kTwoBit) {
       mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
         std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
         Dequantize2BitImpl(ctx.get_stream<mshadow::cpu>(), inputs, threshold);
       }, from.ctx(), {from.var()}, {to->var()},
       mxnet::FnProperty::kNormal, priority, "DequantizeCPU");
     } else {
+      LOG(FATAL) << "Unsupported dequantization of type " << get_type_str();
+    }
+  } else {
 #if MXNET_USE_CUDA
-      if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
+    if (a == mshadow::gpu::kDevMask && b == mshadow::gpu::kDevMask) {
+      if (type_ == CompressionType::kOneBit) {
         mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
           std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
-          Dequantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
+          Dequantize1BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
           // Wait GPU kernel to complete
           ctx.get_stream<mshadow::gpu>()->Wait();
         }, from.ctx(), {from.var()}, {to->var()},
         mxnet::FnProperty::kNormal, priority, "DequantizeGPU");
+      } else if (type_ == CompressionType::kTwoBit) {
+        mxnet::Engine::Get()->PushSync([from, to, threshold](mxnet::RunContext ctx) {
+          std::vector<mxnet::TBlob> inputs = {from.data(), to->data()};
+          Dequantize2BitImpl(ctx.get_stream<mshadow::gpu>(), inputs, threshold);
+          // Wait GPU kernel to completes
+          ctx.get_stream<mshadow::gpu>()->Wait();
+        }, from.ctx(), {from.var()}, {to->var()},
+        mxnet::FnProperty::kNormal, priority, "DequantizeGPU");
       } else {
-        LOG(FATAL) << "unknown device mask";
+        LOG(FATAL) << "Unsupported dequantization of type " << get_type_str();
       }
+    } else {
+      LOG(FATAL) << "unknown device mask";
+    }
 #else
-      LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+    LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;

Review comment:
       Thanks for your review. We will solve it as soon as possible.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org