You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2018/08/30 22:43:07 UTC
[3/3] systemml git commit: [SYSTEMML-445] Removed batch_norm builtin
functions
[SYSTEMML-445] Removed batch_norm builtin functions
- Removed batch_norm builtin functions to exploit codegen in CP.
- Added rewrites for compiling efficient CuDNN operators.
- Added rewrites for SGD update operations.
- To simplify adding new GPU rewrites, added HopDagPatternMatcher that allows for pattern matching at the HOP-level. This can be extended for other rewrites as well.
- Added GPU tests to validate the rewrites.
- Updated the DML language documentation.
Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/0f36780a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/0f36780a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/0f36780a
Branch: refs/heads/master
Commit: 0f36780a8244c6e728d37c32a79e00ed181211ad
Parents: 81419ae
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Thu Aug 30 15:40:44 2018 -0700
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Thu Aug 30 15:40:44 2018 -0700
----------------------------------------------------------------------
docs/dml-language-reference.md | 2 -
scripts/nn/layers/batch_norm2d.dml | 60 +-
scripts/nn/layers/batch_norm2d_old.dml | 200 ----
src/main/cpp/kernels/SystemML.cu | 56 +-
src/main/cpp/kernels/SystemML.ptx | 321 +++++-
src/main/java/org/apache/sysml/hops/DnnOp.java | 56 +-
.../java/org/apache/sysml/hops/FunctionOp.java | 30 +-
src/main/java/org/apache/sysml/hops/Hop.java | 8 +-
.../hops/rewrite/HopDagPatternMatcher.java | 378 +++++++
.../sysml/hops/rewrite/HopPatternRewriter.java | 72 ++
.../HopRewriteRuleWithPatternMatcher.java | 98 ++
.../sysml/hops/rewrite/HopRewriteUtils.java | 20 +
.../hops/rewrite/RewriteGPUSpecificOps.java | 1027 +++++-------------
.../org/apache/sysml/lops/DnnTransform.java | 53 +-
.../sysml/parser/BuiltinFunctionExpression.java | 61 +-
.../org/apache/sysml/parser/DMLTranslator.java | 2 -
.../org/apache/sysml/parser/Expression.java | 2 +-
.../instructions/GPUInstructionParser.java | 10 +-
.../instructions/gpu/DnnGPUInstruction.java | 526 +++++----
.../gpu/GPUDenseInputPointerFetcher.java | 111 ++
.../gpu/context/GPUMemoryManager.java | 2 +-
.../runtime/matrix/data/LibMatrixCUDA.java | 110 +-
.../runtime/matrix/data/LibMatrixCuDNN.java | 37 +-
.../apache/sysml/test/gpu/BatchNormTest.java | 47 +-
24 files changed, 1818 insertions(+), 1471 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/docs/dml-language-reference.md
----------------------------------------------------------------------
diff --git a/docs/dml-language-reference.md b/docs/dml-language-reference.md
index 924336a..cdcc529 100644
--- a/docs/dml-language-reference.md
+++ b/docs/dml-language-reference.md
@@ -1522,8 +1522,6 @@ Hence, the images are internally represented as a matrix with dimension (N, C *
| bias_add | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Adds the bias (row vector of size num_channels) to input with the given num_channels |
| bias_multiply | input, bias | [batch_size X num_channels* height_image* width_image] | [num_channels X 1] | [batch_size X num_channels* height_image* width_image] | | Multiplies the bias (row vector of size num_channels) to input with the given num_channels |
| lstm | X, W, bias, out0, c0 | [batch_size X seq_length*num_features] | [num_features+hidden_size X 4*hidden_size] | [batch_size X seq_length*hidden_size] if return_sequences else [batch_size X hidden_size] | return_sequences | Perform computation for single-layer unidirectional LSTM (outputs: out, carryOut) |
-| batch_norm2d | input | [batch_size X num_channels* height_image* width_image] | | [batch_size X num_channels* height_image* width_image] | scale, shift, exponentialMovingAverage_Mean, exponentialMovingAverage_Variance, mode, epsilon, momentum | Performs batch normalization operation (outputs: updated exponential moving average mean and variance, cache of the batch mean and variance) |
-| batch_norm2d_backward | input, dout | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | [batch_size X num_channels* height_image* width_image] | scale, epsilon, cache_mean (from forward), cache_inv_var (from forward) | Computed backpropagation error for batch normalization operation |
Note: the builtin functions `batch_norm2d` and `batch_norm2d_backward` are deprecated and will be removed in the next release. The `lstm` builtin function is in experimental phase and is only supported for the GPU backend.
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d.dml b/scripts/nn/layers/batch_norm2d.dml
index 2a98857..c68f23d 100644
--- a/scripts/nn/layers/batch_norm2d.dml
+++ b/scripts/nn/layers/batch_norm2d.dml
@@ -83,8 +83,41 @@ forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
* - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
* Note: This is used for performance during training.
*/
- out = X; ema_mean_upd = ema_mean; ema_var_upd = ema_var; cache_mean = ema_mean; cache_inv_var = ema_var
- [out, ema_mean_upd, ema_var_upd, cache_mean, cache_inv_var] = batch_norm2d(X, gamma, beta, ema_mean, ema_var, mode, epsilon, mu)
+ N = nrow(X)
+
+ if (mode == 'train') {
+ # Compute channel-wise mean and variance
+ # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
+ # - mean of total group is mean of subgroup means
+ # - variance is the mean of the subgroup variances + the variance of the subgroup means
+ subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
+ subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances
+ mean = rowMeans(subgrp_means) # shape (C, 1)
+ var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1)
+ # Update moving averages
+ ema_mean_upd = mu*ema_mean + (1-mu)*mean
+ ema_var_upd = mu*ema_var + (1-mu)*var
+ }
+ else {
+ # Use moving averages of mean and variance during testing
+ mean = ema_mean
+ var = ema_var
+ ema_mean_upd = ema_mean
+ ema_var_upd = ema_var
+ }
+
+ # Save variable for backward pass
+ cache_mean = mean
+ cache_inv_var = 1/sqrt(var+epsilon)
+
+ # Normalize, shift, and scale
+ # norm = (X-mean)*(var+epsilon)^(-1/2)
+ # = (X-mean) / sqrt(var+epsilon)
+ centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
+ norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
+ # out = norm*gamma + beta
+ scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win)
+ out = bias_add(scaled, beta) # shape (N, C*Hin*Win)
}
backward = function(matrix[double] dout,
@@ -119,9 +152,27 @@ backward = function(matrix[double] dout,
* - dbeta: Gradient wrt `b`, of shape (C, 1).
*
*/
+ N = nrow(X)
+ oneByN = 1/N
+ oneByHW = 1/(Hin*Win)
+
+ mean = cache_mean
+ centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
+ norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
# Compute gradients during training
- dX = X; dgamma = gamma; dbeta = gamma;
- [dX, dgamma, dbeta] = batch_norm2d_backward(X, dout, gamma, epsilon, cache_mean, cache_inv_var)
+ dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
+ dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
+ dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
+ dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
+ C, Hin, Win) # shape (C, 1)
+ dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
+ dmean_var_branch = util::channel_sums((-2*oneByN*oneByHW) * centered, C, Hin, Win)
+ dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet
+ dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1)
+ dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
+ dX_mean_branch = (oneByN*oneByHW) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
+ dX_var_branch = (2*oneByN*oneByHW) * bias_multiply(centered, dvar)
+ dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win)
}
init = function(int C)
@@ -149,3 +200,4 @@ init = function(int C)
ema_mean = matrix(0, rows=C, cols=1)
ema_var = matrix(1, rows=C, cols=1)
}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/scripts/nn/layers/batch_norm2d_old.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm2d_old.dml b/scripts/nn/layers/batch_norm2d_old.dml
deleted file mode 100644
index 2aba2e6..0000000
--- a/scripts/nn/layers/batch_norm2d_old.dml
+++ /dev/null
@@ -1,200 +0,0 @@
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-/*
- * 2D (Spatial) Batch Normalization layer.
- */
-source("nn/util.dml") as util
-
-forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
- int C, int Hin, int Win, string mode,
- matrix[double] ema_mean, matrix[double] ema_var,
- double mu, double epsilon)
- return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
- matrix[double] cache_mean, matrix[double] cache_inv_var) {
- /*
- * Computes the forward pass for a 2D (spatial) batch normalization
- * layer. The input data has N examples, each represented as a 3D
- * volume unrolled into a single vector.
- *
- * A spatial batch normalization layer uses the per-channel sample
- * mean and per-channel uncorrected sample variance during training
- * to normalize each channel of the input data. Additionally, it
- * introduces learnable parameters (gamma, beta) to control the
- * amount of normalization.
- *
- * `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
- *
- * This implementation maintains exponential moving averages of the
- * mean and variance during training for use during testing.
- *
- * Reference:
- * - Batch Normalization: Accelerating Deep Network Training by
- * Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
- * - https://arxiv.org/abs/1502.03167
- *
- * Inputs:
- * - X: Inputs, of shape (N, C*Hin*Win).
- * - gamma: Scale parameters, of shape (C, 1).
- * - beta: Shift parameters, of shape (C, 1).
- * - C: Number of input channels (dimensionality of input depth).
- * - Hin: Input height.
- * - Win: Input width.
- * - mode: 'train' or 'test' to indicate if the model is currently
- * being trained or tested. During training, the current batch
- * mean and variance will be used to normalize the inputs, while
- * during testing, the exponential average of the mean and
- * variance over all previous batches will be used.
- * - ema_mean: Exponential moving average of the mean, of
- * shape (C, 1).
- * - ema_var: Exponential moving average of the variance, of
- * shape (C, 1).
- * - mu: Momentum value for moving averages.
- * Typical values are in the range of [0.9, 0.999].
- * - epsilon: Smoothing term to avoid divide by zero errors.
- * Typical values are in the range of [1e-5, 1e-3].
- *
- * Outputs:
- * - out: Outputs, of shape (N, C*Hin*Win).
- * - ema_mean_upd: Updated exponential moving average of the mean,
- * of shape (C, 1).
- * - ema_var_upd: Updated exponential moving average of the variance,
- * of shape (C, 1).
- * - cache_mean: Cache of the batch mean, of shape (C, 1).
- * Note: This is used for performance during training.
- * - cache_inv_var: Cache of the inverse variance, of shape (C, 1).
- * Note: This is used for performance during training.
- */
- N = nrow(X)
-
- if (mode == 'train') {
- # Compute channel-wise mean and variance
- # Since we don't have tensors, we will compute the means and variances in a piece-wise fashion.
- # - mean of total group is mean of subgroup means
- # - variance is the mean of the subgroup variances + the variance of the subgroup means
- subgrp_means = matrix(colMeans(X), rows=C, cols=Hin*Win)
- subgrp_vars = matrix(colVars(X) * ((N-1)/N), rows=C, cols=Hin*Win) # uncorrected variances
- mean = rowMeans(subgrp_means) # shape (C, 1)
- var = rowMeans(subgrp_vars) + rowVars(subgrp_means)*(((Hin*Win)-1)/(Hin*Win)) # shape (C, 1)
- # Update moving averages
- ema_mean_upd = mu*ema_mean + (1-mu)*mean
- ema_var_upd = mu*ema_var + (1-mu)*var
- }
- else {
- # Use moving averages of mean and variance during testing
- mean = ema_mean
- var = ema_var
- ema_mean_upd = ema_mean
- ema_var_upd = ema_var
- }
-
- # Save variable for backward pass
- cache_mean = mean
- cache_inv_var = 1/sqrt(var+epsilon)
-
- # Normalize, shift, and scale
- # norm = (X-mean)*(var+epsilon)^(-1/2)
- # = (X-mean) / sqrt(var+epsilon)
- centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
- norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
- # out = norm*gamma + beta
- scaled = bias_multiply(norm, gamma) # shape (N, C*Hin*Win)
- out = bias_add(scaled, beta) # shape (N, C*Hin*Win)
-}
-
-backward = function(matrix[double] dout,
- matrix[double] cache_mean, matrix[double] cache_inv_var,
- matrix[double] X, matrix[double] gamma,
- int C, int Hin, int Win, double epsilon)
- return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
- /*
- * Computes the backward pass for a 2D (spatial) batch normalization
- * layer.
- *
- * Inputs:
- * - dout: Gradient wrt `out` from upstream, of shape (N, C*Hin*Win).
- * - cache_mean: Cache of the batch mean from the forward pass, of
- * shape (C, 1). Note: This is used for performance during
- * training.
- * - cache_inv_var: Cache of the inverse variance from the forward pass,
- * of shape (C, 1). Note: This is used for performance during
- * training.
- * - X: Input data matrix to the forward pass, of
- * shape (N, C*Hin*Win).
- * - gamma: Scale parameters, of shape (C, 1).
- * - C: Number of input channels (dimensionality of input depth).
- * - Hin: Input height.
- * - Win: Input width.
- * - epsilon: Smoothing term to avoid divide by zero errors.
- * Typical values are in the range of [1e-5, 1e-3].
- *
- * Outputs:
- * - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
- * - dgamma: Gradient wrt `W`, of shape (C, 1).
- * - dbeta: Gradient wrt `b`, of shape (C, 1).
- *
- */
- N = nrow(X)
- mean = cache_mean
- centered = bias_add(X, -mean) # shape (N, C*Hin*Win)
- norm = bias_multiply(centered, cache_inv_var) # shape (N, C*Hin*Win)
- # Compute gradients during training
- dgamma = util::channel_sums(dout*norm, C, Hin, Win) # shape (C, 1)
- dbeta = util::channel_sums(dout, C, Hin, Win) # shape (C, 1)
- dnorm = bias_multiply(dout, gamma) # shape (N, C*Hin*Win)
- dvar = util::channel_sums((-1/2) * bias_multiply(centered, cache_inv_var^3) * dnorm,
- C, Hin, Win) # shape (C, 1)
- dmean_norm_branch = util::channel_sums(bias_multiply(dnorm, -cache_inv_var), C, Hin, Win)
- dmean_var_branch = util::channel_sums((-2/(N*Hin*Win)) * centered, C, Hin, Win)
- dmean_var_branch = dmean_var_branch * dvar # we can't use a function within an expression yet
- dmean = dmean_norm_branch + dmean_var_branch # shape (C, 1)
- dX_norm_branch = bias_multiply(dnorm, cache_inv_var)
- dX_mean_branch = (1/(N*Hin*Win)) * bias_add(matrix(0, rows=1, cols=C*Hin*Win), dmean)
- dX_var_branch = (2/(N*Hin*Win)) * bias_multiply(centered, dvar)
- dX = dX_norm_branch + dX_mean_branch + dX_var_branch # shape (N, C*Hin*Win)
-}
-
-init = function(int C)
- return (matrix[double] gamma, matrix[double] beta,
- matrix[double] ema_mean, matrix[double] ema_var) {
- /*
- * Initialize the parameters of this layer.
- *
- * Note: This is just a convenience function, and parameters
- * may be initialized manually if needed.
- *
- * Inputs:
- * - C: Number of input channels (dimensionality of input depth).
- *
- * Outputs:
- * - gamma: Scale parameters, of shape (C, 1).
- * - beta: Shift parameters, of shape (C, 1).
- * - ema_mean: Exponential moving average of the mean, of
- * shape (C, 1).
- * - ema_var: Exponential moving average of the variance, of
- * shape (C, 1).
- */
- gamma = matrix(1, rows=C, cols=1)
- beta = matrix(0, rows=C, cols=1)
- ema_mean = matrix(0, rows=C, cols=1)
- ema_var = matrix(1, rows=C, cols=1)
-}
-
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.cu
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.cu b/src/main/cpp/kernels/SystemML.cu
index 9ddaaff..b874cdd 100644
--- a/src/main/cpp/kernels/SystemML.cu
+++ b/src/main/cpp/kernels/SystemML.cu
@@ -2289,4 +2289,58 @@ extern "C" __global__ void update_nesterov_x_d(double *X, double *v, double *v_p
extern "C" __global__ void update_nesterov_x_f(float *X, float *v, float *v_prev, double mu, float *out, unsigned int size) {
update_nesterov_x(X, v, v_prev, mu, out, size);
-}
\ No newline at end of file
+}
+
+// Performs the operation: C = a*X + b*C
+template <typename T>
+__device__ void aXplusbC(T *X, T *C, double a, double b, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = a*X[index] + b*C[index];
+ }
+}
+
+extern "C" __global__ void aXplusbC_d(double *X, double *C, double a, double b, unsigned int size) {
+ aXplusbC(X, C, a, b,size);
+}
+
+extern "C" __global__ void aXplusbC_f(float *X, float *C, double a, double b, unsigned int size) {
+ aXplusbC(X, C, a, b,size);;
+}
+
+
+// Performs the operation: C = a*X + b*Y
+template <typename T>
+__device__ void aXplusbY(T *X, T* Y, T *C, double a, double b, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = a*X[index] + b*Y[index];
+ }
+}
+
+extern "C" __global__ void aXplusbY_d(double *X, double* Y, double *C, double a, double b, unsigned int size) {
+ aXplusbY(X, Y, C, a, b, size);
+}
+
+extern "C" __global__ void aXplusbY_f(float *X, float* Y, float *C, double a, double b, unsigned int size) {
+ aXplusbY(X, Y, C, a, b, size);
+}
+
+
+// Performs the operation: C = 1 / sqrt(X + eps)
+template <typename T>
+__device__ void invVar(T *X, T *C, double eps, unsigned int size) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < size) {
+ C[index] = 1.0 / sqrt(X[index] + eps);
+ }
+}
+
+extern "C" __global__ void invVar_d(double *X, double *C, double eps, unsigned int size) {
+ invVar(X, C, eps, size);
+}
+
+extern "C" __global__ void invVar_f(float *X, float *C, double eps, unsigned int size) {
+ invVar(X, C, eps, size);
+}
+
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/cpp/kernels/SystemML.ptx
----------------------------------------------------------------------
diff --git a/src/main/cpp/kernels/SystemML.ptx b/src/main/cpp/kernels/SystemML.ptx
index 8a14876..1ab32f5 100644
--- a/src/main/cpp/kernels/SystemML.ptx
+++ b/src/main/cpp/kernels/SystemML.ptx
@@ -13084,12 +13084,279 @@ BB115_2:
ret;
}
+ // .globl aXplusbC_d
+.visible .entry aXplusbC_d(
+ .param .u64 aXplusbC_d_param_0,
+ .param .u64 aXplusbC_d_param_1,
+ .param .f64 aXplusbC_d_param_2,
+ .param .f64 aXplusbC_d_param_3,
+ .param .u32 aXplusbC_d_param_4
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [aXplusbC_d_param_0];
+ ld.param.u64 %rd2, [aXplusbC_d_param_1];
+ ld.param.f64 %fd1, [aXplusbC_d_param_2];
+ ld.param.f64 %fd2, [aXplusbC_d_param_3];
+ ld.param.u32 %r2, [aXplusbC_d_param_4];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB116_2;
+
+ cvta.to.global.u64 %rd3, %rd2;
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 8;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f64 %fd3, [%rd6];
+ add.s64 %rd7, %rd3, %rd5;
+ ld.global.f64 %fd4, [%rd7];
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ st.global.f64 [%rd7], %fd6;
+
+BB116_2:
+ ret;
+}
+
+ // .globl aXplusbC_f
+.visible .entry aXplusbC_f(
+ .param .u64 aXplusbC_f_param_0,
+ .param .u64 aXplusbC_f_param_1,
+ .param .f64 aXplusbC_f_param_2,
+ .param .f64 aXplusbC_f_param_3,
+ .param .u32 aXplusbC_f_param_4
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<4>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [aXplusbC_f_param_0];
+ ld.param.u64 %rd2, [aXplusbC_f_param_1];
+ ld.param.f64 %fd1, [aXplusbC_f_param_2];
+ ld.param.f64 %fd2, [aXplusbC_f_param_3];
+ ld.param.u32 %r2, [aXplusbC_f_param_4];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB117_2;
+
+ cvta.to.global.u64 %rd3, %rd2;
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 4;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f32 %f1, [%rd6];
+ cvt.f64.f32 %fd3, %f1;
+ add.s64 %rd7, %rd3, %rd5;
+ ld.global.f32 %f2, [%rd7];
+ cvt.f64.f32 %fd4, %f2;
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvt.rn.f32.f64 %f3, %fd6;
+ st.global.f32 [%rd7], %f3;
+
+BB117_2:
+ ret;
+}
+
+ // .globl aXplusbY_d
+.visible .entry aXplusbY_d(
+ .param .u64 aXplusbY_d_param_0,
+ .param .u64 aXplusbY_d_param_1,
+ .param .u64 aXplusbY_d_param_2,
+ .param .f64 aXplusbY_d_param_3,
+ .param .f64 aXplusbY_d_param_4,
+ .param .u32 aXplusbY_d_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<11>;
+
+
+ ld.param.u64 %rd1, [aXplusbY_d_param_0];
+ ld.param.u64 %rd2, [aXplusbY_d_param_1];
+ ld.param.u64 %rd3, [aXplusbY_d_param_2];
+ ld.param.f64 %fd1, [aXplusbY_d_param_3];
+ ld.param.f64 %fd2, [aXplusbY_d_param_4];
+ ld.param.u32 %r2, [aXplusbY_d_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB118_2;
+
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 8;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f64 %fd3, [%rd6];
+ cvta.to.global.u64 %rd7, %rd2;
+ add.s64 %rd8, %rd7, %rd5;
+ ld.global.f64 %fd4, [%rd8];
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd5;
+ st.global.f64 [%rd10], %fd6;
+
+BB118_2:
+ ret;
+}
+
+ // .globl aXplusbY_f
+.visible .entry aXplusbY_f(
+ .param .u64 aXplusbY_f_param_0,
+ .param .u64 aXplusbY_f_param_1,
+ .param .u64 aXplusbY_f_param_2,
+ .param .f64 aXplusbY_f_param_3,
+ .param .f64 aXplusbY_f_param_4,
+ .param .u32 aXplusbY_f_param_5
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<4>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<7>;
+ .reg .b64 %rd<11>;
+
+
+ ld.param.u64 %rd1, [aXplusbY_f_param_0];
+ ld.param.u64 %rd2, [aXplusbY_f_param_1];
+ ld.param.u64 %rd3, [aXplusbY_f_param_2];
+ ld.param.f64 %fd1, [aXplusbY_f_param_3];
+ ld.param.f64 %fd2, [aXplusbY_f_param_4];
+ ld.param.u32 %r2, [aXplusbY_f_param_5];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB119_2;
+
+ cvta.to.global.u64 %rd4, %rd1;
+ mul.wide.s32 %rd5, %r1, 4;
+ add.s64 %rd6, %rd4, %rd5;
+ ld.global.f32 %f1, [%rd6];
+ cvt.f64.f32 %fd3, %f1;
+ cvta.to.global.u64 %rd7, %rd2;
+ add.s64 %rd8, %rd7, %rd5;
+ ld.global.f32 %f2, [%rd8];
+ cvt.f64.f32 %fd4, %f2;
+ mul.f64 %fd5, %fd4, %fd2;
+ fma.rn.f64 %fd6, %fd3, %fd1, %fd5;
+ cvt.rn.f32.f64 %f3, %fd6;
+ cvta.to.global.u64 %rd9, %rd3;
+ add.s64 %rd10, %rd9, %rd5;
+ st.global.f32 [%rd10], %f3;
+
+BB119_2:
+ ret;
+}
+
+ // .globl invVar_d
+.visible .entry invVar_d(
+ .param .u64 invVar_d_param_0,
+ .param .u64 invVar_d_param_1,
+ .param .f64 invVar_d_param_2,
+ .param .u32 invVar_d_param_3
+)
+{
+ .reg .pred %p<2>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<6>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [invVar_d_param_0];
+ ld.param.u64 %rd2, [invVar_d_param_1];
+ ld.param.f64 %fd1, [invVar_d_param_2];
+ ld.param.u32 %r2, [invVar_d_param_3];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB120_2;
+
+ cvta.to.global.u64 %rd3, %rd1;
+ mul.wide.s32 %rd4, %r1, 8;
+ add.s64 %rd5, %rd3, %rd4;
+ ld.global.f64 %fd2, [%rd5];
+ add.f64 %fd3, %fd2, %fd1;
+ sqrt.rn.f64 %fd4, %fd3;
+ rcp.rn.f64 %fd5, %fd4;
+ cvta.to.global.u64 %rd6, %rd2;
+ add.s64 %rd7, %rd6, %rd4;
+ st.global.f64 [%rd7], %fd5;
+
+BB120_2:
+ ret;
+}
+
+ // .globl invVar_f
+.visible .entry invVar_f(
+ .param .u64 invVar_f_param_0,
+ .param .u64 invVar_f_param_1,
+ .param .f64 invVar_f_param_2,
+ .param .u32 invVar_f_param_3
+)
+{
+ .reg .pred %p<2>;
+ .reg .f32 %f<3>;
+ .reg .b32 %r<6>;
+ .reg .f64 %fd<6>;
+ .reg .b64 %rd<8>;
+
+
+ ld.param.u64 %rd1, [invVar_f_param_0];
+ ld.param.u64 %rd2, [invVar_f_param_1];
+ ld.param.f64 %fd1, [invVar_f_param_2];
+ ld.param.u32 %r2, [invVar_f_param_3];
+ mov.u32 %r3, %ctaid.x;
+ mov.u32 %r4, %ntid.x;
+ mov.u32 %r5, %tid.x;
+ mad.lo.s32 %r1, %r4, %r3, %r5;
+ setp.ge.u32 %p1, %r1, %r2;
+ @%p1 bra BB121_2;
+
+ cvta.to.global.u64 %rd3, %rd1;
+ mul.wide.s32 %rd4, %r1, 4;
+ add.s64 %rd5, %rd3, %rd4;
+ ld.global.f32 %f1, [%rd5];
+ cvt.f64.f32 %fd2, %f1;
+ add.f64 %fd3, %fd2, %fd1;
+ sqrt.rn.f64 %fd4, %fd3;
+ rcp.rn.f64 %fd5, %fd4;
+ cvt.rn.f32.f64 %f2, %fd5;
+ cvta.to.global.u64 %rd6, %rd2;
+ add.s64 %rd7, %rd6, %rd4;
+ st.global.f32 [%rd7], %f2;
+
+BB121_2:
+ ret;
+}
+
.func (.param .b64 func_retval0) __internal_trig_reduction_slowpathd(
.param .b64 __internal_trig_reduction_slowpathd_param_0,
.param .b64 __internal_trig_reduction_slowpathd_param_1
)
{
- .local .align 8 .b8 __local_depot116[40];
+ .local .align 8 .b8 __local_depot122[40];
.reg .b64 %SP;
.reg .b64 %SPL;
.reg .pred %p<9>;
@@ -13098,7 +13365,7 @@ BB115_2:
.reg .b64 %rd<102>;
- mov.u64 %rd101, __local_depot116;
+ mov.u64 %rd101, __local_depot122;
cvta.local.u64 %SP, %rd101;
ld.param.f64 %fd4, [__internal_trig_reduction_slowpathd_param_0];
ld.param.u64 %rd37, [__internal_trig_reduction_slowpathd_param_1];
@@ -13112,7 +13379,7 @@ BB115_2:
shr.u32 %r3, %r1, 20;
bfe.u32 %r4, %r1, 20, 11;
setp.eq.s32 %p1, %r4, 2047;
- @%p1 bra BB116_13;
+ @%p1 bra BB122_13;
add.s32 %r15, %r4, -1024;
shr.u32 %r16, %r15, 6;
@@ -13125,7 +13392,7 @@ BB115_2:
mov.u64 %rd94, 0;
setp.ge.s32 %p2, %r5, %r6;
mov.u64 %rd93, %rd1;
- @%p2 bra BB116_4;
+ @%p2 bra BB122_4;
mov.b64 %rd41, %fd4;
shl.b64 %rd42, %rd41, 11;
@@ -13142,7 +13409,7 @@ BB115_2:
mov.u64 %rd91, %rd1;
mov.u32 %r39, %r5;
-BB116_3:
+BB122_3:
.pragma "nounroll";
ld.const.u64 %rd47, [%rd89];
// inline asm
@@ -13172,15 +13439,15 @@ BB116_3:
add.s64 %rd93, %rd93, 8;
add.s64 %rd89, %rd89, 8;
setp.lt.s32 %p3, %r39, %r6;
- @%p3 bra BB116_3;
+ @%p3 bra BB122_3;
-BB116_4:
+BB122_4:
st.local.u64 [%rd93], %rd94;
ld.local.u64 %rd95, [%rd1+16];
ld.local.u64 %rd96, [%rd1+24];
and.b32 %r9, %r3, 63;
setp.eq.s32 %p4, %r9, 0;
- @%p4 bra BB116_6;
+ @%p4 bra BB122_6;
mov.u32 %r27, 64;
sub.s32 %r28, %r27, %r9;
@@ -13192,7 +13459,7 @@ BB116_4:
shr.u64 %rd55, %rd54, %r28;
or.b64 %rd95, %rd55, %rd53;
-BB116_6:
+BB122_6:
cvta.to.local.u64 %rd56, %rd37;
shr.u64 %rd57, %rd96, 62;
cvt.u32.u64 %r29, %rd57;
@@ -13209,7 +13476,7 @@ BB116_6:
selp.b32 %r34, %r32, %r33, %p5;
st.local.u32 [%rd56], %r34;
setp.eq.s32 %p6, %r31, 0;
- @%p6 bra BB116_8;
+ @%p6 bra BB122_8;
mov.u64 %rd64, 0;
// inline asm
@@ -13229,10 +13496,10 @@ BB116_6:
// inline asm
xor.b32 %r40, %r40, -2147483648;
-BB116_8:
+BB122_8:
clz.b64 %r41, %rd98;
setp.eq.s32 %p7, %r41, 0;
- @%p7 bra BB116_10;
+ @%p7 bra BB122_10;
shl.b64 %rd67, %rd98, %r41;
mov.u32 %r35, 64;
@@ -13240,7 +13507,7 @@ BB116_8:
shr.u64 %rd68, %rd97, %r36;
or.b64 %rd98, %rd68, %rd67;
-BB116_10:
+BB122_10:
mov.u64 %rd72, -3958705157555305931;
// inline asm
{
@@ -13261,7 +13528,7 @@ BB116_10:
}
// inline asm
setp.lt.s64 %p8, %rd100, 1;
- @%p8 bra BB116_12;
+ @%p8 bra BB122_12;
// inline asm
{
@@ -13280,7 +13547,7 @@ BB116_10:
// inline asm
add.s32 %r41, %r41, 1;
-BB116_12:
+BB122_12:
cvt.u64.u32 %rd79, %r40;
shl.b64 %rd80, %rd79, 32;
mov.u32 %r37, 1022;
@@ -13295,7 +13562,7 @@ BB116_12:
or.b64 %rd88, %rd87, %rd80;
mov.b64 %fd4, %rd88;
-BB116_13:
+BB122_13:
st.param.f64 [func_retval0+0], %fd4;
ret;
}
@@ -13323,7 +13590,7 @@ BB116_13:
}
shr.u32 %r51, %r50, 20;
setp.ne.s32 %p1, %r51, 0;
- @%p1 bra BB117_2;
+ @%p1 bra BB123_2;
mul.f64 %fd14, %fd12, 0d4350000000000000;
{
@@ -13337,13 +13604,13 @@ BB116_13:
shr.u32 %r16, %r50, 20;
add.s32 %r51, %r16, -54;
-BB117_2:
+BB123_2:
add.s32 %r52, %r51, -1023;
and.b32 %r17, %r50, -2146435073;
or.b32 %r18, %r17, 1072693248;
mov.b64 %fd135, {%r49, %r18};
setp.lt.u32 %p2, %r18, 1073127583;
- @%p2 bra BB117_4;
+ @%p2 bra BB123_4;
{
.reg .b32 %temp;
@@ -13357,7 +13624,7 @@ BB117_2:
mov.b64 %fd135, {%r19, %r21};
add.s32 %r52, %r51, -1022;
-BB117_4:
+BB123_4:
add.f64 %fd15, %fd135, 0d3FF0000000000000;
rcp.approx.ftz.f64 %fd16, %fd15;
neg.f64 %fd17, %fd15;
@@ -13520,13 +13787,13 @@ BB117_4:
mov.b32 %f2, %r35;
abs.f32 %f1, %f2;
setp.lt.f32 %p4, %f1, 0f4086232B;
- @%p4 bra BB117_7;
+ @%p4 bra BB123_7;
setp.lt.f64 %p5, %fd4, 0d0000000000000000;
add.f64 %fd129, %fd4, 0d7FF0000000000000;
selp.f64 %fd136, 0d0000000000000000, %fd129, %p5;
setp.geu.f32 %p6, %f1, 0f40874800;
- @%p6 bra BB117_7;
+ @%p6 bra BB123_7;
mov.f64 %fd134, 0d4338000000000000;
mov.f64 %fd133, 0d3FF71547652B82FE;
@@ -13548,26 +13815,26 @@ BB117_4:
mov.b64 %fd131, {%r44, %r43};
mul.f64 %fd136, %fd130, %fd131;
-BB117_7:
+BB123_7:
{
.reg .b32 %temp;
mov.b64 {%temp, %r45}, %fd136;
}
and.b32 %r46, %r45, 2147483647;
setp.ne.s32 %p7, %r46, 2146435072;
- @%p7 bra BB117_9;
+ @%p7 bra BB123_9;
{
.reg .b32 %temp;
mov.b64 {%r47, %temp}, %fd136;
}
setp.eq.s32 %p8, %r47, 0;
- @%p8 bra BB117_10;
+ @%p8 bra BB123_10;
-BB117_9:
+BB123_9:
fma.rn.f64 %fd136, %fd136, %fd5, %fd136;
-BB117_10:
+BB123_10:
st.param.f64 [func_retval0+0], %fd136;
ret;
}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/DnnOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/DnnOp.java b/src/main/java/org/apache/sysml/hops/DnnOp.java
index a7d37dc..c4ce466 100644
--- a/src/main/java/org/apache/sysml/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysml/hops/DnnOp.java
@@ -110,8 +110,6 @@ public class DnnOp extends MultiThreadedHop
if( getLops() != null )
return getLops();
- ExecType et = optFindExecType();
-
ArrayList<Hop> inputs = getInput();
switch( op )
{
@@ -125,6 +123,7 @@ public class DnnOp extends MultiThreadedHop
case BIASADD:
case BIASMULT:
{
+ ExecType et = optFindExecType();
if(et == ExecType.CP || et == ExecType.GPU) {
setLops(constructDnnLops(et, inputs));
break;
@@ -137,15 +136,15 @@ public class DnnOp extends MultiThreadedHop
case BATCH_NORM2D_TEST:
case CHANNEL_SUMS:
case UPDATE_NESTEROV_X:
+ case UPDATE_EMA_VAR:
+ case RESHAPE_COLMEANS:
+ case UPDATE_EMA:
+ case INV_VAR:
+ case BATCH_NORM2D_BACKWARD_DX:
{
- if(et == ExecType.GPU) {
- setLops(constructDnnLops(et, inputs));
- break;
- }
- else {
- throw new HopsException("Unimplemented DnnOp for execution type: " + et.name());
- }
- // break;
+ // GPU-specific operators
+ setLops(constructDnnLops(ExecType.GPU, inputs));
+ break;
}
default:
throw new HopsException("Unsupported lops construction for operation type '"+op+"'.");
@@ -171,10 +170,16 @@ public class DnnOp extends MultiThreadedHop
return 14;
case BIASADD:
case BIASMULT:
+ case INV_VAR:
return 2;
case BATCH_NORM2D_TEST:
return 6;
+ case UPDATE_EMA_VAR:
+ case BATCH_NORM2D_BACKWARD_DX:
+ return 5;
+ case RESHAPE_COLMEANS:
case CHANNEL_SUMS:
+ case UPDATE_EMA:
return 3;
case UPDATE_NESTEROV_X:
return 4;
@@ -532,7 +537,8 @@ public class DnnOp extends MultiThreadedHop
long[] ret = new long[3];
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
- op == OpOpDnn.UPDATE_NESTEROV_X) {
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
// Same dimension as the first input
MatrixCharacteristics[] mc = memo.getAllInputStats(getInput());
ret[0] = mc[0].rowsKnown() ? mc[0].getRows() : -1;
@@ -540,13 +546,21 @@ public class DnnOp extends MultiThreadedHop
ret[2] = -1;
return (ret[0]>=0 && ret[1]>=0) ? ret : null;
}
- else if(op == OpOpDnn.CHANNEL_SUMS) {
+ else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
ret[0] = numChannels;
ret[1] = 1;
ret[2] = -1;
return ret;
}
+ else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+ long numChannels = Hop.computeSizeInformation(getInput().get(1));
+ long HW = Hop.computeSizeInformation(getInput().get(2));
+ ret[0] = numChannels;
+ ret[1] = HW;
+ ret[2] = -1;
+ return ret;
+ }
refreshSizeInformation();
ret[0] = _dim1; ret[1] = _dim2; ret[2] = _nnz;
@@ -739,7 +753,9 @@ public class DnnOp extends MultiThreadedHop
@Override
public void refreshSizeInformation()
{
- if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.UPDATE_NESTEROV_X) {
+ if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST ||
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
// Same dimension as the first input
Hop input1 = getInput().get(0);
setDim1(input1.getDim1());
@@ -747,13 +763,21 @@ public class DnnOp extends MultiThreadedHop
_nnz = -1; // cannot infer stats
return;
}
- else if(op == OpOpDnn.CHANNEL_SUMS) {
+ else if(op == OpOpDnn.CHANNEL_SUMS || op == OpOpDnn.UPDATE_EMA_VAR) {
long numChannels = Hop.computeSizeInformation(getInput().get(1));
setDim1(numChannels);
setDim2(1);
_nnz = -1; // cannot infer stats
return;
}
+ else if(op == OpOpDnn.RESHAPE_COLMEANS) {
+ long numChannels = Hop.computeSizeInformation(getInput().get(1));
+ long HW = Hop.computeSizeInformation(getInput().get(2));
+ setDim1(numChannels);
+ setDim2(HW);
+ _nnz = -1; // cannot infer stats
+ return;
+ }
// Reset the _cachedParams to avoid incorrect sizes
_cachedParams = new DnnParameters(-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, _maxNumThreads);
@@ -847,7 +871,9 @@ public class DnnOp extends MultiThreadedHop
*/
private long getDim(String dimString) {
if(op == OpOpDnn.BIASADD || op == OpOpDnn.BIASMULT || op == OpOpDnn.BATCH_NORM2D_TEST || op == OpOpDnn.CHANNEL_SUMS ||
- op == OpOpDnn.UPDATE_NESTEROV_X) {
+ op == OpOpDnn.UPDATE_NESTEROV_X || op == OpOpDnn.RESHAPE_COLMEANS ||
+ op == OpOpDnn.UPDATE_EMA_VAR || op == OpOpDnn.UPDATE_EMA || op == OpOpDnn.INV_VAR ||
+ op == OpOpDnn.BATCH_NORM2D_BACKWARD_DX) {
throw new RuntimeException("getDim method should not be invoked for " + op.name());
}
try {
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/FunctionOp.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java
index ea397db..5f177bd 100644
--- a/src/main/java/org/apache/sysml/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java
@@ -181,21 +181,6 @@ public class FunctionOp extends Hop
// TODO: To allow for initial version to always run on the GPU
return 0;
}
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(3).getDim1(), getOutputs().get(3).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(4).getDim1(), getOutputs().get(4).getDim2(), 1.0);
- }
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_test") ) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
- }
- else if ( getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ) {
- return OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0) +
- OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(2).getDim1(), getOutputs().get(2).getDim2(), 1.0);
- }
else if ( getFunctionName().equalsIgnoreCase("svd") ) {
long outputU = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(0).getDim1(), getOutputs().get(0).getDim2(), 1.0);
long outputSigma = OptimizerUtils.estimateSizeExactSparsity(getOutputs().get(1).getDim1(), getOutputs().get(1).getDim2(), 1.0);
@@ -226,10 +211,6 @@ public class FunctionOp extends Hop
return OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), getInput().get(0).getDim2(), 1.0)
+ 3*OptimizerUtils.estimateSizeExactSparsity(getInput().get(0).getDim1(), 1, 1.0);
}
- else if (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test")) {
- return 0;
- }
else if ( getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ) {
// TODO: To allow for initial version to always run on the GPU
return 0;
@@ -251,9 +232,7 @@ public class FunctionOp extends Hop
@Override
public boolean isGPUEnabled() {
- if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward") ||
- getFunctionName().equalsIgnoreCase("batch_norm2d_train") || getFunctionName().equalsIgnoreCase("batch_norm2d_test"))
+ if(getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))
return true;
else
return false;
@@ -308,13 +287,6 @@ public class FunctionOp extends Hop
throw new RuntimeException("The function " + getFunctionName() + " is only supported on GPU.");
_etype = ExecType.GPU;
}
- else if(isBuiltinFunction && (getFunctionName().equalsIgnoreCase("batch_norm2d") || getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
- _etype = ConfigurationManager.isGPU() ? ExecType.GPU : ExecType.CP;
- }
- else if(isBuiltinFunction && getFunctionName().equalsIgnoreCase("batch_norm2d_train")) {
- // Only GPU implementation is supported
- _etype = ExecType.GPU;
- }
else {
// Since the memory estimate is only conservative, do not throw
// exception if the estimated memory is larger than the budget
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/Hop.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 3b461a1..c8356e0 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -1100,7 +1100,8 @@ public abstract class Hop implements ParseInfo
MAX_POOL, MAX_POOL_BACKWARD, AVG_POOL, AVG_POOL_BACKWARD,
CONV2D, CONV2D_BACKWARD_FILTER, CONV2D_BACKWARD_DATA,
BIASADD, BIASMULT, BATCH_NORM2D_TEST, CHANNEL_SUMS,
- UPDATE_NESTEROV_X
+ UPDATE_NESTEROV_X, RESHAPE_COLMEANS, UPDATE_EMA_VAR, UPDATE_EMA, INV_VAR,
+ BATCH_NORM2D_BACKWARD_DX
}
public enum DataGenMethod {
@@ -1174,8 +1175,13 @@ public abstract class Hop implements ParseInfo
HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_FILTER, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_FILTER);
HopsConv2Lops.put(OpOpDnn.CONV2D_BACKWARD_DATA, org.apache.sysml.lops.DnnTransform.OperationTypes.CONV2D_BACKWARD_DATA);
HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_TEST, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_TEST);
+ HopsConv2Lops.put(OpOpDnn.UPDATE_EMA_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA_VAR);
HopsConv2Lops.put(OpOpDnn.CHANNEL_SUMS, org.apache.sysml.lops.DnnTransform.OperationTypes.CHANNEL_SUMS);
HopsConv2Lops.put(OpOpDnn.UPDATE_NESTEROV_X, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_NESTEROV_X);
+ HopsConv2Lops.put(OpOpDnn.RESHAPE_COLMEANS, org.apache.sysml.lops.DnnTransform.OperationTypes.RESHAPE_COLMEANS);
+ HopsConv2Lops.put(OpOpDnn.UPDATE_EMA, org.apache.sysml.lops.DnnTransform.OperationTypes.UPDATE_EMA);
+ HopsConv2Lops.put(OpOpDnn.INV_VAR, org.apache.sysml.lops.DnnTransform.OperationTypes.INV_VAR);
+ HopsConv2Lops.put(OpOpDnn.BATCH_NORM2D_BACKWARD_DX, org.apache.sysml.lops.DnnTransform.OperationTypes.BATCH_NORM2D_BACKWARD_DX);
}
protected static final HashMap<Hop.Direction, org.apache.sysml.lops.PartialAggregate.DirectionTypes> HopsDirection2Lops;
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
new file mode 100644
index 0000000..7c70b7b
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopDagPatternMatcher.java
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.function.Function;
+import java.util.function.Predicate;
+import org.apache.sysml.api.DMLScript;
+import org.apache.sysml.conf.ConfigurationManager;
+import org.apache.sysml.hops.AggUnaryOp;
+import org.apache.sysml.hops.Hop;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.hops.Hop.AggOp;
+import org.apache.sysml.hops.Hop.Direction;
+import org.apache.sysml.hops.Hop.OpOp1;
+import org.apache.sysml.hops.Hop.OpOp2;
+import org.apache.sysml.hops.Hop.OpOpDnn;
+import org.apache.sysml.hops.Hop.ReOrgOp;
+import org.apache.sysml.parser.Expression.DataType;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.instructions.gpu.context.GPUContextPool;
+import org.apache.sysml.utils.Explain;
+
+/**
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopDagPatternMatcher {
+ static final HashSet<String> DEBUG_PATTERNS;
+ static {
+ // DEBUG_PATTERNS = new HashSet<>();
+ // DEBUG_PATTERNS.add("batchNormdX");
+ DEBUG_PATTERNS = null;
+ }
+
+ // Predicates for the current HOP
+ List<HopPredicate> _predicates = new ArrayList<>();
+ // Child matchers
+ List<HopDagPatternMatcher> _children = new ArrayList<>();
+ private boolean _isLeaf = false;
+
+ static boolean DEBUG_REWRITES = false; // This is set by HopPatternRewriter. Please use DEBUG_PATTERNS instead.
+
+ // Simple utility for debugging the rewrites
+ public static class HopPredicate implements Predicate<Hop> {
+ final String _name;
+ final Function<Hop, Boolean> _pred;
+ public HopPredicate(String name, Function<Hop, Boolean> pred) {
+ _name = name;
+ _pred = pred;
+ }
+ @Override
+ public boolean test(Hop h) {
+ return _pred.apply(h);
+ }
+ @Override
+ public String toString() {
+ return _name;
+ }
+ }
+
+ /**
+ * Adds a predicate to the pattern matcher
+ *
+ * @param name name of the pattern for debugging
+ * @param pred higher order function that takes as an input a hop and returns true if the pattern matches else false
+ * @return this
+ */
+ public HopDagPatternMatcher addPredicate(String name, Function<Hop, Boolean> pred) {
+ _predicates.add(new HopPredicate(name, pred));
+ return this;
+ }
+
+ /**
+ * Add child pattern matcher
+ * @param children list of childer
+ * @return this
+ */
+ public HopDagPatternMatcher addChildMatcher(HopDagPatternMatcher... children) {
+ for(int i = 0; i < children.length; i++) {
+ _children.add(children[i]);
+ }
+ return this;
+ }
+
+ /**
+ * Get the matched HOP DAGs
+ * @param varName variable names
+ * @return matched HOP
+ */
+ public Hop getMatchedHop(String varName) {
+
+ if(matchedHops == null || !matchedHops.containsKey(varName)) {
+ throw new RuntimeException("Incorrect usage: the variable " + varName + " is not registered as input.");
+ }
+ return matchedHops.get(varName);
+ }
+
+ /**
+ * Return the value
+ *
+ * @param varName variable name
+ * @return the value of the LiteralOp
+ */
+ public double getLiteralValue(String varName) {
+ return OptimizerUtils.rEvalSimpleDoubleExpression(getMatchedHop(varName), new HashMap<>());
+ }
+
+ @Override
+ public String toString() {
+ return _predicates.size() >= 1 ? _predicates.get(0).toString() : "";
+ }
+
+ /**
+ * Match the given HOP DAG
+ *
+ * @param h root node of the HOP DAG
+ * @return true if HOP DAG matches
+ */
+ public boolean matches(Hop h) {
+ visited.clear();
+ matchedHops.clear();
+ return matchHelper(this, h);
+ }
+
+ private HashMap<String, Hop> matchedHops = new HashMap<>();
+ private String variableName;
+ private HashMap<HopDagPatternMatcher, Hop> visited = new HashMap<>(); // Map of matched hops
+ private boolean matchHelper(HopDagPatternMatcher root, Hop h) {
+ if(h == null) {
+ return false;
+ }
+ else if(_children.size() > 0 && h.getInput().size() < _children.size()) {
+ if(DEBUG_REWRITES) {
+ System.out.println("The expected number of children (" + _children.size() + ") didnot match the number of inputs (" + h.getInput().size() + ") " + this);
+ }
+ return false;
+ }
+ if(root.visited.containsKey(this)) {
+ Hop h1 = root.visited.get(this);
+ if(h == h1) {
+ if(DEBUG_REWRITES)
+ System.out.println("MATCHED: Early exit as the given HOP has been already matched by the matcher." + this);
+ return true; // Early exit as the given HOP has been already matched by the matcher
+ }
+ else if(_isLeaf) {
+ if(h.getDataType() == h1.getDataType() && h.getDataType() == DataType.SCALAR) {
+ return OptimizerUtils.rEvalSimpleDoubleExpression(h, new HashMap<>()) == OptimizerUtils.rEvalSimpleDoubleExpression(h1, new HashMap<>());
+ }
+ return false; // Mismatched or unknown datatypes or matched with different hops
+ }
+ }
+
+ for(HopPredicate p : _predicates) {
+ if(!p.test(h)) {
+ if(DEBUG_REWRITES) {
+ System.out.println("The predicate " + p.toString() + " failed.");
+ }
+ return false;
+ }
+ }
+ int index = 0;
+ for(HopDagPatternMatcher child : _children) {
+ if(!child.matchHelper(root, h.getInput().get(index))) {
+ return false;
+ }
+ index++;
+ }
+ if(_isLeaf) {
+ root.matchedHops.put(variableName, h);
+ }
+
+ root.visited.put(this, h);
+ if(DEBUG_REWRITES)
+ System.out.println("MATCHED: " + this + " to " + Explain.explain(h));
+ return true;
+ }
+
+
+ // Simple helper utilities for adding predicates
+ private HopDagPatternMatcher isScalar() {
+ return this.addPredicate("isScalar", h -> h.getDataType() == DataType.SCALAR);
+ }
+ private HopDagPatternMatcher isMatrix() {
+ return this.addPredicate("isMatrix", h -> h.getDataType() == DataType.MATRIX);
+ }
+ public HopDagPatternMatcher fitsOnGPU(double constant) {
+ return this.addPredicate("fitsOnGPU", h -> _fitsOnGPU(h, constant));
+ }
+
+ // Factory methods:
+ public static HopDagPatternMatcher dummy = new HopDagPatternMatcher();
+ public static HopDagPatternMatcher rowMeans(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowMeans", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher rowVars(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowVars", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colVars(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colVars", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.VAR && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher leaf(String _variableName, DataType dt) {
+ HopDagPatternMatcher ret = new HopDagPatternMatcher();
+ ret._isLeaf = true;
+ ret.variableName = _variableName;
+ if(dt == DataType.MATRIX) {
+ return ret.isMatrix();
+ }
+ else if(dt == DataType.SCALAR) {
+ return ret.isScalar();
+ }
+ else if(dt == DataType.UNKNOWN) {
+ return ret;
+ }
+ else {
+ throw new DMLRuntimeException("Unsupported datatype in pattern matcher:" + dt.name());
+ }
+ }
+ public static HopDagPatternMatcher rowSums(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("rowSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Row)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colSums(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.SUM && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher colMeans(HopDagPatternMatcher child1) {
+ return new HopDagPatternMatcher().addPredicate("colSums", h ->
+ h instanceof AggUnaryOp && ((AggUnaryOp)h).getOp() == AggOp.MEAN && ((AggUnaryOp)h).getDirection() == Direction.Col)
+ .addChildMatcher(child1);
+ }
+ public static HopDagPatternMatcher matrix(HopDagPatternMatcher X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_reshape", h -> HopRewriteUtils.isReorg(h, ReOrgOp.RESHAPE))
+ .addChildMatcher(X, rows, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X))
+ .addChildMatcher(rows, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, HopDagPatternMatcher rows, long cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim2() == cols)
+ .addChildMatcher(rows, dummy);
+ }
+ public static HopDagPatternMatcher matrix(double X, long rows, HopDagPatternMatcher cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim1() == rows)
+ .addChildMatcher(dummy, cols);
+ }
+ public static HopDagPatternMatcher matrix(double X, long rows, long cols) {
+ return new HopDagPatternMatcher().addPredicate("matrix_datagen", h -> HopRewriteUtils.isDataGenOpWithConstantValue(h, X) &&
+ h.getDim1() == rows && h.getDim2() == cols)
+ .addChildMatcher(dummy, dummy);
+ }
+ public static HopDagPatternMatcher bias_add(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("bias_add", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASADD))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher bias_multiply(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("bias_multiply", h -> HopRewriteUtils.isDnn(h, OpOpDnn.BIASMULT))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher unaryMinus(HopDagPatternMatcher child) {
+ return new HopDagPatternMatcher().addPredicate("unaryMinus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS)
+ && HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), 0))
+ .addChildMatcher(dummy, child);
+ }
+ public static HopDagPatternMatcher sqrt(HopDagPatternMatcher child) {
+ return new HopDagPatternMatcher().addPredicate("sqrt", h -> HopRewriteUtils.isUnary(h, OpOp1.SQRT))
+ .addChildMatcher(child);
+ }
+ public static HopDagPatternMatcher div(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher div(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher div(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("div", h -> HopRewriteUtils.isBinary(h, OpOp2.DIV) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+
+ public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher pow(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("pow", h -> HopRewriteUtils.isBinary(h, OpOp2.POW) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ private static boolean matchDimensions(Hop h1, Hop h2) {
+ return h1.getDim1() == h2.getDim1() && h1.getDim2() == h2.getDim2();
+ }
+ // This is used to differentiate between matrix-matrix and matrix-vector operations.
+ public static HopDagPatternMatcher mm_plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS)
+ && matchDimensions(h.getInput().get(0), h.getInput().get(1)))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher plus(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher plus(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("plus", h -> HopRewriteUtils.isBinary(h, OpOp2.PLUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher minus(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher minus(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("minus", h -> HopRewriteUtils.isBinary(h, OpOp2.MINUS) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT))
+ .addChildMatcher(child1, child2);
+ }
+ public static HopDagPatternMatcher mult(double child1, HopDagPatternMatcher child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(0), child1))
+ .addChildMatcher(dummy, child2);
+ }
+ public static HopDagPatternMatcher mult(HopDagPatternMatcher child1, double child2) {
+ return new HopDagPatternMatcher().addPredicate("mult", h -> HopRewriteUtils.isBinary(h, OpOp2.MULT) &&
+ HopRewriteUtils.isLiteralOfValue(h.getInput().get(1), child2))
+ .addChildMatcher(child1, dummy);
+ }
+ private static boolean _fitsOnGPU(Hop h, double multiplier) {
+ double memEst = multiplier*h.getMemEstimate();
+ return ConfigurationManager.isGPU() && h.dimsKnown() && OptimizerUtils.isMemoryBasedOptLevel() &&
+ memEst < OptimizerUtils.getLocalMemBudget() && memEst < GPUContextPool.initialGPUMemBudget();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
new file mode 100644
index 0000000..02472ed
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopPatternRewriter.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.function.Function;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * This class is used with HopRewriteRuleWithPatternMatcher to implement the following pattern matching logic:
+ * ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter();
+ * for(HopPatternRewriter patternRewriter : patternRewriters) {
+ * hi = patternRewriter.rewrite(hi);
+ * }
+ *
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public class HopPatternRewriter {
+ private final HopDagPatternMatcher _matcher;
+ private final Function<Hop, Hop> _replacer;
+ private final String _name;
+ public HopPatternRewriter(String name, HopDagPatternMatcher matcher, Function<Hop, Hop> replacer) {
+ _name = name;
+ _matcher = matcher;
+ _replacer = replacer;
+ }
+
+ public Hop rewrite(Hop root) {
+ boolean printMessage = HopDagPatternMatcher.DEBUG_PATTERNS != null && HopDagPatternMatcher.DEBUG_PATTERNS.contains(_name);
+ if(printMessage) {
+ HopDagPatternMatcher.DEBUG_REWRITES = true;
+ System.out.println("-----------------------------------");
+ System.out.println(org.apache.sysml.utils.Explain.explain(root));
+ }
+ if(_matcher.matches(root)) {
+ Hop newHop = _replacer.apply(root);
+ if(printMessage) {
+ if(newHop == root)
+ System.out.println("Initial pattern match for " + _name + " succeeded but replacer returned the same HOP.");
+ else
+ System.out.println("Pattern match for " + _name + " succeeded.");
+ HopDagPatternMatcher.DEBUG_REWRITES = false;
+ System.out.println("-----------------------------------");
+ }
+ return newHop;
+ }
+ else {
+ if(printMessage) {
+ System.out.println("Pattern match for " + _name + " failed.");
+ HopDagPatternMatcher.DEBUG_REWRITES = false;
+ System.out.println("-----------------------------------");
+ }
+ return root;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
new file mode 100644
index 0000000..854eca3
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteRuleWithPatternMatcher.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.rewrite;
+
+import java.util.ArrayList;
+
+import org.apache.sysml.hops.Hop;
+
+/**
+ * Simple utility class that implements generic structure for HopRewriteRule.
+ * Please see org.apache.sysml.hops.rewrite.RewriteGPUSpecificOps class for usage and design documentation.
+ */
+public abstract class HopRewriteRuleWithPatternMatcher extends HopRewriteRule {
+
+ public abstract ArrayList<HopPatternRewriter> getPatternRewriter();
+
+ @Override
+ public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
+ if( roots == null )
+ return roots;
+
+ //one pass rewrite-descend (rewrite created pattern)
+ for( int i = 0; i < roots.size(); i++ )
+ applyRules(roots, roots.get(i), false );
+ Hop.resetVisitStatus(roots, true);
+
+ //one pass descend-rewrite (for rollup)
+ for( int i = 0; i < roots.size(); i++ )
+ applyRules(roots, roots.get(i), true );
+ Hop.resetVisitStatus(roots, true);
+
+ return roots;
+ }
+
+ @Override
+ public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
+ if( root == null )
+ return root;
+
+ //one pass rewrite-descend (rewrite created pattern)
+ applyRules(null, root, false );
+
+ root.resetVisitStatus();
+
+ //one pass descend-rewrite (for rollup)
+ applyRules(null, root, true );
+
+ return root;
+ }
+
+ /**
+ * Apply rules
+ *
+ * @param roots root operators
+ * @param hop high-level operator
+ * @param descendFirst true if recursively process children first
+ */
+ private void applyRules(ArrayList<Hop> roots, Hop hop, boolean descendFirst)
+ {
+ if(hop.isVisited())
+ return;
+
+ //recursively process children
+ for( int i=0; i<hop.getInput().size(); i++) {
+ Hop hi = hop.getInput().get(i);
+
+ //process childs recursively first (to allow roll-up)
+ if( descendFirst )
+ applyRules(roots, hi, descendFirst); //see below
+
+ ArrayList<HopPatternRewriter> patternRewriters = getPatternRewriter();
+ for(HopPatternRewriter patternRewriter : patternRewriters) {
+ hi = patternRewriter.rewrite(hi);
+ }
+
+ if( !descendFirst )
+ applyRules(roots, hi, descendFirst);
+ }
+
+ hop.setVisited();
+ }
+}
http://git-wip-us.apache.org/repos/asf/systemml/blob/0f36780a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
index 271142d..2351f5f 100644
--- a/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysml/hops/rewrite/HopRewriteUtils.java
@@ -719,6 +719,26 @@ public class HopRewriteUtils
return ternOp;
}
+ public static DnnOp createDnnOp(OpOpDnn op, Hop... hops) {
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ for(Hop h : hops) {
+ inHops.add(h);
+ }
+ return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+ op, inHops);
+ }
+
+ public static DnnOp createDnnOp(HopDagPatternMatcher matcher, OpOpDnn op, String... varNames) {
+ ArrayList<Hop> inHops = new ArrayList<Hop>();
+ for(String v : varNames) {
+ inHops.add(matcher.getMatchedHop(v));
+ }
+ return new DnnOp("tmp", DataType.MATRIX, ValueType.DOUBLE,
+ op, inHops);
+ }
+
+
+
public static void setOutputParameters( Hop hop, long rlen, long clen, int brlen, int bclen, long nnz ) {
hop.setDim1( rlen );
hop.setDim2( clen );