Repository: incubator-singa
Updated Branches:
  refs/heads/master 4a0db51f5 -> 0cb6cc222


SINGA-116 Fix a bug in InnerProductLayer caused by weight matrix sharing

Fix the bug by considering the tranpose of weight matrix explicitly.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/cf6ef824
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/cf6ef824
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/cf6ef824

Branch: refs/heads/master
Commit: cf6ef8247420b0869b1f93613ad47420f73ee2d4
Parents: 4a0db51
Author: Wei Wang <[email protected]>
Authored: Sun Dec 27 22:34:57 2015 +0800
Committer: Wei Wang <[email protected]>
Committed: Mon Dec 28 11:49:20 2015 +0800

----------------------------------------------------------------------
 src/neuralnet/neuron_layer/inner_product.cc | 15 ++++++++++++---
 src/utils/param.cc                          |  2 +-
 2 files changed, 13 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf6ef824/src/neuralnet/neuron_layer/inner_product.cc
----------------------------------------------------------------------
diff --git a/src/neuralnet/neuron_layer/inner_product.cc 
b/src/neuralnet/neuron_layer/inner_product.cc
index 7e6318d..f50afba 100644
--- a/src/neuralnet/neuron_layer/inner_product.cc
+++ b/src/neuralnet/neuron_layer/inner_product.cc
@@ -57,7 +57,10 @@ void InnerProductLayer::Setup(const LayerProto& conf,
 
 void InnerProductLayer::ComputeFeature(int flag,
     const vector<Layer*>& srclayers) {
-  MMDot(srclayers[0]->data(this), weight_->data().T(), &data_);
+  if (transpose_)
+    MMDot(srclayers[0]->data(this), weight_->data(), &data_);
+  else
+    MMDot(srclayers[0]->data(this), weight_->data().T(), &data_);
   MVAddRow(bias_->data(), &data_);
 }
 
@@ -65,9 +68,15 @@ void InnerProductLayer::ComputeGradient(int flag,
     const vector<Layer*>& srclayers) {
 
   MVSumRow(1.0f, 0.0f, grad_, bias_->mutable_grad());
-  MMDot(grad_.T(), srclayers[0]->data(this), weight_->mutable_grad());
+  if (transpose_)
+    MMDot(srclayers[0]->data(this).T(), grad_, weight_->mutable_grad());
+  else
+    MMDot(grad_.T(), srclayers[0]->data(this), weight_->mutable_grad());
   if (srclayers[0]->mutable_grad(this) != nullptr) {
-    MMDot(grad_, weight_->data(), srclayers[0]->mutable_grad(this));
+    if (transpose_)
+      MMDot(grad_, weight_->data().T(), srclayers[0]->mutable_grad(this));
+    else
+      MMDot(grad_, weight_->data(), srclayers[0]->mutable_grad(this));
   }
 }
 }  // namespace singa

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/cf6ef824/src/utils/param.cc
----------------------------------------------------------------------
diff --git a/src/utils/param.cc b/src/utils/param.cc
index 70a969f..bdae72f 100644
--- a/src/utils/param.cc
+++ b/src/utils/param.cc
@@ -168,7 +168,7 @@ void Param::InitValues(int version) {
 
 void Param::ShareFrom(Param* other, bool cpu_only) {
   proto_.set_owner(other->owner());
-  CHECK(data_.shape() == other->data_.shape());
+  CHECK_EQ(data_.count(), other->data_.count());
   data_.ShareData(&(other->data_), cpu_only);
   if (grad_.count() == 0)
     grad_.Reshape(data_.shape());

Reply via email to