Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Backward Time Complexity to O(MK) #1

Open
wants to merge 6 commits into
base: caffe-face
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
#Faster Center Loss Implementation
This branch is forked from [ydwen's caffe-face](https://github.com/ydwen/caffe-face) and modified by mfs6174 ( [email protected] )

Compared to the original implementation by the paper author, the backward time complexity of this implementation is optimized to O(MK) from O(MK+NM).

In the original implementation, the time complexity of the backward process of the center loss layer is O(MK+NM). It will be very slow when training with a large number of classes since the running time of the backward pass is related to the class number (N). Unfortunately, it is a common case when training face recognition model (e.g. 750k unique persons).

This implementation rewrites the backward code. The time complexity is optimized to O(MK) with additional O(N) space. Because M (batch size) << N and K (feature length) << N usually hold for face recognition problem, this modification will improve the training speed significantly.

For a Googlenet v2 model trained with Everphoto's 750k unique person dataset, on a single Nvidia GTX Titan X, with 24 batch size and iter_size = 5, the average backward iteration time for different cases is:

1. Softmax only: 230ms
2. Softmax + Center loss, original implementation: 3485ms, center loss layer: 3332ms
3. Softmax + Center loss, implementation in this PR: 235.6ms, center loss layer: 5.4ms

There is more than 600x improvement.

For the author's "minit_example", running on a single GTX Titan X, training time of the original implementation and the PR is 4min20s V.S. 3min50s. It is shown that even when training with small dataset with only 10 classes, there still is some improvement.

The implementation also fix the code style to pass the Caffe's lint test (make lint) so that it may be ready to be merged into Caffe's master.

# Deep Face Recognition with Caffe Implementation

This branch is developed for deep face recognition, the related paper is as follows.
Expand Down Expand Up @@ -185,4 +206,4 @@ Please cite Caffe in your publications if it helps your research:
Journal = {arXiv preprint arXiv:1408.5093},
Title = {Caffe: Convolutional Architecture for Fast Feature Embedding},
Year = {2014}
}
}
4 changes: 2 additions & 2 deletions include/caffe/layers/center_loss_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class CenterLossLayer : public LossLayer<Dtype> {
int M_;
int K_;
int N_;

Blob<Dtype> distance_;
Blob<Dtype> variation_sum_;
Blob<int> count_;
};

} // namespace caffe

#endif // CAFFE_CENTER_LOSS_LAYER_HPP_
#endif // CAFFE_CENTER_LOSS_LAYER_HPP_
52 changes: 32 additions & 20 deletions src/caffe/layers/center_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace caffe {
template <typename Dtype>
void CenterLossLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int num_output = this->layer_param_.center_loss_param().num_output();
const int num_output = this->layer_param_.center_loss_param().num_output();
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.center_loss_param().axis());
Expand All @@ -31,7 +31,6 @@ void CenterLossLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
shared_ptr<Filler<Dtype> > center_filler(GetFiller<Dtype>(
this->layer_param_.center_loss_param().center_filler()));
center_filler->Fill(this->blobs_[0].get());

} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
}
Expand All @@ -48,6 +47,9 @@ void CenterLossLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
LossLayer<Dtype>::Reshape(bottom, top);
distance_.ReshapeLike(*bottom[0]);
variation_sum_.ReshapeLike(*this->blobs_[0]);
vector<int> count_shape(1);
count_shape[0] = N_;
count_.Reshape(count_shape);
}

template <typename Dtype>
Expand All @@ -57,54 +59,64 @@ void CenterLossLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* label = bottom[1]->cpu_data();
const Dtype* center = this->blobs_[0]->cpu_data();
Dtype* distance_data = distance_.mutable_cpu_data();

// the i-th distance_data
for (int i = 0; i < M_; i++) {
const int label_value = static_cast<int>(label[i]);
// D(i,:) = X(i,:) - C(y(i),:)
caffe_sub(K_, bottom_data + i * K_, center + label_value * K_, distance_data + i * K_);
caffe_sub(K_, bottom_data + i * K_,
center + label_value * K_, distance_data + i * K_);
}
Dtype dot = caffe_cpu_dot(M_ * K_, distance_.cpu_data(), distance_.cpu_data());
Dtype dot = caffe_cpu_dot(M_ * K_, distance_.cpu_data(),
distance_.cpu_data());
Dtype loss = dot / M_ / Dtype(2);
top[0]->mutable_cpu_data()[0] = loss;
}

template <typename Dtype>
void CenterLossLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
// Gradient with respect to centers
if (this->param_propagate_down_[0]) {
const Dtype* label = bottom[1]->cpu_data();
Dtype* center_diff = this->blobs_[0]->mutable_cpu_diff();
Dtype* variation_sum_data = variation_sum_.mutable_cpu_data();
int* count_data = count_.mutable_cpu_data();

const Dtype* distance_data = distance_.cpu_data();

// \sum_{y_i==j}
caffe_set(N_ * K_, (Dtype)0., variation_sum_.mutable_cpu_data());
for (int n = 0; n < N_; n++) {
int count = 0;
for (int m = 0; m < M_; m++) {
const int label_value = static_cast<int>(label[m]);
if (label_value == n) {
count++;
caffe_sub(K_, variation_sum_data + n * K_, distance_data + m * K_, variation_sum_data + n * K_);
}
}
caffe_axpy(K_, (Dtype)1./(count + (Dtype)1.), variation_sum_data + n * K_, center_diff + n * K_);
caffe_set(N_, 0 , count_.mutable_cpu_data());
caffe_set(N_ * K_, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blob's diff should not be reseted in Backward().
It will hurt the gradient accumulation when iter_size > 1. See discussion in https://groups.google.com/forum/#!searchin/caffe-users/iter_size|sort:relevance/caffe-users/PMbycfbpKcY/FTBiMKunEQAJ

Copy link
Author

@mfs6174 mfs6174 Nov 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your review. It is true that the diff has been initialized in other code and should not be reset here. I have just fixed it.
Interestingly, I have successfully trained a network with center loss on a large dataset using iter_size > 1 and it seems that the result did not suffered much from this bug.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything you wrote wrong works as a regularize term:)


for (int m = 0; m < M_; m++) {
const int label_value = static_cast<int>(label[m]);
caffe_sub(K_, variation_sum_data + label_value * K_,
distance_data + m * K_, variation_sum_data + label_value * K_);
count_data[label_value]++;
}
for (int m = 0; m < M_; m++) {
const int n = static_cast<int>(label[m]);
caffe_cpu_axpby(K_, (Dtype)1./ (count_data[n] + (Dtype)1.),
variation_sum_data + n * K_,
(Dtype)0., center_diff + n * K_);
}
}
// Gradient with respect to bottom data
// Gradient with respect to bottom data
if (propagate_down[0]) {
caffe_copy(M_ * K_, distance_.cpu_data(), bottom[0]->mutable_cpu_diff());
caffe_scal(M_ * K_, top[0]->cpu_diff()[0] / M_, bottom[0]->mutable_cpu_diff());
caffe_copy(M_ * K_, distance_.cpu_data(),
bottom[0]->mutable_cpu_diff());
caffe_scal(M_ * K_, top[0]->cpu_diff()[0] / M_,
bottom[0]->mutable_cpu_diff());
}
if (propagate_down[1]) {
LOG(FATAL) << this->type()
<< " Layer cannot backpropagate to label inputs.";
}
}


#ifdef CPU_ONLY
STUB_GPU(CenterLossLayer);
#endif
Expand Down
83 changes: 50 additions & 33 deletions src/caffe/layers/center_loss_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
namespace caffe {

template <typename Dtype>
__global__ void Compute_distance_data_gpu(int nthreads, const int K, const Dtype* bottom,
const Dtype* label, const Dtype* center, Dtype* distance) {
__global__ void Compute_distance_data_gpu(int nthreads, const int K,
const Dtype* bottom,
const Dtype* label,
const Dtype* center,
Dtype* distance) {
CUDA_KERNEL_LOOP(index, nthreads) {
int m = index / K;
int k = index % K;
Expand All @@ -17,55 +20,69 @@ __global__ void Compute_distance_data_gpu(int nthreads, const int K, const Dtype
distance[index] = bottom[index] - center[label_value * K + k];
}
}

template <typename Dtype>
__global__ void Compute_center_diff_gpu(int nthreads, const int M, const int K,
const Dtype* label, const Dtype* distance, Dtype* variation_sum,
Dtype* center_diff) {
__global__ void Compute_variation_sum_gpu(int nthreads, const int K,
const Dtype* label,
const Dtype* distance,
Dtype* variation_sum, int * count) {
CUDA_KERNEL_LOOP(index, nthreads) {
int count = 0;
for (int m = 0; m < M; m++) {
const int label_value = static_cast<int>(label[m]);
if (label_value == index) {
count++;
for (int k = 0; k < K; k++) {
variation_sum[index * K + k] -= distance[m * K + k];
}
}
}
for (int k = 0; k < K; k++) {
center_diff[index * K + k] = variation_sum[index * K + k] /(count + (Dtype)1.);
}
int m = index / K;
int k = index % K;
const int label_value = static_cast<int>(label[m]);
variation_sum[label_value * K + k] -= distance[m * K + k];
count[label_value] += ((k == 0)?1:0);
}
}


template <typename Dtype>
__global__ void Compute_center_diff_gpu(int nthreads, const int K,
const Dtype* label,
Dtype* variation_sum,
int * count, Dtype* center_diff) {
CUDA_KERNEL_LOOP(index, nthreads) {
int m = index / K;
int k = index % K;
const int n = static_cast<int>(label[m]);
center_diff[n * K + k] = variation_sum[n * K + k]
/ (count[n] + (Dtype)1.);
}
}
template <typename Dtype>
void CenterLossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
int nthreads = M_ * K_;
Compute_distance_data_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[0]->gpu_data(), bottom[1]->gpu_data(),
this->blobs_[0]->gpu_data(), distance_.mutable_gpu_data());
Compute_distance_data_gpu<Dtype> <<< CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[0]->gpu_data(),
bottom[1]->gpu_data(),
this->blobs_[0]->gpu_data(),
distance_.mutable_gpu_data());
Dtype dot;
caffe_gpu_dot(M_ * K_, distance_.gpu_data(), distance_.gpu_data(), &dot);
Dtype loss = dot / M_ / Dtype(2);
top[0]->mutable_cpu_data()[0] = loss;
}

template <typename Dtype>
void CenterLossLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom) {
int nthreads = N_;
caffe_gpu_set(N_ * K_, (Dtype)0., variation_sum_.mutable_cpu_data());
Compute_center_diff_gpu<Dtype><<<CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, M_, K_, bottom[1]->gpu_data(), distance_.gpu_data(),
variation_sum_.mutable_cpu_data(), this->blobs_[0]->mutable_gpu_diff());

caffe_gpu_set(N_ * K_, (Dtype)0., variation_sum_.mutable_gpu_data());
caffe_gpu_set(N_, 0 , count_.mutable_gpu_data());
caffe_gpu_set(N_ * K_, (Dtype)0., this->blobs_[0]->mutable_gpu_diff());
int nthreads = M_ * K_;
Compute_variation_sum_gpu<Dtype> <<< CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[1]->gpu_data(),
distance_.gpu_data(),
variation_sum_.mutable_gpu_data(),
count_.mutable_gpu_data());
Compute_center_diff_gpu<Dtype> <<< CAFFE_GET_BLOCKS(nthreads),
CAFFE_CUDA_NUM_THREADS>>>(nthreads, K_, bottom[1]->gpu_data(),
variation_sum_.mutable_gpu_data(),
count_.mutable_gpu_data(),
this->blobs_[0]->mutable_gpu_diff());
if (propagate_down[0]) {
caffe_gpu_scale(M_ * K_, top[0]->cpu_diff()[0] / M_,
distance_.gpu_data(), bottom[0]->mutable_gpu_diff());
caffe_gpu_scale(M_ * K_,
top[0]->cpu_diff()[0] / M_,
distance_.gpu_data(),
bottom[0]->mutable_gpu_diff());
}
if (propagate_down[1]) {
LOG(FATAL) << this->type()
Expand Down