You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by la...@apache.org on 2020/07/17 23:34:08 UTC
[incubator-mxnet] 01/01: Revert "Add qr backward for wide matrices
with m < n (#18197)"
This is an automated email from the ASF dual-hosted git repository.
lausen pushed a commit to branch revert-18197-qr_back_two
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit bd52fb75af3406e7618440fc39d44c7ceca62bf9
Author: Leonard Lausen <le...@lausen.nl>
AuthorDate: Fri Jul 17 16:33:24 2020 -0700
Revert "Add qr backward for wide matrices with m < n (#18197)"
This reverts commit 60d067288842b38a2a485e42e874f1551cba248c.
---
src/operator/numpy/linalg/np_qr-inl.h | 185 ++++++---------------------------
tests/python/unittest/test_numpy_op.py | 147 +++++++-------------------
2 files changed, 71 insertions(+), 261 deletions(-)
diff --git a/src/operator/numpy/linalg/np_qr-inl.h b/src/operator/numpy/linalg/np_qr-inl.h
index c204520..0f332e4 100644
--- a/src/operator/numpy/linalg/np_qr-inl.h
+++ b/src/operator/numpy/linalg/np_qr-inl.h
@@ -483,53 +483,19 @@ struct assign_helper {
}
};
-// backprop helper to get y, v
-struct QrBackHelper_G1 {
- template<typename DType>
- MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data,
- const int ldin, DType *out_data, const int ldout) {
- const int offin(k * m * ldin);
- const int offout(k * m * ldout);
- for (index_t i = 0; i < m; ++i) {
- for (index_t j = 0; j < n - m; ++j) {
- out_data[offout + i * ldout + j] = in_data[offin + m + i * ldin + j];
- }
- }
- }
-};
-
-// backprop helper to get da from dx, dy
-struct QrBackHelper_G2 {
- template<typename DType>
- MSHADOW_XINLINE static void Map(const int k, const int m, const int n, const DType *in_data_x,
- const int ldinx, const DType *in_data_y, const int ldiny,
- DType *out_data, const int ldout) {
- const int offiny(k * m * ldiny);
- const int offinx(k * m * ldinx);
- const int offout(k * m * ldout);
- for (index_t i = 0; i < m; ++i) {
- for (index_t j = 0; j < n - m; ++j) {
- out_data[offout + m + i * ldout + j] = in_data_y[offiny + i * ldiny + j];
- }
- for (index_t j = 0; j < m; ++j) {
- out_data[offout + i * ldout + j] = in_data_x[offinx + i * ldinx + j];
- }
- }
- }
-};
-
-// Reference https://journals.aps.org/prx/pdf/10.1103/PhysRevX.9.031041
struct qr_backward {
template<typename xpu, typename DType>
static void op(const Tensor<xpu, 3, DType>& dA,
const Tensor<xpu, 3, DType>& dQ,
const Tensor<xpu, 3, DType>& dR,
+ const Tensor<xpu, 3, DType>& A,
const Tensor<xpu, 3, DType>& Q,
const Tensor<xpu, 3, DType>& R,
const Tensor<xpu, 3, DType>& M,
const OpContext& ctx, const nnvm::NodeAttrs& attrs) {
- // Implements da = [dq + q@copyltu(M))]@r**(-T)
+ // Implements case m >= n; da = [dq + q@copyltu(M))]@r**(-T)
// Where M = r@(dr**T) - (dq**T)@q
+ // Reference: https://arxiv.org/abs/1710.08717
Stream<xpu> *s = ctx.get_stream<xpu>();
if (dQ.dptr_ != dA.dptr_) Copy(dA, dQ, s);
// M = R@dR_T
@@ -548,30 +514,15 @@ struct qr_backward {
template<typename xpu>
size_t QrBackwardWorkspaceSize(const TBlob& a,
- const TBlob& q,
const TBlob& r,
const TBlob& grad_a) {
- const mxnet::TShape& a_shape = a.shape_;
- const int a_ndim = a_shape.ndim();
- const int n = a.size(a_ndim - 1);
- const int m = a.size(a_ndim - 2);
-
if (0U == a.Size()) { return 0U; }
MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
size_t work_space_size = 0;
+ // for grad a and M
work_space_size += a.Size();
- if (m >= n) {
- work_space_size += r.Size();
- } else {
- const mxnet::TShape& q_shape = q.shape_;
- mxnet::TShape v_shape(q_shape);
- v_shape[a_ndim - 1] = n - m;
- // allocate space for: m, u, dq_prime, du, dx (shaped like Q)
- work_space_size += 5 * q.Size();
- // allocate space for: y, dv (shaped like V, the partition of R)
- work_space_size += 2 * v_shape.Size();
- }
+ work_space_size += r.Size();
return work_space_size * sizeof(DType);
});
LOG(FATAL) << "InternalError: cannot reach here";
@@ -591,10 +542,8 @@ void QrBackwardImpl(const TBlob& grad_a,
const nnvm::NodeAttrs& attrs) {
Stream<xpu> *s = ctx.get_stream<xpu>();
const mxnet::TShape& a_shape = a.shape_;
- const mxnet::TShape& q_shape = q.shape_;
const mxnet::TShape& r_shape = r.shape_;
const int a_ndim = a_shape.ndim();
- const int m = a.size(a_ndim - 2);
const int n = a.size(a_ndim - 1);
if (kNullOp == req[0]) { return; }
@@ -602,105 +551,27 @@ void QrBackwardImpl(const TBlob& grad_a,
if (0U == a_shape.Size()) { return; }
MSHADOW_SGL_DBL_TYPE_SWITCH(grad_a.type_flag_, DType, {
- // common for all shapes (m, n)
- DType *grad_a_ptr = reinterpret_cast<DType*>(workspace.dptr_);
+ // case m >= n; Q of same shape with A and R is (n, n)
+ DType *m_ptr = reinterpret_cast<DType*>(workspace.dptr_);
+ DType *grad_a_ptr = m_ptr + r_shape.Size();
+ TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
TBlob grad_a_data(grad_a_ptr, a_shape, xpu::kDevMask);
- if (m >= n) {
- // Q of same shape with A (m, n) and R is (n, n)
- DType *m_ptr = grad_a_ptr + a_shape.Size();
- TBlob temp_m(m_ptr, r_shape, xpu::kDevMask);
- // dR_T
- mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
- s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);
- qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
- grad_q.FlatToKD<xpu, 3, DType>(s),
- grad_r.FlatToKD<xpu, 3, DType>(s),
- q.FlatToKD<xpu, 3, DType>(s),
- r.FlatToKD<xpu, 3, DType>(s),
- temp_m.FlatToKD<xpu, 3, DType>(s),
- ctx, attrs);
- } else {
- // R is same shape with A (m, n) and Q is (m, m)
- // Partition A = (X | Y); R = (U | V)
- // X and U are (m, m); Y and V are (m, n - m)
- mxnet::TShape v_shape(q_shape);
- v_shape[a_ndim - 1] = n - m;
-
- DType *m_ptr = grad_a_ptr + a_shape.Size();
- DType *u_ptr = m_ptr + q_shape.Size();
- DType *dq_prime_ptr = u_ptr + q_shape.Size();
- DType *dv_ptr = dq_prime_ptr + q_shape.Size();
- DType *y_ptr = dv_ptr + v_shape.Size();
- DType *du_ptr = y_ptr + v_shape.Size();
- DType *dx_ptr = du_ptr + q_shape.Size();
-
- TBlob temp_m(m_ptr, q_shape, xpu::kDevMask);
- TBlob u_data(u_ptr, q_shape, xpu::kDevMask);
- TBlob dq_prime_data(dq_prime_ptr, q_shape, xpu::kDevMask);
- TBlob dv_data(dv_ptr, v_shape, xpu::kDevMask);
- TBlob y_data(y_ptr, v_shape, xpu::kDevMask);
- TBlob du_data(du_ptr, q_shape, xpu::kDevMask);
- TBlob dx_data(dx_ptr, q_shape, xpu::kDevMask);
-
- Tensor<xpu, 3, DType> R = r.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dR = grad_r.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> Q = q.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dQ = grad_q.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dQ_prime = dq_prime_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> A = a.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dA = grad_a_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> U = u_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dU = du_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dV = dv_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> Y = y_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> dX = dx_data.FlatToKD<xpu, 3, DType>(s);
- Tensor<xpu, 3, DType> M = temp_m.FlatToKD<xpu, 3, DType>(s);
-
- // U
- for (index_t i = 0; i < R.size(0); ++i) {
- const Tensor<xpu, 2, DType>& Ri = R[i];
- const Tensor<xpu, 2, DType>& Ui = U[i];
- Tensor<xpu, 2, DType> Um(Ri.dptr_, Shape2(m, m), Ri.stride_, s);
- Copy(Ui, Um, s);
- }
- // dU
- for (index_t i = 0; i < dR.size(0); ++i) {
- const Tensor<xpu, 2, DType>& dRi = dR[i];
- const Tensor<xpu, 2, DType>& dUi = dU[i];
- Tensor<xpu, 2, DType> dUm(dRi.dptr_, Shape2(m, m), dRi.stride_, s);
- Copy(dUi, dUm, s);
- }
- // Y
- mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
- s, A.size(0), m, n, A.dptr_, A.stride_, Y.dptr_, Y.stride_);
- // dV
- mxnet_op::Kernel<QrBackHelper_G1, xpu>::Launch(
- s, dR.size(0), m, n, dR.dptr_, dR.stride_, dV.dptr_, dV.stride_);
- // store dU_T in M
- mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
- s, q_shape.Size(), dU.dptr_, m_ptr, m, m, m * m);
- // dq_prime = dQ
- Copy(dQ_prime, dQ, s);
- // dq_prime = dQ+Y@dV.T
- gemm::op(Y, dV, dQ_prime, DType(1.0), DType(1.0), false, true, s);
- // dX = op call
- qr_backward::op(dX,
- dQ_prime,
- dU,
- Q,
- U,
- M,
- ctx, attrs);
- // dY = Q@dV; reuse Y memory for dY
- gemm::op(Q, dV, Y, DType(1.0), DType(0.0), false, false, s);
- // copy dX and dY to dA
- mxnet_op::Kernel<QrBackHelper_G2, xpu>::Launch(
- s, dA.size(0), m, n, dX.dptr_, dX.stride_, Y.dptr_, Y.stride_, dA.dptr_, dA.stride_);
- }
- // common for all shapes
+ // dR_T
+ mxnet_op::Kernel<QrTypeTransposeHelper, xpu>::Launch(
+ s, r_shape.Size(), grad_r.dptr<DType>(), m_ptr, n, n, n * n);
+
+ qr_backward::op(grad_a_data.FlatToKD<xpu, 3, DType>(s),
+ grad_q.FlatToKD<xpu, 3, DType>(s),
+ grad_r.FlatToKD<xpu, 3, DType>(s),
+ a.FlatToKD<xpu, 3, DType>(s),
+ q.FlatToKD<xpu, 3, DType>(s),
+ r.FlatToKD<xpu, 3, DType>(s),
+ temp_m.FlatToKD<xpu, 3, DType>(s),
+ ctx, attrs);
+
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
- mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch(
- s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>());
+ mxnet_op::Kernel<assign_helper<req_type>, xpu>::Launch(
+ s, a_shape.Size(), grad_a_data.dptr<DType>(), grad_a.dptr<DType>());
});
});
}
@@ -723,8 +594,14 @@ void NumpyLaQrBackward(const nnvm::NodeAttrs& attrs,
const TBlob& q = inputs[3];
const TBlob& r = inputs[4];
const TBlob& grad_a = outputs[0];
+ const int a_ndim = a.shape_.ndim();
+ const int n = a.size(a_ndim - 1);
+ const int m = a.size(a_ndim - 2);
+
+ CHECK_LE(n, m)
+ << "QrBackward not implemented when ncols > nrows";
- size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, q, r, grad_a);
+ size_t workspace_size = QrBackwardWorkspaceSize<xpu>(a, r, grad_a);
Tensor<xpu, 1, char> workspace = ctx.requested[0]
.get_space_typed<xpu, 1, char>(Shape1(workspace_size), ctx.get_stream<xpu>());
QrBackwardImpl<xpu>(grad_a, grad_q, grad_r, a, q, r, req, workspace, ctx, attrs);
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 7197949..88ad77f 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -5761,102 +5761,42 @@ def test_np_linalg_qr():
def hybrid_forward(self, F, data):
return F.np.linalg.qr(data)
- def get_expected_grad(a, q, r, dq, dr):
- # all shapes (..., m, n)
- # allow feeding different dq and dr values
+ def get_expected_grad(a, q, r):
if 0 in r.shape:
return r
- def _copyltu(M):
- eye = _np.array([_np.eye(M.shape[-1]) for i in range(M.shape[0])])
- lower = _np.tril(M) - eye * M
- lower_mask = _np.tril(_np.ones_like(M))
- ret = lower_mask * M + lower.swapaxes(-1, -2)
- return ret
- def _case_m_ge_n(a, q, r, dq, dr):
- dq_t = dq.swapaxes(-1, -2)
- dr_t = dr.swapaxes(-1, -2)
- r_inv = _np.linalg.inv(r)
- r_inv_t = r_inv.swapaxes(-1, -2)
- r_t = r.swapaxes(-1, -2)
- # Get M
- M = _np.matmul(r, dr_t) - _np.matmul(dq_t, q)
- da = _np.matmul(dq + _np.matmul(q, _copyltu(M)), r_inv_t)
- return da
- m, n = a.shape[-2], a.shape[-1]
- x = a[..., :, :m]
- x_shape = x.shape
- y = a[..., :, m:]
- y_shape = y.shape
- u = r[..., :, :m]
- v = r[..., :, m:]
- dv = dr[..., :, m:]
- du = dr[..., :, :m]
- q = q.reshape(-1, q.shape[-2], q.shape[-1])
- u = u.reshape(-1, u.shape[-2], u.shape[-1])
- dq = dq.reshape(-1, q.shape[-2], q.shape[-1])
- du = du.reshape(-1, du.shape[-2], du.shape[-1])
- if m >= n:
- dx = _case_m_ge_n(x, q, u, dq, du).reshape(x_shape)
- return dx
- else:
- dv = dv.reshape(-1, dv.shape[-2], dv.shape[-1])
- y = y.reshape(-1, y.shape[-2], y.shape[-1])
- dy = _np.matmul(q, dv).reshape(y_shape)
- dq_prime = dq + _np.matmul(y, dv.swapaxes(-1, -2))
- dx = _case_m_ge_n(x, q, u, dq_prime, du).reshape(x_shape)
- da = _np.concatenate([dx, dy], axis=-1)
- return da
-
- def _analytical_jacobian(x, dy, Q, R, Q_, R_, k):
- x_size = _np.prod(x.shape)
- dy_size = _np.prod(dy.shape)
- # jacobian has data_np size number of rows and dQ or dR size number of columns
- jacobian = _np.zeros((x_size, dy_size))
- # dQ and dR have all elements equal to zero to begin with
- dy_data = _np.zeros(dy.shape)
- dy_data_flat = dy_data.ravel()
- for col in range(dy_size):
- # we only feed dQ or dR with 1 element changed to 1 at a time
- dy_data_flat[col] = 1
- ret_ = dy_data_flat.reshape(dy.shape)
- if k == 0:
- # k is 0 when dy is dQ
- jacobian[:, col] = get_expected_grad(x, dy, R, ret_, R_).ravel()
- else:
- # k is 1 when dy is dR
- jacobian[:, col] = get_expected_grad(x, Q, dy, Q_, ret_).ravel()
- dy_data_flat[col] = 0
- return jacobian
-
- def _numerical_jacobian(x, y, delta, k, dtype):
- # compute central differences
- x_size = _np.prod(x.shape)
- y_size = _np.prod(y.shape)
- scale = _np.asarray(2 * delta)[()]
- # jacobian has data_np size number of rows and Q or R size number of columns
- jacobian_num = _np.zeros((x_size, y_size))
- for row in range(x_size):
- x_pos = x.copy()
- x_neg = x.copy()
- # add delta to one element of data_np at a time
- x_pos.ravel().view(dtype)[row] += delta # one element in x is added delta
- # get qr decomposition of new input with one changed element
- ret_pos = np.linalg.qr(np.array(x_pos))[k]
- # subtract delta from input data_np one element at a time
- x_neg.ravel().view(dtype)[row] -= delta
- # get qr decomposition of new input with one changed element
- ret_neg = np.linalg.qr(np.array(x_neg))[k]
- # get central differences
- diff = (ret_pos - ret_neg) / scale
- jacobian_num[row, :] = diff.asnumpy().ravel().view(dtype)
- return jacobian_num
+ def copyltu(M):
+ # shape of M is [batch, m, m]
+ eye = _np.array([_np.eye(M.shape[-1]) for i in range(M.shape[0])])
+ lower = _np.tril(M) - eye * M
+ lower_mask = _np.tril(_np.ones_like(M))
+ ret = lower_mask * M + lower.swapaxes(-1, -2)
+ return ret
+ shape_r = r.shape
+ shape_q = q.shape
+ shape_a = a.shape
+ r = r.reshape(-1, shape_r[-2], shape_r[-1])
+ q = q.reshape(-1, shape_q[-2], shape_q[-1])
+ dq = _np.ones_like(q)
+ dr = _np.ones_like(r)
+ dq_t = dq.swapaxes(-1, -2)
+ dr_t = dr.swapaxes(-1, -2)
+ r_inv = _np.linalg.inv(r)
+ r_inv_t = r_inv.swapaxes(-1, -2)
+ r_t = r.swapaxes(-1, -2)
+ # Get M
+ M = _np.matmul(r, dr_t) - _np.matmul(dq_t, q)
+ da = _np.matmul(dq + _np.matmul(q, copyltu(M)), r_inv_t)
+ return da.reshape(a.shape)
def well_conditioned_rectang_matrix_2D(shape, max_cond=4):
m, n = shape[-2], shape[-1]
while 1:
- Q1, R1 = _np.linalg.qr(_np.random.uniform(-10, 10, (m, m)))
- D = _np.eye(m, n)
- Q2, R2 = _np.linalg.qr(_np.random.uniform(-10, 10, (n, n)))
+ M1 = _np.random.uniform(-10, 10, (m, n))
+ Q1, R1 = _np.linalg.qr(M1)
+ s = _np.ones(n)
+ D = _np.diag(s)
+ M2 =_np.random.uniform(-10, 10, (n, n))
+ Q2, R2 = _np.linalg.qr(M2)
a = _np.matmul(_np.matmul(Q1, D), _np.swapaxes(Q2, -1, -2))
if (_np.linalg.cond(a, 2) < max_cond):
return a
@@ -5896,6 +5836,7 @@ def test_np_linalg_qr():
(3, 3),
(5, 5),
(8, 8),
+ (4, 5),
(4, 6),
(5, 4),
(6, 5),
@@ -5909,16 +5850,19 @@ def test_np_linalg_qr():
(4, 2, 2, 1),
(2, 3, 4, 3)
]
- dtypes = ['float64', 'float32']
+ dtypes = ['float32', 'float64']
for hybridize, shape, dtype in itertools.product([False, True], shapes, dtypes):
rtol = atol = 0.01
test_qr = TestQR()
if hybridize:
test_qr.hybridize()
+
if 0 in shape:
data_np = _np.ones(shape)
- else:
+ elif shape[-2] >= shape[-1]:
data_np = well_conditioned_rectang_matrix_nD(shape, max_cond=4)
+ else:
+ data_np = _np.random.uniform(-10.0, 10.0, shape)
data_np = _np.array(data_np, dtype=dtype)
data = np.array(data_np, dtype=dtype)
@@ -5928,24 +5872,13 @@ def test_np_linalg_qr():
Q, R = ret[0], ret[1]
check_qr(Q, R, data_np)
- if 0 not in R.shape:
+ # Only shapes m >= n have gradient
+ if 0 not in R.shape and shape[-2] >= shape[-1]:
assert data.grad.shape == data_np.shape
- backward_expected = get_expected_grad(data_np, Q.asnumpy(), R.asnumpy(),
- _np.ones(Q.shape), _np.ones(R.shape))
+ backward_expected = get_expected_grad(data_np, Q.asnumpy(), R.asnumpy())
mx.autograd.backward(ret)
assert_almost_equal(data.grad.asnumpy(), backward_expected, rtol=rtol, atol=atol)
- # for a few cases, check that the analytical jacobian is equal to
- # numerical jacobian computed via central differences
- # restrict this check to float64 for numerical precision
- if dtype == 'float64' and len(shape) == 2:
- epsilon = _np.finfo(dtype).eps
- delta = 0.1 * epsilon**(1.0 / 3.0) # Optimal delta for central differences
- for k, b in enumerate(ret):
- qr_num = _numerical_jacobian(data_np, b.asnumpy(), delta, k, dtype)
- qr_a = _analytical_jacobian(x=data_np, dy=b.asnumpy(), Q=Q.asnumpy(),
- R=R.asnumpy(), Q_=_np.zeros(Q.shape),
- R_=_np.zeros(R.shape), k=k)
- assert_almost_equal(qr_num, qr_a, rtol=rtol, atol=atol)
+
# check imperative once more; mode='reduced' is default
# behavior and optional parameter in original numpy
ret = np.linalg.qr(data, mode='reduced')