You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/13 13:32:07 UTC

[GitHub] csgcmai closed pull request #9056: Face

csgcmai closed pull request #9056: Face
URL: https://github.com/apache/incubator-mxnet/pull/9056
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/dmlc-core b/dmlc-core
index c39001019e..8afaaae57b 160000
--- a/dmlc-core
+++ b/dmlc-core
@@ -1 +1 @@
-Subproject commit c39001019e443c7a061789bd1180f58ce85fc3e6
+Subproject commit 8afaaae57b8832225c23070cd25093ce9f5b6e9c
diff --git a/mshadow b/mshadow
index e41ae71f70..0186f06e3c 160000
--- a/mshadow
+++ b/mshadow
@@ -1 +1 @@
-Subproject commit e41ae71f7096f4b3592c30786328f95ad0eb6dd0
+Subproject commit 0186f06e3c1ffd0777775fedd670d82052317674
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index bc51e469db..d8d556a54d 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -142,6 +142,18 @@ def update(self, labels, preds):
             self.sum_metric += (pred_label.flat == label.flat).sum()
             self.num_inst += len(pred_label.flat)
 
+class MultiBinaryAccuracy(EvalMetric):
+    """Calculate multi-binary (+1 vs -1) accuracy"""
+
+    def __init__(self):
+        super(MultiBinaryAccuracy, self).__init__('multi_binary_accuracy')
+
+    def update(self, labels, preds):
+        check_label_shapes(labels, preds)
+        for i in range(len(labels)):
+            self.sum_metric += ((labels[i].asnumpy() >= 0) == (preds[i].asnumpy() >= 0)).sum()
+            self.num_inst += numpy.prod(labels[i].asnumpy().shape)
+
 class TopKAccuracy(EvalMetric):
     """Calculate top k predictions accuracy"""
 
@@ -314,26 +326,20 @@ class CustomMetric(EvalMetric):
     ----------
     feval : callable(label, pred)
         Customized evaluation function.
+
     name : str, optional
         The name of the metric
-    allow_extra_outputs : bool
-        If true, the prediction outputs can have extra outputs.
-        This is useful in RNN, where the states are also produced
-        in outputs for forwarding.
     """
-    def __init__(self, feval, name=None, allow_extra_outputs=False):
+    def __init__(self, feval, name=None):
         if name is None:
             name = feval.__name__
             if name.find('<') != -1:
                 name = 'custom(%s)' % name
         super(CustomMetric, self).__init__(name)
         self._feval = feval
-        self._allow_extra_outputs = allow_extra_outputs
 
     def update(self, labels, preds):
-        if not self._allow_extra_outputs:
-            check_label_shapes(labels, preds)
-
+        check_label_shapes(labels, preds)
         for pred, label in zip(preds, labels):
             label = label.asnumpy()
             pred = pred.asnumpy()
@@ -351,25 +357,22 @@ def update(self, labels, preds):
                 self.num_inst += 1
 
 # pylint: disable=invalid-name
-def np(numpy_feval, name=None, allow_extra_outputs=False):
+def np(numpy_feval, name=None):
     """Create a customized metric from numpy function.
 
     Parameters
     ----------
     numpy_feval : callable(label, pred)
         Customized evaluation function.
+
     name : str, optional
         The name of the metric.
-    allow_extra_outputs : bool
-        If true, the prediction outputs can have extra outputs.
-        This is useful in RNN, where the states are also produced
-        in outputs for forwarding.
     """
     def feval(label, pred):
         """Internal eval function."""
         return numpy_feval(label, pred)
     feval.__name__ = numpy_feval.__name__
-    return CustomMetric(feval, name, allow_extra_outputs)
+    return CustomMetric(feval, name)
 # pylint: enable=invalid-name
 
 def create(metric, **kwargs):
@@ -395,6 +398,7 @@ def create(metric, **kwargs):
     metrics = {
         'acc': Accuracy,
         'accuracy': Accuracy,
+        'multi_binary_acc': MultiBinaryAccuracy,
         'ce': CrossEntropy,
         'f1': F1,
         'mae': MAE,
diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h
index bbe231d755..c5f18e8e61 100644
--- a/src/operator/batch_norm-inl.h
+++ b/src/operator/batch_norm-inl.h
@@ -88,7 +88,8 @@ class BatchNormOp : public Operator {
     Tensor<xpu, 1> bias = in_data[batchnorm::kBeta].get<xpu, 1, real_t>(s);
     Tensor<xpu, 1> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, real_t>(s);
     Tensor<xpu, 1> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, real_t>(s);
-
+   
+    if (param_.fix_gamma) slope = 1.0f;	
     // whether use global statistics
     if (ctx.is_train && !param_.use_global_stats) {
       Tensor<xpu, 1> mean = out_data[batchnorm::kMean].get<xpu, 1, real_t>(s);
diff --git a/src/operator/cudnn_batch_norm-inl.h b/src/operator/cudnn_batch_norm-inl.h
index c58baad7a7..e98bba80c2 100644
--- a/src/operator/cudnn_batch_norm-inl.h
+++ b/src/operator/cudnn_batch_norm-inl.h
@@ -89,6 +89,7 @@ class CuDNNBatchNormOp : public Operator {
     Tensor<gpu, 4> x = in_data[cudnnbatchnorm::kData].get_with_shape<gpu, 4, real_t>(shape_, s);
     Tensor<gpu, 1> gamma =
       in_data[cudnnbatchnorm::kGamma].get_with_shape<gpu, 1, real_t>(Shape1(shape_[1]), s);
+    if (param_.fix_gamma) gamma = 1.0f;
     Tensor<gpu, 1> beta =
       in_data[cudnnbatchnorm::kBeta].get_with_shape<gpu, 1, real_t>(Shape1(shape_[1]), s);
     Tensor<gpu, 4> y = out_data[cudnnbatchnorm::kOut].get_with_shape<gpu, 4, real_t>(shape_, s);
diff --git a/src/operator/moon_output-inl.h b/src/operator/moon_output-inl.h
new file mode 100644
index 0000000000..511e231d2a
--- /dev/null
+++ b/src/operator/moon_output-inl.h
@@ -0,0 +1,212 @@
+/*!
+ * Copyright (c) 2016 by Contributors
+ * \file moon_output-inl.h
+ * \brief
+ *  This is the moon loss operator, which comes from the paper:
+ *	Rudd E, G??nther M, Boult T. MOON: A Mixed Objective Optimization Network for the Recognition of Facial Attributes[J].
+ *  arXiv preprint arXiv:1603.07027, 2016.
+ *  the moon loss operator is usually used in multi-binary-label application, which every binary label is +1 and -1;
+ * \author Wei Wu
+*/
+#ifndef MXNET_OPERATOR_MOON_OUTPUT_INL_H_
+#define MXNET_OPERATOR_MOON_OUTPUT_INL_H_
+
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <cstring>
+#include <map>
+#include <string>
+#include <vector>
+#include <utility>
+#include <fstream>
+#include "./operator_common.h"
+
+namespace mxnet {
+namespace op {
+
+namespace moonout_enum {
+enum MoonOutputOpInputs {kData, kLabel};
+enum MoonOutputOpOutputs {kOut};
+}  // namespace moonout_enum
+
+struct MoonOutputParam : public dmlc::Parameter<MoonOutputParam> {
+  float grad_scale;
+  std::string src_dist_path;
+  DMLC_DECLARE_PARAMETER(MoonOutputParam) {
+    DMLC_DECLARE_FIELD(grad_scale).set_default(1.0f)
+    .describe("Scale the gradient by a float factor");
+	DMLC_DECLARE_FIELD(src_dist_path).set_default("src_dist.txt")
+		.describe("the parameters file of src distribution");
+  };
+};
+
+template<typename xpu, typename DType>
+class MoonOutputOp : public Operator {
+ public:
+  explicit MoonOutputOp(MoonOutputParam param) : param_(param) {
+	  std::ifstream ifs;
+	  ifs.open(param_.src_dist_path.c_str(), std::ifstream::in);
+	  float tmp;
+	  while (ifs >> tmp) {
+		  src_dist_.push_back(tmp);
+	  }
+	  ifs.close();
+  }
+
+  virtual void Forward(const OpContext &ctx,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<TBlob> &aux_args) {
+	using namespace mshadow;
+	using namespace mshadow::expr;
+    CHECK_EQ(in_data.size(), 2) << "MoonOutput Input: [data, label]";
+    CHECK_EQ(out_data.size(), 1) << "MoonOutput Output: [output]";
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Tensor<xpu, 2, DType> data = in_data[moonout_enum::kData].FlatTo2D<xpu, DType>(s);
+	Tensor<xpu, 2, DType> label = in_data[moonout_enum::kLabel].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = out_data[moonout_enum::kOut].FlatTo2D<xpu, DType>(s);
+	CHECK_EQ(data.shape_, out.shape_) << "Moon: shape mismatch between input and output";
+	CHECK_EQ(label.shape_, out.shape_) << "Moon: shape mismatch between label and output";
+	CHECK_EQ(data.shape_[1], src_dist_.size()) << "Moon: shape mismatch between input channel and number parmaters in src_dist.txt";
+	Assign(out, req[moonout_enum::kOut], F<mshadow::op::identity>(data));
+	//out = data;
+  }
+
+  virtual void Backward(const OpContext &ctx,
+                        const std::vector<TBlob> &out_grad,
+                        const std::vector<TBlob> &in_data,
+                        const std::vector<TBlob> &out_data,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &in_grad,
+                        const std::vector<TBlob> &aux_args) {
+    using namespace mshadow;
+    CHECK_EQ(in_data.size(), 2);
+    CHECK_EQ(out_grad.size(), 1);
+    CHECK_GE(in_grad.size(), 1);
+    CHECK_GE(req.size(), 1);
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+	Tensor<xpu, 2, DType> label = in_data[moonout_enum::kLabel].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> out = out_data[moonout_enum::kOut].FlatTo2D<xpu, DType>(s);
+    Tensor<xpu, 2, DType> grad = in_grad[moonout_enum::kData].FlatTo2D<xpu, DType>(s);
+	MoonBackward(grad, out, label, src_dist_);
+    grad *= DType(param_.grad_scale/label.size(1)); // normalize the gradient by number labels
+  }
+
+ private:
+  MoonOutputParam param_;
+  std::vector<float> src_dist_;
+};  // class MoonOutputOp
+
+// Decalre Factory function, used for dispatch specialization
+template<typename xpu>
+Operator* CreateOp(MoonOutputParam param, int dtype);
+
+#if DMLC_USE_CXX11
+class MoonOutputProp : public OperatorProperty {
+ public:
+  std::vector<std::string> ListArguments() const override {
+    return {"data", "label"};
+  }
+
+  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
+    param_.Init(kwargs);
+  }
+
+  std::map<std::string, std::string> GetParams() const override {
+    return param_.__DICT__();
+  }
+
+  bool InferShape(std::vector<TShape> *in_shape,
+                  std::vector<TShape> *out_shape,
+                  std::vector<TShape> *aux_shape) const override {
+    using namespace mshadow;
+    CHECK_EQ(in_shape->size(), 2) << "Input:[data, label]";
+    const TShape &dshape = in_shape->at(0);
+    if (dshape.ndim() == 0) return false;
+    SHAPE_ASSIGN_CHECK(*in_shape, moonout_enum::kLabel,
+                       Shape2(dshape[0], dshape[1]));
+    out_shape->clear();
+    out_shape->push_back(dshape);
+    return true;
+  }
+
+  bool InferType(std::vector<int> *in_type,
+                 std::vector<int> *out_type,
+                 std::vector<int> *aux_type) const override {
+    CHECK_GE(in_type->size(), 1);
+    int dtype = (*in_type)[0];
+    CHECK_NE(dtype, -1) << "First input must have specified type";
+    for (index_t i = 0; i < in_type->size(); ++i) {
+      if ((*in_type)[i] == -1) {
+        (*in_type)[i] = dtype;
+      } else {
+        CHECK_EQ((*in_type)[i], dtype) << "This layer requires uniform type. "
+                                       << "Expected " << dtype << " v.s. given "
+                                       << (*in_type)[i] << " at " << ListArguments()[i];
+      }
+    }
+    out_type->clear();
+    out_type->push_back(dtype);
+    return true;
+  }
+
+  OperatorProperty* Copy() const override {
+    auto ptr = new MoonOutputProp();
+    ptr->param_ = param_;
+    return ptr;
+  }
+
+  std::string TypeString() const override {
+    return "MoonOutput";
+  }
+
+  std::vector<int> DeclareBackwardDependency(
+    const std::vector<int> &out_grad,
+    const std::vector<int> &in_data,
+    const std::vector<int> &out_data) const override {
+    return {in_data[moonout_enum::kLabel], out_data[moonout_enum::kOut]};
+  }
+
+  std::vector<std::pair<int, void*> > BackwardInplaceOption(
+    const std::vector<int> &out_grad,
+    const std::vector<int> &in_data,
+    const std::vector<int> &out_data,
+    const std::vector<void*> &in_grad) const override {
+    return {{out_data[moonout_enum::kOut], in_grad[moonout_enum::kData]}};
+  }
+
+  std::vector<std::pair<int, void*> > ForwardInplaceOption(
+    const std::vector<int> &in_data,
+    const std::vector<void*> &out_data) const override {
+    return {{in_data[moonout_enum::kData], out_data[moonout_enum::kOut]}};
+  }
+
+  Operator* CreateOperator(Context ctx) const override {
+    LOG(FATAL) << "Not Implemented.";
+    return NULL;
+  }
+
+  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                             std::vector<int> *in_type) const override;
+
+ protected:
+  MoonOutputParam param_;
+};  // class MoonOutputProp
+
+class DeprecatedMoonProp : public MoonOutputProp {
+ public:
+  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
+    MoonOutputProp::param_.Init(kwargs);
+  }
+
+  std::string TypeString() const override {
+    return "Moon";
+  }
+};
+#endif  // DMLC_USE_CXX11
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_MOON_OUTPUT_INL_H_
diff --git a/src/operator/moon_output.cc b/src/operator/moon_output.cc
new file mode 100644
index 0000000000..a73b799a59
--- /dev/null
+++ b/src/operator/moon_output.cc
@@ -0,0 +1,75 @@
+/*!
+ * Copyright (c) 2016 by Contributors
+ * \file moon_output.cc
+ * \brief
+ * \author Wei Wu
+*/
+#include <vector>
+#include <math.h>
+#include "./moon_output-inl.h"
+
+namespace mshadow {
+template<typename Dtype>
+inline void MoonBackward(const Tensor<cpu, 2, Dtype> &grad_data,
+	const Tensor<cpu, 2, Dtype> &out_data,
+	const Tensor<cpu, 2, Dtype> &input_label,
+	const std::vector<float> &src_dist) {
+	const Dtype *data = out_data.dptr_;
+	const Dtype *label = input_label.dptr_;
+	Dtype *grad = grad_data.dptr_;
+	Dtype weight = 0.0;
+	for (index_t n = 0; n < out_data.size(0); ++n) {
+		for (index_t c = 0; c < out_data.size(1); ++c) {
+			const int index = c * out_data.size(0) + n;
+			if (1 == int(label[index]) && src_dist[c] > 0.5) {
+				weight = (1 - src_dist[c]) / src_dist[c];
+			}
+			else if (-1 == int(label[index]) && src_dist[c] < 0.5) {
+				weight = src_dist[c] / (1 - src_dist[c]);
+			}
+			else {
+				weight = 1.0;
+			}
+			grad[index] = 2.0 * (data[index] - label[index]) * weight;
+		}
+	}
+}
+} // namespace mshadow
+
+namespace mxnet {
+namespace op {
+template<>
+Operator *CreateOp<cpu>(MoonOutputParam param, int dtype) {
+  Operator *op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new MoonOutputOp<cpu, DType>(param);
+  })
+  return op;
+}
+
+// DO_BIND_DISPATCH comes from operator_common.h
+Operator *MoonOutputProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
+                                     std::vector<int> *in_type) const {
+  std::vector<TShape> out_shape, aux_shape;
+  std::vector<int> out_type, aux_type;
+  CHECK(InferType(in_type, &out_type, &aux_type));
+  CHECK(InferShape(in_shape, &out_shape, &aux_shape));
+  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+}
+
+DMLC_REGISTER_PARAMETER(MoonOutputParam);
+
+MXNET_REGISTER_OP_PROPERTY(MoonOutput, MoonOutputProp)
+.describe("Perform a moon transformation on input, backprop with logloss.")
+.add_argument("data", "Symbol", "Input data to moon.")
+.add_argument("label", "Symbol", "Label data.")
+.add_arguments(MoonOutputParam::__FIELDS__());
+
+MXNET_REGISTER_OP_PROPERTY(Moon, DeprecatedMoonProp)
+.describe("DEPRECATED: Perform a moon transformation on input. Please use MoonOutput")
+.add_argument("data", "Symbol", "Input data to moon.")
+.add_arguments(MoonOutputParam::__FIELDS__());
+
+}  // namespace op
+}  // namespace mxnet
+
diff --git a/src/operator/moon_output.cu b/src/operator/moon_output.cu
new file mode 100644
index 0000000000..b540372d03
--- /dev/null
+++ b/src/operator/moon_output.cu
@@ -0,0 +1,84 @@
+/*!
+ * Copyright (c) 2016 by Contributors
+ * \file moon_output.cu
+ * \brief
+ * \author Wei Wu
+*/
+#include <vector>
+#include <mshadow/tensor.h>
+#include <mshadow/cuda/tensor_gpu-inl.cuh>
+#include "./moon_output-inl.h"
+
+#define CU2DBLOCK_X 32
+#define CU2DBLOCK_Y 32
+
+namespace mshadow {
+namespace cuda{
+template<typename DType>
+__global__ void MoonBackwardKernel(DType *grad, const DType *data, const DType *label, const float *src_dist,
+	const int cols, const int rows, const int stride) {
+	int i = blockIdx.x * blockDim.x + threadIdx.x;
+	int j = blockIdx.y * blockDim.y + threadIdx.y;
+	int num_threads_x = blockDim.x * gridDim.x;
+	int num_threads_y = blockDim.y * gridDim.y;
+	DType weight = 0.0;
+	for (int index = 0; i < cols && j < rows; i += num_threads_x, j += num_threads_y) {
+		index = i * stride + j;
+		if (1 == int(label[index]) && src_dist[i] > 0.5) {
+			weight = (1 - src_dist[i]) / src_dist[i];
+		}
+		else if (-1 == int(label[index]) && src_dist[i] < 0.5) {
+			weight = src_dist[i] / (1 - src_dist[i]);
+		}
+		else {
+			weight = 1.0;
+		}
+		grad[index] = 2.0 * (data[index] - label[index]) * weight;
+	}
+}
+
+template<typename DType>
+inline void MoonBackward(const Tensor<gpu, 2, DType> &grad_data,
+							const Tensor<gpu, 2, DType> &out_data,
+							const Tensor<gpu, 2, DType> &input_label,
+							const std::vector<float> &src_dist) {
+	const DType *data = out_data.dptr_;
+	const DType *label = input_label.dptr_;
+	DType *grad = grad_data.dptr_;
+	dim3 threads_per_block(CU2DBLOCK_X, CU2DBLOCK_Y);
+	dim3 num_blocks((out_data.size(1) + threads_per_block.x - 1) / threads_per_block.x, 
+					(out_data.size(0) + threads_per_block.y - 1) / threads_per_block.y);
+	CheckLaunchParam(num_blocks, threads_per_block, "Moon Backward");
+	cudaStream_t stream = Stream<gpu>::GetStream(grad_data.stream_);
+	// maybe these is a better solutive to construct a Tensor<gpu> with a std::vector
+	float *dist;
+	cudaMalloc((void**)&dist, src_dist.size()*sizeof(float));
+	cudaMemcpyAsync(dist, &src_dist[0], src_dist.size()*sizeof(float), cudaMemcpyHostToDevice, stream);
+	MoonBackwardKernel<DType> << <num_blocks, threads_per_block, 0, stream >> >(grad, data, label, dist, 
+		out_data.size(1), out_data.size(0), out_data.size(0));
+}
+} // namespace cuda
+
+template<typename DType>
+inline void MoonBackward(const Tensor<gpu, 2, DType> &grad_data,
+	const Tensor<gpu, 2, DType> &out_data,
+	const Tensor<gpu, 2, DType> &input_label,
+	const std::vector<float> &src_dist) {
+	cuda::MoonBackward(grad_data, out_data, input_label, src_dist);
+}
+} // namespace mshadow
+
+namespace mxnet {
+namespace op {
+template<>
+Operator *CreateOp<gpu>(MoonOutputParam param, int dtype) {
+  Operator *op = NULL;
+  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+    op = new MoonOutputOp<gpu, DType>(param);
+  })
+  return op;
+}
+
+}  // namespace op
+}  // namespace mxnet
+


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services