You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2017/12/14 19:16:48 UTC

[GitHub] piiswrong closed pull request #7781: Implement Khatri-Rao operator

piiswrong closed pull request #7781: Implement Khatri-Rao operator
URL: https://github.com/apache/incubator-mxnet/pull/7781
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/c_lapack_api.h b/src/operator/c_lapack_api.h
index 293c3f2f81..46c8b963f4 100644
--- a/src/operator/c_lapack_api.h
+++ b/src/operator/c_lapack_api.h
@@ -143,19 +143,7 @@ inline char loup(char uplo, bool invert) { return invert ? (uplo == 'U' ? 'L' :
  * \param lda leading dimension of a
  */
 template <typename xpu, typename DType>
-inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda);
-
-template <>
-inline void flip<cpu, float>(int m, int n,
-  float *b, int ldb, float *a, int lda) {
-  for (int i = 0; i < m; ++i)
-    for (int j = 0; j < n; ++j)
-      b[j * ldb + i] = a[i * lda + j];
-}
-
-template <>
-inline void flip<cpu, double>(int m, int n,
-  double *b, int ldb, double *a, int lda) {
+inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {
   for (int i = 0; i < m; ++i)
     for (int j = 0; j < n; ++j)
       b[j * ldb + i] = a[i * lda + j];
diff --git a/src/operator/contrib/krprod.cc b/src/operator/contrib/krprod.cc
new file mode 100644
index 0000000000..b5f9117ef3
--- /dev/null
+++ b/src/operator/contrib/krprod.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+n * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+/*!
+ *  \file krprod.cc
+ *  \brief Operator registration for Khatri-Rao product
+ *  \author Chris Swierczewski
+ */
+
+#include <mshadow/tensor.h>
+#include <mxnet/op_attr_types.h>
+#include <mxnet/operator_util.h>
+#include <vector>
+#include <string>
+#include <algorithm>
+#include "../mshadow_op.h"
+#include "../mxnet_op.h"
+#include "../operator_common.h"
+#include "../elemwise_op_common.h"
+#include "../../ndarray/ndarray_function.h"
+#include "krprod.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool KhatriRaoShape(
+      const nnvm::NodeAttrs& attrs,
+      std::vector<TShape> *in_attrs,
+      std::vector<TShape> *out_attrs) {
+  CHECK_EQ(out_attrs->size(), 1);
+  CHECK_GE(in_attrs->size(), 1);
+
+  // all input and output matrices must have the same number of rows/columns
+  // (when inputs_transposed is set to true/false)
+  int num_columns = static_cast<int>((*in_attrs)[0][1]);
+  int num_rows = 1;
+  for (const TShape& attr_shape : (*in_attrs)) {
+    CHECK_EQ(num_columns, static_cast<int>(attr_shape[1]));
+    num_rows *= attr_shape[0];
+  }
+  SHAPE_ASSIGN_CHECK(*out_attrs, 0, Shape2(num_rows, num_columns));
+  return true;
+}
+
+
+struct KhatriRaoParam : public dmlc::Parameter<KhatriRaoParam> {
+  int num_args;
+  bool row_wise = false;
+  DMLC_DECLARE_PARAMETER(KhatriRaoParam) {
+    DMLC_DECLARE_FIELD(num_args)
+      .set_lower_bound(1)
+      .describe("Number of input matrices.");
+  }
+};
+DMLC_REGISTER_PARAMETER(KhatriRaoParam);
+
+
+NNVM_REGISTER_OP(khatri_rao)
+.describe(R"code(Computes the Khatri-Rao product of the input matrices.
+
+Given a collection of :math:`n` input matrices,
+
+.. math::
+   A_1 \in \mathbb{R}^{M_1 \times M}, \ldots, A_n \in \mathbb{R}^{M_n \times N},
+
+the (column-wise) Khatri-Rao product is defined as the matrix,
+
+.. math::
+   X = A_1 \otimes \cdots \otimes A_n \in \mathbb{R}^{(M_1 \cdots M_n) \times N},
+
+where the :math:`k`th column is equal to the column-wise outer product
+:math:`{A_1}_k \otimes \cdots \otimes {A_n}_k` where :math:`{A_i}_k` is the kth
+column of the ith matrix.
+
+Example::
+
+  >>> A = mx.nd.array([[1, -1],
+  >>>                  [2, -3]])
+  >>> B = mx.nd.array([[1, 4],
+  >>>                  [2, 5],
+  >>>                  [3, 6]])
+  >>> C = mx.nd.khatri_rao(A, B)
+  >>> print(C.asnumpy())
+  [[  1.  -4.]
+   [  2.  -5.]
+   [  3.  -6.]
+   [  2. -12.]
+   [  4. -15.]
+   [  6. -18.]]
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<KhatriRaoParam>)
+.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
+    uint32_t ret = dmlc::get<KhatriRaoParam>(attrs.parsed).num_args;
+    return ret;
+  })
+.set_num_outputs(1)
+.set_attr<nnvm::FInferShape>("FInferShape", KhatriRaoShape)
+.set_attr<nnvm::FInferType>("FInferType",
+  [](const nnvm::NodeAttrs& attrs,
+     std::vector<int> *in_attrs,
+     std::vector<int> *out_attrs) {
+    return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
+      attrs, in_attrs, out_attrs, -1);
+  })
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+  [](const NodeAttrs& attrs) {
+    uint32_t num_args = dmlc::get<KhatriRaoParam>(attrs.parsed).num_args;
+    std::vector<std::string> ret;
+    for (uint32_t i = 0; i < num_args; ++i)
+      ret.push_back(std::string("arg") + std::to_string(i));
+    return ret;
+  })
+.set_attr<FCompute>("FCompute<cpu>", KhatriRaoCompute<cpu>)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("args", "NDArray-or-Symbol[]", "Positional input matrices");
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/contrib/krprod.h b/src/operator/contrib/krprod.h
index 90a6179e07..4fc362e963 100644
--- a/src/operator/contrib/krprod.h
+++ b/src/operator/contrib/krprod.h
@@ -21,14 +21,17 @@
  *  Copyright (c) 2017 by Contributors
  *  \file krprod.h
  *  \brief Core function for Khatri-Rao product
- *  \author Jencir Lee
+ *  \author Jencir Lee, Chris Swierczewski
  */
 #ifndef MXNET_OPERATOR_CONTRIB_KRPROD_H_
 #define MXNET_OPERATOR_CONTRIB_KRPROD_H_
+#include <algorithm>
+#include <utility>
 #include <vector>
 #include "mshadow/tensor.h"
 #include "../c_lapack_api.h"
 
+
 namespace mxnet {
 namespace op {
 
@@ -247,6 +250,40 @@ inline void inv_khatri_rao
     LOG(FATAL) << "The linear solver in inv_prod() returns " << info;
 }
 
+
+template<typename xpu, typename DType>
+inline void KhatriRaoCompute_(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &in_data,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &out_data) {
+  using namespace mxnet_op;
+  if (req[0] == kNullOp) return;
+
+  Stream<xpu> *stream = ctx.get_stream<xpu>();
+  Tensor<xpu, 2, DType> out = out_data[0].get<xpu, 2, DType>(stream);
+  std::vector<Tensor<xpu, 2, DType> > ts_arr(in_data.size());
+  std::transform(in_data.begin(), in_data.end(), ts_arr.begin(),
+                 [&stream](TBlob blob) -> Tensor<xpu, 2, DType> {
+                   return blob.get<xpu, 2, DType>(stream);
+                 });
+  khatri_rao(out, ts_arr);
+}
+
+
+template<typename xpu>
+inline void KhatriRaoCompute(const nnvm::NodeAttrs &attrs,
+                             const OpContext &ctx,
+                             const std::vector<TBlob> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<TBlob> &outputs) {
+  using namespace mxnet_op;
+  CHECK_EQ(outputs.size(), 1U);
+  MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      KhatriRaoCompute_<xpu, DType>(attrs, ctx, inputs, req, outputs);
+  });
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/tests/python/unittest/test_contrib_krprod.py b/tests/python/unittest/test_contrib_krprod.py
new file mode 100644
index 0000000000..07c0fb843b
--- /dev/null
+++ b/tests/python/unittest/test_contrib_krprod.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# pylint: skip-file
+
+from __future__ import print_function
+import numpy as np
+import mxnet as mx
+
+from numpy.testing import assert_allclose
+
+def assert_mx_allclose(A, B, **kwds):
+    return assert_allclose(A.asnumpy(), B.asnumpy(), **kwds)
+
+
+def test_krprod_one_input():
+    A = mx.nd.arange(1,9).reshape((2,4))
+    out = mx.nd.khatri_rao(A)
+    assert_mx_allclose(out, A, rtol=1e-12)
+
+
+def test_krprod_two_inputs():
+    A = mx.nd.arange(1,7).reshape((3,2))
+    B = mx.nd.arange(1,3).reshape((1,2))
+    out = mx.nd.khatri_rao(A, B)
+    expected = mx.nd.array([[1,4],[3,8],[5,12]])
+    assert_mx_allclose(out, expected, rtol=1e-12)
+
+    A = mx.nd.arange(1,7).reshape((3,2))
+    B = mx.nd.arange(1,9).reshape((4,2))
+    out = mx.nd.khatri_rao(A, B)
+    expected = mx.nd.array([[1,4],[3,8],[5,12],[7,16],[3,8],[9,16],[15,24],
+                            [21,32],[5,12],[15,24],[25,36],[35,48]])
+    assert_mx_allclose(out, expected, rtol=1e-12)
+
+
+def test_krprod_three_inputs():
+    A = mx.nd.arange(1,7).reshape((3,2))
+    B = mx.nd.arange(1,3).reshape((1,2))
+    C = mx.nd.arange(1,5).reshape((2,2))
+    out = mx.nd.khatri_rao(A, B, C)
+    expected = mx.nd.array([[1,8],[3,16],[3,16],[9,32],[5,24],[15,48]])
+    assert_mx_allclose(out, expected, rtol=1e-12)
+
+    out_AB = mx.nd.khatri_rao(A, B)
+    out = mx.nd.khatri_rao(out_AB, C)
+    assert_mx_allclose(out, expected, rtol=1e-12)
+
+    out_BC = mx.nd.khatri_rao(B, C)
+    out = mx.nd.khatri_rao(A, out_BC)
+    assert_mx_allclose(out, expected, rtol=1e-12)


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services