piiswrong commented on a change in pull request #8342: [WIP] 2bit gradient 
compression
URL: https://github.com/apache/incubator-mxnet/pull/8342#discussion_r146112917
 
 

 ##########
 File path: src/ndarray/ndarray.cc
 ##########
 @@ -558,6 +558,101 @@ void CopyFromTo(const NDArray& from, const NDArray& to, 
int priority) {
   }
 }
 
+void Quantize(const NDArray &from, NDArray *to, NDArray *residual, const 
std::string& compress,
+              const float neg_threshold, const float pos_threshold,
+              int priority) {
+  CHECK(from.shape().ndim() != 0)
+      << "source operands have zero dimension shape";
+  // important: callback must always capture by value
+  NDArray ret = *to;
+  NDArray res = *residual;
+  int a = from.ctx().dev_mask();
+  int b = to->ctx().dev_mask();
+  if (a == cpu::kDevMask && b == cpu::kDevMask) {
+    if (compress == "2bit") {
+      Engine::Get()->PushSync([from, res, ret, neg_threshold, 
pos_threshold](RunContext ctx) {
+          std::vector<TBlob> inputs(3);
+          inputs[0] = from.data();
+          inputs[1] = res.data();
+          inputs[2] = ret.data();
+          mxnet::ndarray::Quantize2BitDispatch<cpu>(ctx.get_stream<cpu>(), 
inputs,
+                                                    neg_threshold, 
pos_threshold);
+        }, from.ctx(), {from.var()}, {ret.var(), res.var()},
+        FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeCPU"));
+    } else {
+      LOG(FATAL) << "Unsupported Quantization";
+    }
+  } else {
+#if MXNET_USE_CUDA
+    if (a == gpu::kDevMask && b == gpu::kDevMask) {
+      if (compress == "2bit") {
+        Engine::Get()->PushSync([from, res, ret, neg_threshold, 
pos_threshold](RunContext ctx) {
+            std::vector<TBlob> inputs(3);
+            inputs[0] = from.data();
+            inputs[1] = res.data();
+            inputs[2] = ret.data();
+            mxnet::ndarray::Quantize2BitDispatch<gpu>(ctx.get_stream<gpu>(), 
inputs,
+                                                      neg_threshold, 
pos_threshold);
+            // Wait GPU kernel to complete
+            ctx.get_stream<gpu>()->Wait();
+          }, from.ctx(), {from.var()}, {ret.var(), res.var()},
+          FnProperty::kNormal, priority, PROFILER_MESSAGE("QuantizeGPU"));
+        } else {
+          LOG(FATAL) << "Unsupported Quantization";
+        }
+    } else {
+      LOG(FATAL) << "unknown device mask";
+    }
+#else
+    LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
+#endif
+  }
+}
+
+void Dequantize(const NDArray &from, NDArray *to, const std::string& compress, 
int priority) {
+  CHECK(from.shape().ndim() != 0)
+    << "source operands have zero dimension shape";
+  // important: callback must always capture by value
+  NDArray ret = *to;
+  int a = from.ctx().dev_mask();
+  int b = to->ctx().dev_mask();
+  if (a == cpu::kDevMask && b == cpu::kDevMask) {
+    if (compress == "2bit") {
+      Engine::Get()->PushSync([from, ret](RunContext ctx) {
+        std::vector<TBlob> inputs(2);
+        inputs[0] = from.data();
+        inputs[1] = ret.data();
+        mxnet::ndarray::Dequantize2BitDispatch<cpu>(ctx.get_stream<cpu>(), 
inputs);
+      }, from.ctx(), {from.var()}, {ret.var()},
+      FnProperty::kNormal, priority, PROFILER_MESSAGE("DequantizeCPU"));
+    } else {
+      LOG(FATAL) << "Unsupported dequantization " << compress << std::endl;
+    }
+  } else {
+#if MXNET_USE_CUDA
+    if (a == gpu::kDevMask && b == gpu::kDevMask) {
+      if (compress == "2bit") {
 
 Review comment:
   Also I doubt this should be a ndarray method.
   It's more clean to abstract stuff related to gradient compression into a 
separate class

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to