Repository: incubator-singa Updated Branches: refs/heads/dev 26df5ac03 -> 21e4b2d79
SINGA-184 Add Cross Entropy loss computation Implement Cross Entropy loss Pass cpplint.py, test pass compilation Todo: check test Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/efd7b627 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/efd7b627 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/efd7b627 Branch: refs/heads/dev Commit: efd7b627bacb4acd6a3322468350f2b5399f725b Parents: 3e2507b Author: kaiping <[email protected]> Authored: Fri May 27 12:09:30 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Tue May 31 22:14:09 2016 +0800 ---------------------------------------------------------------------- src/model/loss/cross_entropy.h | 105 ++++++++++++++++++++++++++++++++++ test/singa/test_cross_entropy.cc | 66 +++++++++++++++++++++ 2 files changed, 171 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/efd7b627/src/model/loss/cross_entropy.h ---------------------------------------------------------------------- diff --git a/src/model/loss/cross_entropy.h b/src/model/loss/cross_entropy.h new file mode 100644 index 0000000..815b795 --- /dev/null +++ b/src/model/loss/cross_entropy.h @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef SRC_MODEL_LOSS_CROSS_ENTROPY_H_ +#define SRC_MODEL_LOSS_CROSS_ENTROPY_H_ +#include <stack> +#include "singa/model/loss.h" + +namespace singa { + +/// Cross entropy is for cross entropy loss. +class CrossEntropy : public Loss<Tensor> { + public: + /// Compute the loss values for each sample/instance given the prediction + /// and the target, which is sum {-log(prob_of_truth)} + /// Users can call Average(const Tensor&) to get the average + /// loss value over all samples in the batch. + Tensor Forward(const Tensor& prediction, const Tensor& target) override; + + /// Compute the gradients of the loss values w.r.t. the prediction, + /// which is: if the entry x corresponds to ground truth, + /// then softmax(x) - 1; else, softmax(x) + Tensor Backward() override; + + private: + // to buffer intermediate data, i.e., softmax(prediction), target + std::stack<Tensor> buf_; +}; + +Tensor CrossEntropy::Forward(const Tensor& prediction, const Tensor& target) { + CHECK(buf_.empty()) << "Do not call Forward successively for more than twice." + << " The calling pattern is [Forward|Evaluate] Backward"; + + size_t batchsize = 1; + if (prediction.nDim() > 1) batchsize = prediction.shape().at(0); + size_t dim = prediction.Size() / batchsize; + // a temporal Softmax layer for forward computation +// LayerConf conf; // TODO(kaiping): this is currently commented +// Softmax softmax_tmp; +// softmax_tmp.Setup(conf); +// Tensor softmax = softmax_tmp.Forward(0, prediction); + + Tensor softmax(Shape{batchsize, dim}); // TODO(kaiping): Delete +// softmax.SetValue<float>(0.5f); // TODO(kaiping): Delete + + softmax.Reshape(Shape{batchsize, dim}); + // buffer intermediate data + buf_.push(softmax); + buf_.push(target); + + // Compute loss for each sample + Tensor loss(Shape{batchsize, 1}); + float * pre_ptr = reinterpret_cast<float*>(softmax.blob()->mutable_data()); + float * truth_ptr = reinterpret_cast<float*>(target.blob()->mutable_data()); + float * loss_ptr = reinterpret_cast<float*>(loss.blob()->mutable_data()); + for (size_t i = 0; i < batchsize; i++) { + int ilabel = static_cast<int>(truth_ptr[i]); + CHECK_GE(ilabel, 0); + float prob_of_truth = pre_ptr[ilabel]; + loss_ptr[i] = -log(prob_of_truth); + pre_ptr += dim; // change to the next sample + } + return loss; +} + +Tensor CrossEntropy::Backward() { + const Tensor& target = buf_.top(); + buf_.pop(); + Tensor softmax = buf_.top(); + buf_.pop(); + + size_t batchsize = 1; + if (softmax.nDim() > 1) + batchsize = softmax.shape().at(0); + size_t dim = softmax.Size() / batchsize; + float * truth_ptr = reinterpret_cast<float*>(target.blob()->mutable_data()); + float * pre_ptr = reinterpret_cast<float*>(softmax.blob()->mutable_data()); + for (size_t i = 0; i < batchsize; i++) { + int ilabel = static_cast<int>(truth_ptr[i]); + // CHECK_GE(ilabel, 0); + pre_ptr[ilabel] -= 1.0; + pre_ptr += dim; // change to the next sample + } + return softmax; +} +} // namespace singa + +#endif // SRC_MODEL_LOSS_CROSS_ENTROPY_H_ + + http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/efd7b627/test/singa/test_cross_entropy.cc ---------------------------------------------------------------------- diff --git a/test/singa/test_cross_entropy.cc b/test/singa/test_cross_entropy.cc new file mode 100644 index 0000000..9bb2321 --- /dev/null +++ b/test/singa/test_cross_entropy.cc @@ -0,0 +1,66 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +#include "gtest/gtest.h" +#include "singa/core/tensor.h" +#include "singa/core/device.h" +#include "../src/model/loss/cross_entropy.h" + +using singa::Tensor; +class TestCrossEntropy : public ::testing::Test { + protected: + virtual void SetUp() { + p.Reshape(singa::Shape{2, 4}); + t.Reshape(singa::Shape{2, 1}); + p.CopyDataFromHostPtr(pdat, sizeof(pdat) / sizeof(float)); + t.CopyDataFromHostPtr(tdat, sizeof(pdat) / sizeof(float)); + } + const float pdat[8] = {0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1}; + const float tdat[2] = {0.0, 2.0}; + + singa::Tensor p, t; +}; + +TEST_F(TestCrossEntropy, CppForward) { + singa::CrossEntropy cross_entropy; + const Tensor& loss = cross_entropy.Forward(p, t); + auto ldat = loss.data<const float*>(); + + const float result_test = -log(0.25); + EXPECT_FLOAT_EQ(ldat[0], result_test); + EXPECT_FLOAT_EQ(ldat[1], result_test); +} + +TEST_F(TestCrossEntropy, CppBackward) { + singa::CrossEntropy cross_entropy; + cross_entropy.Forward(p, t); + const Tensor& grad = cross_entropy.Backward(); + + auto gdat = grad.data<const float*>(); + EXPECT_FLOAT_EQ(gdat[0], -0.75); + EXPECT_FLOAT_EQ(gdat[1], 0.25); + EXPECT_FLOAT_EQ(gdat[2], 0.25); + EXPECT_FLOAT_EQ(gdat[3], 0.25); + EXPECT_FLOAT_EQ(gdat[4], 0.25); + EXPECT_FLOAT_EQ(gdat[5], 0.25); + EXPECT_FLOAT_EQ(gdat[6], -0.75); + EXPECT_FLOAT_EQ(gdat[7], 0.25); +}
