You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemds.apache.org by ba...@apache.org on 2022/08/17 16:30:24 UTC

[systemds] branch main updated: [SYSTEMDS-3303] NN Builtin Attention Layer

This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 6576bad751 [SYSTEMDS-3303] NN Builtin Attention Layer
6576bad751 is described below

commit 6576bad751bcc136183877592b120fcfd1157eb8
Author: Stefan Schörkmeier <s....@student.tugraz.at>
AuthorDate: Mon Apr 11 09:51:38 2022 +0200

    [SYSTEMDS-3303] NN Builtin Attention Layer
    
    This commit adds a new neural network builtin layer for attention.
    
    AMLS project SS2022
    
    Closes #1625
    Closes #1679
    
    Co-authored-by: Anton Postl <an...@student.tugraz.at>
    Co-authored-by: Stefan Schörkmeier <s....@student.tugraz.at>
---
 .gitignore                                         |   1 +
 scripts/nn/examples/AttentionExample.dml           | 479 +++++++++++++++++++++
 scripts/nn/examples/download_attentionExample.sh   |  26 ++
 scripts/nn/layers/attention.dml                    | 142 ++++++
 src/test/scripts/applications/nn/grad_check.dml    |  90 ++++
 .../applications/nn/run_tests_gradients.dml        |   1 +
 6 files changed, 739 insertions(+)

diff --git a/.gitignore b/.gitignore
index f008e4bb9b..f463565a38 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,6 +43,7 @@ _site/
 
 # Tutorial data mnist
 src/main/python/systemds/examples/tutorials/*/
+scripts/nn/examples/data/*
 
 # User configuration files
 conf/SystemDS-config.xml
diff --git a/scripts/nn/examples/AttentionExample.dml b/scripts/nn/examples/AttentionExample.dml
new file mode 100644
index 0000000000..f0eb6c2e28
--- /dev/null
+++ b/scripts/nn/examples/AttentionExample.dml
@@ -0,0 +1,479 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#-------------------------------------------------------------
+# A simple example using the attention layer
+# for single-headed self attention
+# in combination with the LSTM recurrent layer.
+#
+# It use the clickbait dataset
+# (https://www.kaggle.com/datasets/amananandrai/clickbait-dataset?select=clickbait_data.csv)
+# which is a simple binary text classification with 32000 samples.
+#
+# To load the data set and insert it into the right folder, simply run the download_attentionExample.sh
+#
+#-------------------------------------------------------------
+
+
+source("nn/layers/attention.dml") as attention
+source("nn/layers/affine.dml") as affine
+source("nn/layers/lstm.dml") as lstm
+source("nn/layers/relu.dml") as relu
+source("nn/layers/sigmoid.dml") as sigmoid
+source("nn/optim/adam.dml") as adam
+source("nn/layers/log_loss.dml") as log_loss
+
+
+# 1 get data
+data_loc = "scripts/nn/examples/data/"
+tableschema = "string,int"
+N=32000 # Samples of whole dataset
+n=1000  # Samples to use for training
+max_length = 32 # maximum sequence length
+epochs = 10
+batch_size = 32
+val_size = 100
+
+data = read(data_loc + "clickbait_data.csv", format="csv", header=TRUE, sep=",", data_type="frame", schema=tableschema, cols=2, rows=N)
+
+
+[x_train, y_train, vocab_size] = preprocess(data, max_length, N)
+
+x_train = x_train[1:n]
+y_train = y_train[1:n]
+
+# train network
+
+[biases, weights] = train(x_train, y_train, epochs, batch_size, max_length, vocab_size, val_size)
+
+
+preprocess = function(frame[unknown] data, integer max_length, integer n)
+  return (matrix[double] features, matrix[double] targets, integer vocab_size)
+{
+  /*
+   * Preprocess the raw text data into integer tokens, shuffles data and targets.
+   *
+   * Inputs:
+   * - data: dataframe with [string, int] schema and n rows.
+   * - max_length: maximum sequence length we use for training
+   * - n: number of samples.
+   *
+   * Outputs:
+   * - features: feature matrix of shape (n, max_sequence_length)
+   * - targets: labels vector of shape (n,1)
+   * - vocab_size: vocabulary size, used to define size of embedding matrix during training.
+   */
+
+  # map to lowercase, remove non alphanumeric characters
+  formatted = map(data[,1], "x -> x.toLowerCase().replaceAll(\"[^\\p{Alnum}]+\", \" \").replaceAll(\"[\\s]+\", \" \")")
+  ids = as.frame(seq(1,nrow(formatted),1))
+  formatted = cbind(ids, formatted)
+
+  # prepare feature matrix for lstm into one-hot-encoded sequences
+  spec = "{\"algo\" : \"split\", \"out\": \"position\", \"tokenize_col\": 2, \"id_cols\": [1]}"
+  tokenized = tokenize(target=formatted, spec=spec, max_tokens=max_length)
+  recode_spec = "{ \"recode\": [C3]}"
+  [tokens, mapping] = transformencode(target=tokenized, spec=recode_spec)
+  features = matrix(0, rows=n, cols=max_length)
+  row_old = as.scalar(tokens[1,1])
+  pos = 1
+  for(i in 1:nrow(tokens))
+  {
+    row = as.scalar(tokens[i,1])
+    if (row != row_old)
+    {
+      row_old = row
+      pos = 1
+    }
+    features[row,pos] = tokens[i,3]
+    pos += 1
+  }
+  features = replace(target=features, pattern = NaN, replacement = -1)
+  features = features + 2
+  vocab_size = as.integer(max(features))
+
+  targets = as.matrix(data[,2])
+
+  #shuffle data
+  r = rand(rows=n, cols=1, min=0, max=1, pdf="uniform")
+  x = order(target=cbind(r,features), by=1)
+  y = order(target=cbind(r,targets), by=1)
+  features = x[,2:ncol(x)]
+  targets = y[,2:ncol(y)]
+}
+
+train = function( matrix[double] x_train,
+                  matrix[double] y_train,
+                  integer epochs,
+                  integer batch_size,
+                  integer max_sequence_length,
+                  integer vocab_size,
+                  integer val_size
+)
+  return(List[unknown] biases, List[unknown] weights)
+{
+  /*
+   * Trains our example model
+   *
+   * Inputs:
+   * - x_train: training features, matrix of shape (N,max_sequence_length).
+   * - y_train: training labels, matrix of shape (N,1).
+   * - epochs: number of epochs to train our model.
+   * - batch_size: batch size we use in each iteration.
+   * - max_length: maximum sequence length of data.
+   * - vocab_size: Size of our considered vocabulary.
+   * - val_size: Size of the validation set, which is subtracted from x_train and y_train.
+   *
+   * Outputs:
+   * - biases: list of biases.
+   * - weights: list of weights.
+   */
+  samples = nrow(x_train)
+  print("Start Training")
+
+  #validation split
+  x_val = x_train[1:val_size]
+  y_val = y_train[1:val_size]
+
+  x_train = x_train[val_size+1:samples]
+  y_train = y_train[val_size+1:samples]
+
+  samples = nrow(x_train)
+  features = ncol(x_train)
+  output_size = 1
+
+  # We use a trainable embedding, each row is an embedding for a word
+  embedding_size = 64
+  W_E = rand(rows=vocab_size, cols=embedding_size)
+
+  # 1 lstm layer
+  lstm_neurons = 150
+  [W_0, b_0, out0, c0] = lstm::init(batch_size, embedding_size, lstm_neurons)
+
+  # 2 attention layer: learnable query with half max_sequence_length
+  [W_query, b_query] = affine::init(max_sequence_length*lstm_neurons,max_sequence_length*lstm_neurons/2,-1)
+  [W_key, b_key] = affine::init(max_sequence_length*lstm_neurons,max_sequence_length*lstm_neurons,-1)
+
+
+  # 3 dense layer -> (hidden_size)
+  hidden_neurons = 128
+
+  [W_1, b_1] = affine::init(max_sequence_length/2 * lstm_neurons, hidden_neurons, -1)
+
+  # 4 dense layer -> (output_size)
+  [W_2, b_2] = affine::init(hidden_neurons, output_size, -1)
+
+  # 5 sigmoid layer: no weights
+
+  # put weights & biases into list
+  biases = list(b_0, b_1, b_2, b_query, b_key)
+  weights = list(W_0, W_1, W_2, W_E, W_query, W_key)
+
+  #optimizer init
+  [mW_E, vW_E] = adam::init(W_E)
+
+  [mW_query, vW_query] = adam::init(W_query)
+  [mb_query, vb_query] = adam::init(b_query)
+
+  [mW_key, vW_key] = adam::init(W_key)
+  [mb_key, vb_key] = adam::init(b_key)
+
+  [mW_0, vW_0] = adam::init(W_0)
+  [mW_1, vW_1] = adam::init(W_1)
+  [mW_2, vW_2] = adam::init(W_2)
+
+  [mb_0, vb_0] = adam::init(b_0)
+  [mb_1, vb_1] = adam::init(b_1)
+  [mb_2, vb_2] = adam::init(b_2)
+
+  #optimizer params
+  lr = 0.001
+  beta1 = 0.99
+  beta2 = 0.999
+  epsilon = 1e-8
+  t = 0
+
+  #allocate matrices for attention layer
+  dQuery = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size)
+  dValue = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size)
+  dKey = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size)
+  out2 = matrix(0, rows=batch_size, cols=max_sequence_length*embedding_size)
+
+  #training loop
+  iters = ceil(samples/batch_size)
+  for (ep in 1:epochs)
+  {
+    print("Start ep: " + ep)
+    for (i in 1:iters)
+    {
+      print("Iteration: " + i)
+      # 1 Get batch data
+      start = ((i-1) * batch_size) %% samples + 1
+      end = min(samples, start + batch_size -1)
+
+      x_batch = x_train[start:end,]
+      y_batch = y_train[start:end,]
+
+      # 2 predict
+      [y_hat, out5, out4, out3, out2, out1, query, key, emb, cache_out_out, cache_c_out, cache_ifog_out] =
+          predict(x_batch, biases, weights, max_sequence_length, embedding_size, lstm_neurons, out2)
+
+      # 3 backpropagation
+      dout = log_loss::backward(y_hat, y_batch)
+      dprobs = sigmoid::backward(dout, out5)
+      [dout_2, dW_2, db_2] = affine::backward(dprobs, out4, W_2, b_2)
+      drelu = relu::backward(dout_2, out3)
+      [dout_1, dW_1, db_1] = affine::backward(drelu, out2, W_1, b_1)
+      [dQuery, dValue, dKey] = attention::backward(dattention=dout_1,
+                                                    query=query,
+                                                    key=key,
+                                                    value=out1,
+                                                    D=max_sequence_length,
+                                                    dquery=dQuery,
+                                                    dvalue=dValue,
+                                                    dkey=dKey)
+      [dq, dW_query, db_query] = affine::backward(dQuery, out1, W_query, b_query)
+      [dk, dW_key, db_key] = affine::backward(dKey, out1, W_key, b_key)
+      dc = matrix(0, rows=nrow(x_batch), cols=lstm_neurons)
+      out0 = dc
+      c0 = dc
+      [dEmb, dW_0, db_0, dout0, dc0] = lstm::backward(dValue + dq + dk,
+                                                        dc,
+                                                        emb,
+                                                        W_0,
+                                                        b_0,
+                                                        max_sequence_length,
+                                                        embedding_size,
+                                                        TRUE,
+                                                        out0,
+                                                        c0,
+                                                        cache_out_out,
+                                                        cache_c_out,
+                                                        cache_ifog_out)
+
+      # 4 update weights & biases
+      t = ep * i - 1
+      # #embedding
+      [W_E, mW_E, vW_E] = update_embeddings(x_batch, dEmb, W_E, mW_E, vW_E,
+        lr, beta1, beta2, epsilon, t, max_sequence_length, embedding_size)
+
+      # lstm
+      [b_0, mb_0, vb_0] = adam::update(b_0, db_0, lr, beta1, beta2, epsilon, t, mb_0, vb_0)
+      [W_0, mW_0, vW_0] = adam::update(W_0, dW_0, lr, beta1, beta2, epsilon, t, mW_0, vW_0)
+
+      # affine query
+      [W_query, mW_query, vW_query] = adam::update(W_query, dW_query, lr, beta1, beta2, epsilon, t, mW_query, vW_query)
+      [b_query, mb_query, vb_query] = adam::update(b_query, db_query, lr, beta1, beta2, epsilon, t, mb_query, vb_query)
+
+      # affine key
+      [W_key, mW_key, vW_key] = adam::update(W_key, dW_key, lr, beta1, beta2, epsilon, t, mW_key, vW_key)
+      [b_key, mb_key, vb_key] = adam::update(b_key, db_key, lr, beta1, beta2, epsilon, t, mb_key, vb_key)
+
+
+      # hidden affine
+      [b_1, mb_1, vb_1] = adam::update(b_1, db_1, lr, beta1, beta2, epsilon, t, mb_1, vb_1)
+      [W_1, mW_1, vW_1] = adam::update(W_1, dW_1, lr, beta1, beta2, epsilon, t, mW_1, vW_1)
+
+      # output affine
+      [b_2, mb_2, vb_2] = adam::update(b_2, db_2, lr, beta2, beta2, epsilon, t, mb_2, vb_2)
+      [W_2, mW_2, vW_2] = adam::update(W_2, dW_2, lr, beta2, beta2, epsilon, t, mW_2, vW_2)
+
+      # put weights & biases into list
+      biases = list(b_0,b_1,b_2,b_query, b_key)
+      weights = list(W_0,W_1,W_2,W_E,W_query, W_key)
+    }
+    [loss, accuracy] = evaluate(x_train, y_train, biases, weights, lstm_neurons, max_sequence_length, embedding_size, out2)
+    [val_loss, val_accuracy] = evaluate(x_val, y_val, biases, weights, lstm_neurons, max_sequence_length, embedding_size, out2)
+    print("Epoch: " + ep + "; Train Loss: " + loss + "; Train Acc: " + accuracy +"; Val. Loss: " + val_loss + "; Val. Accuracy: " + val_accuracy)
+  }
+}
+
+predict = function( matrix[double] x,
+                    List[unknown] biases,
+                    List[unknown] weights,
+                    integer max_sequence_length,
+                    integer embedding_size,
+                    integer lstm_neurons,
+                    matrix[double] out2
+)
+  return (matrix[double] y_hat, matrix[double] out5, matrix[double] out4, matrix[double] out3,
+          matrix[double] out2, matrix[double] out1, matrix[double] query, matrix[double] key, matrix[double] emb, matrix[double] cache_out_out,
+          matrix[double] cache_c_out, matrix[double] cache_ifog_out)
+{
+  /*
+   * Predicts an output y_hat for given samples x.
+   *
+   * Inputs:
+   * - x: sample features of shape(batch_size, max_sequence_length)
+   * - biases: list of biases of length 3 (lstm, affine, affine)
+   * - weights: list of weights of length 4 (lstm, affine, affine, embedding)
+   * - max_sequence_length: number of words per sample.
+   * - embedding_size: size of embedding vector for 1 word
+   * - lstm_neurons: number of neurons in lstm layer.
+   * - out2: matrix of shape (batch_size, max_sequence_length*embedding_size) as attention for attention layer.
+   *
+   * Outputs:
+   * - y_hat: matrix of shape(batch_size, 1), prediction for log-loss, output of sigmoid layer
+   * - out5: output of final affine layer, shape(batch_size, 1)
+   * - out4: output of relu layer
+   * - out3: output of hidden affine layer
+   * - out2: output of attention layer
+   * - out1: output states from lstm layer, of shape(batch_size, max_sequence_length * lstm_neurons)
+   * - query: transformation of out1 by affine layer, of shape(batch_size, max_sequence_length * lstm_neurons/2)
+   * - key: transformation of out1 by affine layer, of shape(batch_size, max_sequence_length * lstm_neurons)
+   * - cache_out_out: cache_out output from lstm layer
+   * - cahce_c_out: cache_c output from lstm layer
+   * - cache_ifog_out: cahce_ifog output from lstm layer
+   */
+
+  # unpack weights & biases
+  W_0 = as.matrix(weights[1])
+  W_1 = as.matrix(weights[2])
+  W_2 = as.matrix(weights[3])
+  W_E = as.matrix(weights[4])
+  W_query = as.matrix(weights[5])
+  W_key = as.matrix(weights[6])
+
+  b_0 = as.matrix(biases[1])
+  b_1 = as.matrix(biases[2])
+  b_2 = as.matrix(biases[3])
+  b_query = as.matrix(biases[4])
+  b_key = as.matrix(biases[5])
+
+  # fetch embedding
+  emb = fetch_embeddings(x, W_E, max_sequence_length, embedding_size)
+  # put input through layers
+  batch_size = nrow(x)
+  out0 = matrix(0, batch_size, lstm_neurons)
+  c0 = out0
+  [out1, c_out, cache_out_out, cache_c_out, cache_ifog_out]=
+    lstm::forward(emb, W_0, b_0, max_sequence_length, embedding_size, TRUE, out0, c0)
+  query = affine::forward(out1, W_query, b_query)
+  key = affine::forward(out1, W_key, b_key)
+  out2 = attention::forward(query=query,key=key, value=out1, D=max_sequence_length, attention=out2)
+  out3 = affine::forward(out2, W_1, b_1)
+  out4 = relu::forward(out3)
+  out5 = affine::forward(out4, W_2, b_2)
+  y_hat = sigmoid::forward(out5)
+}
+
+fetch_embeddings = function(matrix[double] indexes, matrix[double] W_E,
+  integer max_sequence_length, integer embedding_size)
+  return(matrix[double] emb)
+{
+  /*
+   * Fetches embeddings for given tokens (indexes).
+   *
+   * Inputs:
+   * - indexes: tokens for fetching embeddings, shape(batch_size, max_sequence_length)
+   * - W_E: trainable embedding matrix of shape(vocab_size, embedding_size)
+   * - max_sequence_lengt: number of words per sample.
+   * - embedding_size: size of an embedding vector for 1 word.
+   *
+   * Outputs:
+   * - emb: embedded version of indexes of shape(batch_size, max_sequence_length * embedding_size)
+   */
+
+  emb = matrix(0, rows=nrow(indexes), cols=embedding_size*max_sequence_length)
+  for (i in 1:nrow(indexes))
+  {
+    for (j in 1:max_sequence_length)
+    {
+      index = as.integer(as.scalar(indexes[i,j]))
+      emb[i,(j-1)*embedding_size+1:j*embedding_size] = W_E[index]
+    }
+  }
+}
+
+update_embeddings = function(matrix[double] indexes, matrix[double] dEmb, matrix[double] W_E,
+  matrix[double] mW_E, matrix[double] vW_E, double lr, double beta1, double beta2,
+  double epsilon, integer t, integer max_sequence_length, integer embedding_size)
+  return (matrix[double] W_E, matrix[double] mW_E, matrix[double] vW_E)
+{
+  /*
+   * Updates embedding matrix for given tokens (indexes).
+   *
+   * Inputs:
+   * - indexes: tokens for fetching embeddings, shape(batch_size, max_sequence_length)
+   * - dEmb: gradient from upstream of shape(batch_size, max_sequence_length * embedding_size)
+   * - W_E: trainable embedding matrix of shape(vocab_size, embedding_size)
+   * - mW_E: m variable (1st moment estimate) for adam optimizer
+   * - vW_E: v variable (2nd moment estimate) for adam optimizer
+   * - lr: learning rate
+   * - beta1: exponential decay rate for 1st moment estimate for adam optimizer
+   * - beta2: exponential decay rate for 2nd moment estimate for adam optimizer
+   * - epsilon: for numerical stability of adam optimizer
+   * - t: timestep for adam optimizer
+   * - max_sequence_length: number of words per sample.
+   * - embedding_size: size of an embedding vector for 1 word.
+   *
+   * Outputs:
+   * - W_E: updated trainable embedding matrix of shape(vocab_size, embedding_size)
+   * - mW_E: updated m variable (1st moment estimate) for adam optimizer
+   * - vW_E: updated v variable (2nd moment estimate) for adam optimizer
+   */
+  for (i in 1:nrow(indexes))
+  {
+    for (j in 1:max_sequence_length)
+    {
+      index = as.integer(as.scalar(indexes[i,j]))
+      [W_Ei, mW_Ei, vW_Ei] = adam::update(
+        W_E[index], dEmb[i,(j-1)*embedding_size+1:j*embedding_size], lr, beta1, beta2,
+        epsilon, t, mW_E[index], vW_E[index])
+      W_E[index] = W_Ei
+      mW_E[index] = mW_Ei
+      vW_E[index] = vW_Ei
+    }
+  }
+}
+
+evaluate = function(matrix[double] x, matrix[double] y,
+  list[unknown] biases, list[unknown] weights, integer lstm_neurons,
+  integer max_sequence_length, integer embedding_size, matrix[double] out2)
+  return(double loss, double accuracy)
+{
+  /*
+  * Evaluate fit by calculating log_loss and accuracy using predict function.
+  *
+  * Inputs:
+  * - x: feature matrix of shape(batch_size, max_sequence_length)
+  * - y: target matrix for x of shape(batch_size, 1)
+  * - biases: list of biases
+  * - weights: list of weights
+  * - lstm_neurons: number of neurons used in lstm layer
+  * - max_sequence_length: number of words per sample.
+  * - embedding_size: size of an embedding vector for 1 word.
+  * - out2: matrix of shape (batch_size, max_sequence_length*embedding_size) as attention for attention layer.
+  *
+  * Outputs:
+  * - loss: log_loss of prediction
+  * - accuracy: percentage of correct classifications
+  */
+  batch_size = nrow(x)
+  [y_hat, out5, out4, out3, out2, out1, query, key, emb, cache_out_out, cache_c_out, cache_ifog_out] =
+    predict(x, biases, weights, max_sequence_length, embedding_size, lstm_neurons, out2)
+  loss = log_loss::forward(y_hat, y)
+
+  z = y_hat >= 0.5
+  accuracy = 1 - sum(abs(z - y)) / batch_size
+}
+
+
diff --git a/scripts/nn/examples/download_attentionExample.sh b/scripts/nn/examples/download_attentionExample.sh
new file mode 100755
index 0000000000..bee982a23a
--- /dev/null
+++ b/scripts/nn/examples/download_attentionExample.sh
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+mkdir -p ../nn/examples/data
+cd ../nn/examples/data
+wget https://systemds.apache.org/assets/datasets/clickbait/clickbait.7z
+7z e clickbait.7z
+mv clickbait.csv clickbait_data.csv
diff --git a/scripts/nn/layers/attention.dml b/scripts/nn/layers/attention.dml
new file mode 100644
index 0000000000..3d80324ddf
--- /dev/null
+++ b/scripts/nn/layers/attention.dml
@@ -0,0 +1,142 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+source("nn/layers/softmax.dml") as softmax
+source("scripts/nn/util.dml") as util
+
+
+forward = function(matrix[double] query, matrix[double] value,
+       matrix[double] key = matrix("",rows=0, cols=0), integer D,
+       matrix[double] attention)
+    return (matrix[double] attention) {
+  /*
+   * Computes the forward pass for the attention layer.
+   *
+   * Inputs:
+   * - query: Input querys of shape (N,J*D).
+   * - value: Values for keys of shape (N,K*D).
+   * - key: *optional* Keys for values of shape (N,K*D).
+   * -      If key is a matrix of length 0 value will be used as key.
+   * - D: Dimensionality of single query, value, key,
+   * - attention: Matrix of shape (N,J*D) on which to put the output.
+   * -
+   * Outputs:
+   * - attention: Attention on value(s) for given query(s), of shape (N,J*D).
+   */
+  N = nrow(value)
+  K = ncol(value) / D
+  J = ncol(query) / D
+
+  norm = 1/D^0.5
+  if (!length(key))
+  {
+    key = value
+  }
+  key_norm = key * norm
+  attention = matrix(0, rows=N, cols=J*D)
+  query_n = matrix(0, rows=J, cols=D)
+  key_norm_n = matrix(0, rows=K, cols=D)
+  value_n = matrix(0, rows=K, cols=D)
+  probs = matrix(0, rows=J, cols=K)
+  scores = matrix(0, rows=J, cols=K)
+  for (n in 1:N)
+  {
+    #reshape
+    query_n = matrix(query[n], rows=J, cols=D)
+    key_norm_n = matrix(key_norm[n],rows=K, cols=D)
+    value_n = matrix(value[n], rows=K, cols=D)
+
+    scores = query_n %*% t(key_norm_n)
+    #column wise softmax
+    probs = t(softmax::forward(t(scores)))
+    attention[n] = matrix(probs %*% value_n, rows=1, cols=J*D)
+  }
+}
+
+backward = function(matrix[double] dattention,
+                  matrix[double] query, matrix[double] value, matrix[double] key=matrix("",rows=0,cols=0),
+                  integer D, matrix[double] dquery, matrix[double] dvalue, matrix[double] dkey)
+    return (matrix[double] dquery, matrix[double] dvalue, matrix[double] dkey)
+{
+  /*
+   * Computes the backward pass for the attention layer.
+   *
+   * Inputs:
+   * - dattention: Gradient wrt `attention` of shape (N,J*D).
+   * - query: *optional* Query input of shape (N,J*D).
+   * - key: Keys for values of shape (N,K*D).
+   * -      If key is of length 0, the gradient dkey will be added to dvalue and dkey is 0.
+   * - value: Values for given key of shape (N,K*D).
+   * - D: Dimensionality of single query, key, value.
+   * - dquery: Matrix of shape (N,J*D) for output allocation.
+   * - dvalue: Matrix of shape (N,K*D) for output allocation.
+   * - dkey: Matrix of shape (N,K*D) for output allocation
+   * -
+   * Outputs:
+   * - dquery: Gradient wrt `query`, of shape (N, J*D).
+   * - dkey: Gradient wrt `key`, of shape (N,K*D).
+   * - dvalue: Gradient wrt `value` of shape (N,K*D).
+   */
+
+  N = nrow(value)
+  K = ncol(value) / D
+  J = ncol(query) / D
+
+  norm = 1 / D^0.5
+  use_key = length(key) > 0
+  if (!use_key){
+    key = value
+  }
+  key_norm = key * norm
+
+  dquery = matrix(0, rows=N, cols=J*D)
+  dkey = matrix(0, rows=N, cols=K*D)
+  dvalue = matrix(0, rows=N, cols=K*D)
+
+  query_n = matrix(0, rows=J, cols=D)
+  key_norm_n = matrix(0, rows=K, cols=D)
+  value_n = matrix(0, rows=K, cols=D)
+  dvalue_n = matrix(0, rows=K, cols=D)
+  probs = matrix(0, rows=J, cols=K)
+  scores = matrix(0, rows=J, cols=K)
+  for (n in 1:N)
+  {
+    #reshape
+    query_n = matrix(query[n], rows=J, cols=D)
+    key_norm_n = matrix(key_norm[n], rows=K, cols=D)
+    value_n = matrix(value[n], rows=K, cols=D)
+    dattention_n = matrix(dattention[n], rows=J, cols=D)
+
+    scores = query_n %*% t(key_norm_n)
+    probs = t(softmax::forward(t(scores)))
+
+    dscore = t(softmax::backward(value_n %*% t(dattention_n), t(scores)))
+    dquery[n] = matrix(dscore %*% key_norm_n, rows=1, cols=J*D)
+    if (use_key){
+      dkey[n] = matrix(t(dscore) %*% query_n * norm, rows=1, cols=K*D)
+      dvalue[n] = matrix(t(probs) %*% dattention_n, rows=1, cols=K*D)
+    }
+    else{
+      dvalue[n] = matrix(t(probs) %*% dattention_n + t(dscore) %*% query_n * norm, rows=1, cols=K*D)
+    }
+
+  }
+}
diff --git a/src/test/scripts/applications/nn/grad_check.dml b/src/test/scripts/applications/nn/grad_check.dml
index e0a3740539..b5353af8fd 100644
--- a/src/test/scripts/applications/nn/grad_check.dml
+++ b/src/test/scripts/applications/nn/grad_check.dml
@@ -59,6 +59,7 @@ source("src/test/scripts/applications/nn/max_pool2d_simple.dml") as max_pool2d_s
 source("src/test/scripts/applications/nn/util.dml") as test_util
 source("scripts/nn/util.dml") as util
 source("scripts/nn/layers/elu.dml") as elu
+source("scripts/nn/layers/attention.dml") as attention
 
 affine = function() {
   /*
@@ -2537,3 +2538,92 @@ elu = function() {
      }
    }
 }
+
+attention = function() {
+  /*
+  * Gradient check for Attention layer
+  */
+  print("Grad checking Attention layer with L2 loss.")
+  # initialize random Query, Key and Value
+  N = 5 # number of samples
+  J = 2 # queries
+  K = 4 # keys,values
+  D = 3 # key,query,value dimension
+
+  query = rand(rows=N, cols=J*D, min=-5, max=5)
+  key = rand(rows=N, cols=K*D, min=-5, max=5)
+  value = rand(rows=N, cols=K*D, min=-5, max=5)
+
+  # initialize random target
+  y = rand(rows=N, cols=J*D, min=-5, max=5)
+
+  # cycle 1: using key
+  # cycle 2: using value as key
+  for (cycle in 1:2)
+  {
+    print("Cycle " + cycle)
+    if (cycle == 2){
+      key = matrix("", rows=0, cols=0)
+    }
+    # compute gradient analytically
+    att = matrix(0, rows=N, cols=J*D)
+    att = attention::forward(query, value, key, D, att)
+    datt = l2_loss::backward(att, y)
+    dQ = matrix(0, rows=N, cols=J*D)
+    dV = matrix(0, rows=N, cols=K*D)
+    dK = matrix(0, rows=N, cols=K*D)
+    [dQ, dV, dK] = attention::backward(datt, query, value, key, D, dQ, dV, dK)
+
+
+    # check gradient numerically
+    h = 1e-4
+
+    print(" - Grad checking query.")
+    for (i in 1:nrow(query)){
+      for (j in 1:ncol(query)){
+      old = as.scalar(query[i,j])
+      query[i,j] = old - h
+      outmh = attention::forward(query, value, key, D, att)
+      lossmh = l2_loss::forward(outmh, y)
+      query[i,j] = old + h
+      outph = attention::forward(query, value, key, D, att)
+      lossph = l2_loss::forward(outph, y)
+      query[i,j] = old  # reset
+      dQ_num = (lossph-lossmh) / (2*h)
+      rel_error = test_util::check_rel_grad_error(as.scalar(dQ[i,j]), dQ_num, lossph, lossmh)
+      }
+    }
+    if (cycle == 1){
+      print(" - Grad checking key.")
+      for (i in 1:nrow(key)){
+        for (j in 1:ncol(key)){
+          old = as.scalar(key[i,j])
+          key[i,j] = old - h
+          outmh = attention::forward(query, value, key, D, att)
+          lossmh = l2_loss::forward(outmh, y)
+          key[i,j] = old + h
+          outph = attention::forward(query, value, key, D, att)
+          lossph = l2_loss::forward(outph, y)
+          key[i,j] = old  # reset
+          dK_num = (lossph-lossmh) / (2*h)
+          rel_error = test_util::check_rel_grad_error(as.scalar(dK[i,j]), dK_num, lossph, lossmh)
+        }
+      }
+    }
+    print(" - Grad checking value.")
+    for (i in 1:nrow(value)){
+      for (j in 1:ncol(value)){
+        old = as.scalar(value[i,j])
+        value[i,j] = old - h
+        outmh = attention::forward(query, value, key, D, att)
+        lossmh = l2_loss::forward(outmh, y)
+        value[i,j] = old + h
+        outph = attention::forward(query, value, key, D, att)
+        lossph = l2_loss::forward(outph, y)
+        value[i,j] = old  # reset
+        dV_num = (lossph-lossmh) / (2*h)
+        rel_error = test_util::check_rel_grad_error(as.scalar(dV[i,j]), dV_num, lossph, lossmh)
+      }
+    }
+  }
+}
diff --git a/src/test/scripts/applications/nn/run_tests_gradients.dml b/src/test/scripts/applications/nn/run_tests_gradients.dml
index f81f500eda..e827bef262 100644
--- a/src/test/scripts/applications/nn/run_tests_gradients.dml
+++ b/src/test/scripts/applications/nn/run_tests_gradients.dml
@@ -33,6 +33,7 @@ grad_check::log_loss()
 
 # Core layers
 grad_check::affine()
+grad_check::attention()
 grad_check::low_rank_affine()
 grad_check::batch_norm1d()
 grad_check::batch_norm2d()