You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/08/22 22:11:23 UTC
[GitHub] haojin2 closed pull request #12261: Support fp16 in synchronized
batchnorm
haojin2 closed pull request #12261: Support fp16 in synchronized batchnorm
URL: https://github.com/apache/incubator-mxnet/pull/12261
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h
index 1f548dbc7e5..4c04f7ef8b2 100644
--- a/src/operator/contrib/sync_batch_norm-inl.h
+++ b/src/operator/contrib/sync_batch_norm-inl.h
@@ -259,6 +259,8 @@ class SyncBatchNorm : public Operator {
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
+ using namespace mshadow_op;
+ using namespace mxnet_op;
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(aux_states.size(), 2U);
if (ctx.is_train) {
@@ -271,69 +273,102 @@ class SyncBatchNorm : public Operator {
}
Stream<xpu> *s = ctx.get_stream<xpu>();
- const real_t scale = static_cast<real_t>(in_data[syncbatchnorm::kData].shape_[1]) /
- static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
- Tensor<xpu, 4> data;
- Tensor<xpu, 4> out;
- if (in_data[syncbatchnorm::kData].ndim() == 2) {
- Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
- in_data[syncbatchnorm::kData].shape_[1], 1, 1);
- data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
- out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
- } else {
- data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
- out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
- }
- Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> moving_mean = aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 1, real_t>(s);
-
- if (param_.fix_gamma) slope = 1.f;
-
- // whether use global statistics
- if (ctx.is_train && !param_.use_global_stats) {
- // get my rank
- Barrier *global_barrier = global_shared_barrier_forward.Register(param_.key, param_.ndev);
- int myRank = global_shared_rank_forward.Register(param_.key, param_.ndev);
- // get the mean and var
- Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, real_t>(s);
- CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] == kWriteTo);
- CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == kWriteTo);
- // E(x) and E(x^2)
- mean = scale * sumall_except_dim<1>(data);
- var = scale * sumall_except_dim<1>(F<mshadow_op::square>(data));
- SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedMean =
- global_shared_mean.Register(param_.key, param_.ndev);
- SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedVar =
- global_shared_var.Register(param_.key, param_.ndev);
- // copy to cpu, push and pull
- Tensor<cpu, 1, real_t>* mean_cpu_ptr = sharedMean->Retrieve(mean.shape_, myRank);
- Tensor<cpu, 1, real_t>* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, myRank);
- mshadow::Copy(*mean_cpu_ptr, mean, s);
- mshadow::Copy(*var_cpu_ptr, var, s);
- sharedMean->SetReady(myRank);
- sharedVar->SetReady(myRank);
- global_barrier->Wait();
- Tensor<cpu, 1, real_t> mean_cpu = sharedMean->Pop(myRank);
- Tensor<cpu, 1, real_t> var_cpu = sharedVar->Pop(myRank);
- // copy back to gpu
- mshadow::Copy(mean, mean_cpu, s);
- mshadow::Copy(var, var_cpu, s);
-
- var = var-F<mshadow_op::square>(mean);
- Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope, out.shape_) *
- (data - broadcast<1>(mean, data.shape_)) /
- F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, data.shape_)) +
- broadcast<1>(bias, out.shape_));
- } else {
- Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope /
- F<mshadow_op::square_root>(moving_var + param_.eps),
- data.shape_) * data +
- broadcast<1>(bias - (slope * moving_mean) /
- F<mshadow_op::square_root>(moving_var + param_.eps), data.shape_));
- }
+ MSHADOW_TYPE_SWITCH(in_data[syncbatchnorm::kData].type_flag_, DType, {
+ const bool is_double = std::is_same<DType, double>::value;
+ CHECK_EQ(is_double, false)
+ << "Synchronized BatchNorm does not support double-precision floating number yet...";
+ const real_t scale = static_cast<real_t>(in_data[syncbatchnorm::kData].shape_[1]) /
+ static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
+ const size_t data_size = in_data[syncbatchnorm::kData].Size();
+ Tensor<xpu, 4> data;
+ Tensor<xpu, 4> out;
+ Tensor<xpu, 1> workspace;
+ if (!std::is_same<DType, real_t>::value) {
+ workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu, 1>(
+ Shape1(data_size * 2), s);
+ }
+ if (in_data[syncbatchnorm::kData].ndim() == 2) {
+ Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
+ in_data[syncbatchnorm::kData].shape_[1], 1, 1);
+ if (std::is_same<DType, real_t>::value) {
+ data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
+ out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
+ } else {
+ data = Tensor<xpu, 4>(workspace.dptr_, dshape, s);
+ out = Tensor<xpu, 4>(workspace.dptr_ + data_size, dshape, s);
+ }
+ } else {
+ if (std::is_same<DType, real_t>::value) {
+ data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+ out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
+ } else {
+ Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
+ in_data[syncbatchnorm::kData].shape_[1],
+ in_data[syncbatchnorm::kData].shape_[2],
+ in_data[syncbatchnorm::kData].shape_[3]);
+ data = Tensor<xpu, 4>(workspace.dptr_, dshape, s);
+ out = Tensor<xpu, 4>(workspace.dptr_ + data_size, dshape, s);
+ }
+ }
+ if (!std::is_same<DType, real_t>::value) {
+ Kernel<identity_with_cast, xpu>::Launch(
+ s, data.shape_.Size(), data.dptr_, in_data[syncbatchnorm::kData].dptr<DType>());
+ }
+ Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> moving_mean = aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 1, real_t>(s);
+
+ if (param_.fix_gamma) slope = 1.f;
+
+ // whether use global statistics
+ if (ctx.is_train && !param_.use_global_stats) {
+ // get my rank
+ Barrier *global_barrier = global_shared_barrier_forward.Register(param_.key, param_.ndev);
+ int myRank = global_shared_rank_forward.Register(param_.key, param_.ndev);
+ // get the mean and var
+ Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, real_t>(s);
+ CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] == kWriteTo);
+ CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == kWriteTo);
+ // E(x) and E(x^2)
+ mean = scale * sumall_except_dim<1>(data);
+ var = scale * sumall_except_dim<1>(F<mshadow_op::square>(data));
+ SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedMean =
+ global_shared_mean.Register(param_.key, param_.ndev);
+ SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedVar =
+ global_shared_var.Register(param_.key, param_.ndev);
+ // copy to cpu, push and pull
+ Tensor<cpu, 1, real_t>* mean_cpu_ptr = sharedMean->Retrieve(mean.shape_, myRank);
+ Tensor<cpu, 1, real_t>* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, myRank);
+ mshadow::Copy(*mean_cpu_ptr, mean, s);
+ mshadow::Copy(*var_cpu_ptr, var, s);
+ sharedMean->SetReady(myRank);
+ sharedVar->SetReady(myRank);
+ global_barrier->Wait();
+ Tensor<cpu, 1, real_t> mean_cpu = sharedMean->Pop(myRank);
+ Tensor<cpu, 1, real_t> var_cpu = sharedVar->Pop(myRank);
+ // copy back to gpu
+ mshadow::Copy(mean, mean_cpu, s);
+ mshadow::Copy(var, var_cpu, s);
+
+ var = var-F<mshadow_op::square>(mean);
+ Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope, out.shape_) *
+ (data - broadcast<1>(mean, data.shape_)) /
+ F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, data.shape_)) +
+ broadcast<1>(bias, out.shape_));
+ } else {
+ Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope /
+ F<mshadow_op::square_root>(moving_var + param_.eps),
+ data.shape_) * data +
+ broadcast<1>(bias - (slope * moving_mean) /
+ F<mshadow_op::square_root>(moving_var + param_.eps), data.shape_));
+ }
+ if (!std::is_same<DType, real_t>::value) {
+ Kernel<identity_with_cast, xpu>::Launch(
+ s, out.shape_.Size(), out_data[syncbatchnorm::kOut].dptr<DType>(), out.dptr_);
+ }
+ });
}
virtual void Backward(const OpContext &ctx,
@@ -345,6 +380,8 @@ class SyncBatchNorm : public Operator {
const std::vector<TBlob> &aux_states) {
using namespace mshadow;
using namespace mshadow::expr;
+ using namespace mshadow_op;
+ using namespace mxnet_op;
CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U);
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
@@ -352,102 +389,152 @@ class SyncBatchNorm : public Operator {
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data, grad, grad_in;
- const real_t scale = static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
- static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
- if (in_data[syncbatchnorm::kData].ndim() == 2) {
- Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
- out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
- data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
- grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
- grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
- } else {
- data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
- grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
- grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
- }
+ Tensor<xpu, 1> workspace;
+ const size_t data_size = in_data[syncbatchnorm::kData].Size();
+ MSHADOW_TYPE_SWITCH(in_data[syncbatchnorm::kData].type_flag_, DType, {
+ const bool is_double = std::is_same<DType, double>::value;
+ CHECK_EQ(is_double, false)
+ << "Synchronized BatchNorm does not support double-precision floating number yet...";
+ size_t total_workspace_size = 0;
- Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
- // Tensor<xpu, 1> bias = in_data[kBeta].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> gslope = in_grad[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> gbias = in_grad[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
- // update moving avg
- Tensor<xpu, 1> moving_mean = aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
- Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 1, real_t>(s);
-
- if (param_.fix_gamma) slope = 1.f;
-
- if (ctx.is_train && !param_.use_global_stats) {
- // get my rank
- Barrier *global_barrier = global_shared_barrier_backward.Register(param_.key, param_.ndev);
- int myRank = global_shared_rank_backward.Register(param_.key, param_.ndev);
- // get requested temp space
- Tensor<xpu, 2> workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu>(
- mshadow::Shape2(5, mean.shape_[0]), s);
- Tensor<xpu, 1> gmean = workspace[0];
- Tensor<xpu, 1> gvar = workspace[1];
-
- moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum);
- moving_var = moving_var * param_.momentum + var * (1 - param_.momentum);
- // cal
- Tensor<xpu, 1> sumGrad = workspace[3];
- Tensor<xpu, 1> sumProd = workspace[4];
- sumGrad = sumall_except_dim<1>(grad);
- sumProd = sumall_except_dim<1>(grad * (data - broadcast<1>(mean, data.shape_)));
- SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedGrad =
- global_shared_grad.Register(param_.key, param_.ndev);
- SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedProd =
- global_shared_prod.Register(param_.key, param_.ndev);
- // copy to cpu, push and pull
- Tensor<cpu, 1, real_t>* grad_cpu_ptr = sharedGrad->Retrieve(sumGrad.shape_, myRank);
- Tensor<cpu, 1, real_t>* prod_cpu_ptr = sharedProd->Retrieve(sumGrad.shape_, myRank);
- mshadow::Copy(*grad_cpu_ptr, sumGrad, s);
- mshadow::Copy(*prod_cpu_ptr, sumProd, s);
- sharedGrad->SetReady(myRank);
- sharedProd->SetReady(myRank);
- global_barrier->Wait();
- Tensor<cpu, 1, real_t> grad_cpu = sharedGrad->Pop(myRank);
- Tensor<cpu, 1, real_t> prod_cpu = sharedProd->Pop(myRank);
- // copy back to gpu
- mshadow::Copy(sumGrad, grad_cpu, s);
- mshadow::Copy(sumProd, prod_cpu, s);
-
- gvar = -1.0f * sumProd * slope *
- F<mshadow_op::power>(var + param_.eps, -1.5f);
- gmean = sumGrad * slope;
- gmean *= -1.0f / F<mshadow_op::square_root>(var + param_.eps);
- // assign
- if (!param_.fix_gamma) {
- Assign(gslope, req[syncbatchnorm::kGamma],
- sumall_except_dim<1>(
- grad * (data - broadcast<1>(mean, data.shape_)) /
- F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, data.shape_))));
+ Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> var = out_data[syncbatchnorm::kVar].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> gslope = in_grad[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> gbias = in_grad[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
+ // update moving avg
+ Tensor<xpu, 1> moving_mean = aux_states[syncbatchnorm::kMovingMean].get<xpu, 1, real_t>(s);
+ Tensor<xpu, 1> moving_var = aux_states[syncbatchnorm::kMovingVar].get<xpu, 1, real_t>(s);
+
+ if (ctx.is_train && !param_.use_global_stats) {
+ total_workspace_size += 4 * mean.shape_[0];
+ }
+ if (!std::is_same<DType, real_t>::value) {
+ total_workspace_size += 3 * data_size;
+ }
+
+ workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space<xpu, 1>(
+ mshadow::Shape1(total_workspace_size), s);
+
+ const real_t scale = static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
+ static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
+ if (in_data[syncbatchnorm::kData].ndim() == 2) {
+ Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
+ out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
+ if (!std::is_same<DType, real_t>::value) {
+ real_t* starting_ptr = (ctx.is_train && !param_.use_global_stats) ?
+ workspace.dptr_ + 4 * mean.shape_[0] :
+ workspace.dptr_;
+ data = Tensor<xpu, 4>(starting_ptr, dshape, s);
+ grad = Tensor<xpu, 4>(starting_ptr + data_size, dshape, s);
+ grad_in = Tensor<xpu, 4>(starting_ptr + 2 * data_size, dshape, s);
+ } else {
+ data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
+ grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
+ grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
+ }
} else {
- Assign(gslope, req[syncbatchnorm::kGamma], 0.0f);
+ Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
+ out_grad[syncbatchnorm::kOut].shape_[1],
+ out_grad[syncbatchnorm::kOut].shape_[2],
+ out_grad[syncbatchnorm::kOut].shape_[3]);
+ if (!std::is_same<DType, real_t>::value) {
+ real_t* starting_ptr = (ctx.is_train && !param_.use_global_stats) ?
+ workspace.dptr_ + 4 * mean.shape_[0] :
+ workspace.dptr_;
+ data = Tensor<xpu, 4>(starting_ptr, dshape, s);
+ grad = Tensor<xpu, 4>(starting_ptr + data_size, dshape, s);
+ grad_in = Tensor<xpu, 4>(starting_ptr + 2 * data_size, dshape, s);
+ } else {
+ data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+ grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
+ grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
+ }
}
- Assign(grad_in, req[syncbatchnorm::kData],
- (grad * broadcast<1>(slope, data.shape_)) *
- broadcast<1>(1.0f / F<mshadow_op::square_root>(var + param_.eps), data.shape_) +
- broadcast<1>(gvar, data.shape_) *
- scale * (data - broadcast<1>(mean, data.shape_)) +
- broadcast<1>(gmean, data.shape_) * scale);
- Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad));
- } else {
- // use global statistics with freeze moving mean and var.
- if (!param_.fix_gamma) {
- Assign(gslope, req[syncbatchnorm::kGamma],
- sumall_except_dim<1>(
- grad * (data - broadcast<1>(moving_mean, data.shape_)) /
- F<mshadow_op::square_root>(broadcast<1>(moving_var + param_.eps, data.shape_))));
+
+ if (!std::is_same<DType, real_t>::value) {
+ Kernel<identity_with_cast, xpu>::Launch(
+ s, data.shape_.Size(), data.dptr_, in_data[syncbatchnorm::kData].dptr<DType>());
+ Kernel<identity_with_cast, xpu>::Launch(
+ s, grad.shape_.Size(), grad.dptr_, out_grad[syncbatchnorm::kOut].dptr<DType>());
+ }
+
+ if (param_.fix_gamma) slope = 1.f;
+
+ if (ctx.is_train && !param_.use_global_stats) {
+ // get my rank
+ Barrier *global_barrier = global_shared_barrier_backward.Register(param_.key, param_.ndev);
+ int myRank = global_shared_rank_backward.Register(param_.key, param_.ndev);
+
+ Shape<1> dshape = Shape1(mean.shape_[0]);
+ Tensor<xpu, 1> gmean = Tensor<xpu, 1>(workspace.dptr_, dshape, s);
+ Tensor<xpu, 1> gvar = Tensor<xpu, 1>(workspace.dptr_ + mean.shape_[0], dshape, s);
+
+ moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum);
+ moving_var = moving_var * param_.momentum + var * (1 - param_.momentum);
+ // cal
+ Tensor<xpu, 1> sumGrad = Tensor<xpu, 1>(workspace.dptr_ + 2 * mean.shape_[0], dshape, s);
+ Tensor<xpu, 1> sumProd = Tensor<xpu, 1>(workspace.dptr_ + 3 * mean.shape_[0], dshape, s);
+ sumGrad = sumall_except_dim<1>(grad);
+ sumProd = sumall_except_dim<1>(grad * (data - broadcast<1>(mean, data.shape_)));
+ SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedGrad =
+ global_shared_grad.Register(param_.key, param_.ndev);
+ SharedND<mshadow::Tensor<cpu, 1, real_t>> *sharedProd =
+ global_shared_prod.Register(param_.key, param_.ndev);
+ // copy to cpu, push and pull
+ Tensor<cpu, 1, real_t>* grad_cpu_ptr = sharedGrad->Retrieve(sumGrad.shape_, myRank);
+ Tensor<cpu, 1, real_t>* prod_cpu_ptr = sharedProd->Retrieve(sumGrad.shape_, myRank);
+ mshadow::Copy(*grad_cpu_ptr, sumGrad, s);
+ mshadow::Copy(*prod_cpu_ptr, sumProd, s);
+ sharedGrad->SetReady(myRank);
+ sharedProd->SetReady(myRank);
+ global_barrier->Wait();
+ Tensor<cpu, 1, real_t> grad_cpu = sharedGrad->Pop(myRank);
+ Tensor<cpu, 1, real_t> prod_cpu = sharedProd->Pop(myRank);
+ // copy back to gpu
+ mshadow::Copy(sumGrad, grad_cpu, s);
+ mshadow::Copy(sumProd, prod_cpu, s);
+
+ gvar = -1.0f * sumProd * slope *
+ F<mshadow_op::power>(var + param_.eps, -1.5f);
+ gmean = sumGrad * slope;
+ gmean *= -1.0f / F<mshadow_op::square_root>(var + param_.eps);
+ // assign
+ if (!param_.fix_gamma) {
+ Assign(gslope, req[syncbatchnorm::kGamma],
+ sumall_except_dim<1>(
+ grad * (data - broadcast<1>(mean, data.shape_)) /
+ F<mshadow_op::square_root>(broadcast<1>(var + param_.eps, data.shape_))));
+ } else {
+ Assign(gslope, req[syncbatchnorm::kGamma], 0.0f);
+ }
+ Assign(grad_in, req[syncbatchnorm::kData],
+ (grad * broadcast<1>(slope, data.shape_)) *
+ broadcast<1>(1.0f / F<mshadow_op::square_root>(var + param_.eps), data.shape_) +
+ broadcast<1>(gvar, data.shape_) *
+ scale * (data - broadcast<1>(mean, data.shape_)) +
+ broadcast<1>(gmean, data.shape_) * scale);
+ Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad));
} else {
- Assign(gslope, req[syncbatchnorm::kGamma], 0.0f);
+ // use global statistics with freeze moving mean and var.
+ if (!param_.fix_gamma) {
+ Assign(gslope, req[syncbatchnorm::kGamma],
+ sumall_except_dim<1>(
+ grad * (data - broadcast<1>(moving_mean, data.shape_)) /
+ F<mshadow_op::square_root>(broadcast<1>(moving_var + param_.eps, data.shape_))));
+ } else {
+ Assign(gslope, req[syncbatchnorm::kGamma], 0.0f);
+ }
+ Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad));
+ Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) *
+ broadcast<1>(
+ 1.0f / F<mshadow_op::square_root>(moving_var + param_.eps), data.shape_));
}
- Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad));
- Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) *
- broadcast<1>(
- 1.0f / F<mshadow_op::square_root>(moving_var + param_.eps), data.shape_));
- }
+ if (!std::is_same<DType, real_t>::value) {
+ Kernel<identity_with_cast, xpu>::Launch(
+ s, grad_in.shape_.Size(), in_grad[syncbatchnorm::kData].dptr<DType>(), grad_in.dptr_);
+ }
+ });
}
private:
@@ -532,6 +619,11 @@ class SyncBatchNormProp : public OperatorProperty {
return "_contrib_SyncBatchNorm";
}
+ std::vector<ResourceRequest> ForwardResource(
+ const std::vector<TShape> &in_shape) const override {
+ return {ResourceRequest::kTempSpace};
+ }
+
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py
index 42d65dab5fd..38306f87a86 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -139,6 +139,8 @@ def _syncParameters(bn1, bn2, ctx):
input1 = input.copy()
input2 = input.copy()
+ rtol, atol = (1e-2, 1e-2) if input.dtype is np.float16 else (1e-3, 1e-3)
+
if cuda:
input1 = input.as_in_context(mx.gpu(0))
ctx_list = [mx.gpu(i) for i in range(num_devices)]
@@ -152,9 +154,6 @@ def _syncParameters(bn1, bn2, ctx):
bn1.initialize(ctx=ctx_list[0])
bn2.initialize(ctx=ctx_list)
- # using the same values for gamma and beta
- #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
-
input1.attach_grad()
inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
for xi in inputs2:
@@ -170,18 +169,19 @@ def _syncParameters(bn1, bn2, ctx):
output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0)
# assert forwarding
- assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3)
- assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3)
+ assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=atol, rtol=rtol)
+ assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=atol, rtol=rtol)
assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
_find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
- atol=1e-3, rtol=1e-3)
+ atol=atol, rtol=rtol)
assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
_find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
- atol=1e-3, rtol=1e-3)
+ atol=atol, rtol=rtol)
input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
- assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3)
+ assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=atol, rtol=rtol)
+@with_seed()
def test_sync_batchnorm():
def get_num_devices():
for i in range(100):
@@ -193,10 +193,12 @@ def get_num_devices():
if get_num_devices() < 2:
return
ndev = 2
+ dtypes = [np.float16, np.float32]
# check with unsync version
for i in range(10):
- _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)),
- num_devices=ndev, cuda=True)
+ for dtype in dtypes:
+ _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)).astype(dtype),
+ num_devices=ndev, cuda=True)
if __name__ == '__main__':
import nose
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
With regards,
Apache Git Services