You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by mb...@apache.org on 2020/06/02 21:40:23 UTC

[systemml] branch master updated: [SYSTEMDS-397] Neural Collaborative Filtering (NCF) algorithm script

This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new 5e0d5c4  [SYSTEMDS-397] Neural Collaborative Filtering (NCF) algorithm script
5e0d5c4 is described below

commit 5e0d5c45162c7f26e6003659d58f091a3a794f11
Author: Patrick Deutschmann <pa...@student.tugraz.at>
AuthorDate: Tue Jun 2 23:29:57 2020 +0200

    [SYSTEMDS-397] Neural Collaborative Filtering (NCF) algorithm script
    
    AMLS project SS2020.
    Closes #925.
---
 dev/docs/Tasks.txt                                 |   1 +
 .../Example - Neural Collaborative Filtering.ipynb | 347 +++++++++++++++++++++
 scripts/nn/examples/README.md                      |   7 +
 scripts/nn/examples/ncf-dummy-data.dml             |  57 ++++
 scripts/nn/examples/ncf-real-data.dml              |  65 ++++
 scripts/staging/NCF.dml                            | 330 ++++++++++++++++++++
 6 files changed, 807 insertions(+)

diff --git a/dev/docs/Tasks.txt b/dev/docs/Tasks.txt
index f3d4acd..8c6b306 100644
--- a/dev/docs/Tasks.txt
+++ b/dev/docs/Tasks.txt
@@ -311,6 +311,7 @@ SYSTEMDS-390 New Builtin Functions IV
  * 394 Builtin for one-hot encoding of matrix (not frame), see table  OK
  * 395 SVM rework and utils (confusionMatrix, msvmPredict)            OK
  * 396 Builtin for counting number of distinct values                 OK
+ * 397 Algorithm for neural collaborative filtering (NCF)             OK
 
 SYSTEMDS-400 Spark Backend Improvements
  * 401 Fix output block indexes of rdiag (diagM2V)                    OK
diff --git a/scripts/nn/examples/Example - Neural Collaborative Filtering.ipynb b/scripts/nn/examples/Example - Neural Collaborative Filtering.ipynb
new file mode 100644
index 0000000..5c047fd
--- /dev/null
+++ b/scripts/nn/examples/Example - Neural Collaborative Filtering.ipynb	
@@ -0,0 +1,347 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Neural Collaborative Filtering (NCF)\n",
+    "\n",
+    "This examples trains a neural network on the MovieLens data set using the concept of [Neural Collaborative Filtering (NCF)](https://dl.acm.org/doi/abs/10.1145/3038912.3052569) that is aimed at approaching recommendation problems using deep neural networks as opposed to common matrix factorization approaches."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Setup and Imports"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import pandas as pd\n",
+    "import matplotlib.pyplot as plt\n",
+    "from sklearn.model_selection import train_test_split"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Data - MovieLens"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The MovieLens data set is provided by the Unniversity of Minnesota and the GroupLens Research Group:\n",
+    "\n",
+    "> This dataset (ml-latest-small) describes 5-star rating and free-text tagging activity from [MovieLens](http://movielens.org/), a movie recommendation service. It contains 100836 ratings and 3683 tag applications across 9742 movies. These data were created by 610 users between March 29, 1996 and September 24, 2018. This dataset was generated on September 26, 2018.<br/>\n",
+    "Users were selected at random for inclusion. All selected users had rated at least 20 movies. No demographic information is included. Each user is represented by an id, and no other information is provided.<br/>\n",
+    "The data are contained in the files links.csv, movies.csv, ratings.csv and tags.csv. More details about the contents and use of all these files follows.<br/>\n",
+    "This is a development dataset. As such, it may change over time and is not an appropriate dataset for shared research results. See available benchmark datasets if that is your intent.<br/>\n",
+    "This and other GroupLens data sets are publicly available for download at http://grouplens.org/datasets/."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Archive:  ml-latest-small.zip\n",
+      "   creating: ml-latest-small/\n",
+      "  inflating: ml-latest-small/links.csv  \n",
+      "  inflating: ml-latest-small/tags.csv  \n",
+      "  inflating: ml-latest-small/ratings.csv  \n",
+      "  inflating: ml-latest-small/README.txt  \n",
+      "  inflating: ml-latest-small/movies.csv  \n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
+      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
+      "\r",
+      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0\r",
+      "  5  955k    5 50411    0     0  68679      0  0:00:14 --:--:--  0:00:14 68586\r",
+      "100  955k  100  955k    0     0   640k      0  0:00:01  0:00:01 --:--:--  640k\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%sh\n",
+    "DATASET=ml-latest-small\n",
+    "\n",
+    "mkdir -p data/$DATASET/\n",
+    "cd data/$DATASET\n",
+    "curl -O http://files.grouplens.org/datasets/movielens/$DATASET.zip\n",
+    "unzip $DATASET.zip"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Prepare Data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_loc = \"data/ml-latest-small/ml-latest-small/\"\n",
+    "negative_split = 1.5  # how many negatives for one positive\n",
+    "\n",
+    "# load interactions from MovieLens\n",
+    "raw_ratings = pd.read_csv(data_loc + \"ratings.csv\")\n",
+    "positives = pd.DataFrame(raw_ratings, columns=['userId', 'movieId'])\n",
+    "\n",
+    "# sample negatives\n",
+    "negatives = pd.DataFrame(columns=[\"userId\", \"movieId\"])\n",
+    "\n",
+    "while len(negatives) < len(positives) * negative_split:\n",
+    "    user = positives[\"userId\"].sample().values[0]\n",
+    "    movie = positives[\"movieId\"].sample().values[0]\n",
+    "    if len(positives.loc[(positives[\"userId\"] == user) & (positives[\"movieId\"] == movie)]) == 0:\n",
+    "        negatives = negatives.append({\"userId\": user, \"movieId\": movie}, ignore_index=True)\n",
+    "\n",
+    "# write out final data\n",
+    "targets = np.hstack([np.ones(len(positives)), np.zeros(len(negatives))])\n",
+    "all_ratings = np.vstack([positives, negatives])\n",
+    "\n",
+    "user_item_targets = np.hstack([all_ratings, targets[:, np.newaxis]])\n",
+    "\n",
+    "np.random.shuffle(user_item_targets)\n",
+    "\n",
+    "split = train_test_split(user_item_targets, train_size=0.8)\n",
+    "\n",
+    "np.savetxt(data_loc + \"sampled-train.csv\", split[0], delimiter=\",\")\n",
+    "np.savetxt(data_loc + \"sampled-test.csv\", split[1], delimiter=\",\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## SystemDS NCF implementation"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Train"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### with synthetic dummy data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Using user supplied systemds jar file target/SystemDS.jar\n",
+      "###############################################################################\n",
+      "#  SYSTEMDS_ROOT= .\n",
+      "#  SYSTEMDS_JAR_FILE= target/SystemDS.jar\n",
+      "#  CONFIG_FILE= --config ./target/testTemp/org/apache/sysds/api/mlcontext/MLContext/SystemDS-config.xml\n",
+      "#  LOG4JPROP= -Dlog4j.configuration=file:conf/log4j-silent.properties\n",
+      "#  CLASSPATH= target/SystemDS.jar:./lib/*:./target/lib/*\n",
+      "#  HADOOP_HOME= /Users/patrick/Uni Offline/Architectures of Machine Learning Systems (AMLS)/systemml/target/lib/hadoop\n",
+      "#\n",
+      "#  Running script scripts/nn/examples/ncf-dummy-data.dml locally with opts: \n",
+      "###############################################################################\n",
+      "Executing command:     java       -Xmx4g      -Xms4g      -Xmn400m   -cp target/SystemDS.jar:./lib/*:./target/lib/*   -Dlog4j.configuration=file:conf/log4j-silent.properties   org.apache.sysds.api.DMLScript   -f scripts/nn/examples/ncf-dummy-data.dml   -exec singlenode   --config ./target/testTemp/org/apache/sysds/api/mlcontext/MLContext/SystemDS-config.xml   \n",
+      "\n",
+      "NCF training starting with 1000 training samples, 100 validation samples, 50 items and 60 users...\n",
+      "Epoch: 1, Iter: 1, Train Loss: 0.6953457411615849, Train Accuracy: 0.5, Val Loss: 0.6995101788248107, Val Accuracy: 0.47\n",
+      "Epoch: 2, Iter: 1, Train Loss: 0.6667911468574823, Train Accuracy: 0.6875, Val Loss: 0.6992050630414124, Val Accuracy: 0.47\n",
+      "Epoch: 3, Iter: 1, Train Loss: 0.6570450250431727, Train Accuracy: 0.6875, Val Loss: 0.7014387912966833, Val Accuracy: 0.47\n",
+      "Epoch: 4, Iter: 1, Train Loss: 0.6521926651745862, Train Accuracy: 0.6875, Val Loss: 0.7053126102214489, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 5, Iter: 1, Train Loss: 0.6431405119563119, Train Accuracy: 0.6875, Val Loss: 0.7115121778198469, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 6, Iter: 1, Train Loss: 0.6353498336109219, Train Accuracy: 0.6875, Val Loss: 0.7193490066131873, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 7, Iter: 1, Train Loss: 0.6308046978859394, Train Accuracy: 0.6875, Val Loss: 0.7306240107462888, Val Accuracy: 0.48\n",
+      "Epoch: 8, Iter: 1, Train Loss: 0.6260145322748087, Train Accuracy: 0.75, Val Loss: 0.7435853055111923, Val Accuracy: 0.49\n",
+      "Epoch: 9, Iter: 1, Train Loss: 0.6163475345953953, Train Accuracy: 0.6875, Val Loss: 0.757023909929672, Val Accuracy: 0.5\n",
+      "Epoch: 10, Iter: 1, Train Loss: 0.6029424406867099, Train Accuracy: 0.6875, Val Loss: 0.7749021987872134, Val Accuracy: 0.51\n",
+      "Epoch: 11, Iter: 1, Train Loss: 0.5791958103856243, Train Accuracy: 0.8125, Val Loss: 0.7921418272873325, Val Accuracy: 0.51\n",
+      "Epoch: 12, Iter: 1, Train Loss: 0.5543597535155846, Train Accuracy: 0.8125, Val Loss: 0.8131440342665028, Val Accuracy: 0.5\n",
+      "Epoch: 13, Iter: 1, Train Loss: 0.5342062981571314, Train Accuracy: 0.8125, Val Loss: 0.8340415360672659, Val Accuracy: 0.45999999999999996\n",
+      "Epoch: 14, Iter: 1, Train Loss: 0.5156903349054259, Train Accuracy: 0.875, Val Loss: 0.8534000391024407, Val Accuracy: 0.47\n",
+      "Epoch: 15, Iter: 1, Train Loss: 0.5042912981017884, Train Accuracy: 0.8125, Val Loss: 0.873901869293276, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 16, Iter: 1, Train Loss: 0.48722704019844537, Train Accuracy: 0.8125, Val Loss: 0.898510539121238, Val Accuracy: 0.47\n",
+      "Epoch: 17, Iter: 1, Train Loss: 0.47048381704431463, Train Accuracy: 0.875, Val Loss: 0.9284775525937294, Val Accuracy: 0.48\n",
+      "Epoch: 18, Iter: 1, Train Loss: 0.45151030675588855, Train Accuracy: 0.875, Val Loss: 0.9574504971357228, Val Accuracy: 0.47\n",
+      "Epoch: 19, Iter: 1, Train Loss: 0.43940495503523824, Train Accuracy: 0.875, Val Loss: 0.9937811553464448, Val Accuracy: 0.45999999999999996\n",
+      "Epoch: 20, Iter: 1, Train Loss: 0.42553379542786246, Train Accuracy: 0.875, Val Loss: 1.0231502880025147, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 21, Iter: 1, Train Loss: 0.4163223594480222, Train Accuracy: 0.875, Val Loss: 1.0595479122098816, Val Accuracy: 0.45999999999999996\n",
+      "Epoch: 22, Iter: 1, Train Loss: 0.4050461773338017, Train Accuracy: 0.875, Val Loss: 1.0944624240337406, Val Accuracy: 0.48\n",
+      "Epoch: 23, Iter: 1, Train Loss: 0.3957080838041942, Train Accuracy: 0.875, Val Loss: 1.1315613394576827, Val Accuracy: 0.47\n",
+      "Epoch: 24, Iter: 1, Train Loss: 0.39252816032717697, Train Accuracy: 0.8125, Val Loss: 1.1608315131205158, Val Accuracy: 0.47\n",
+      "Epoch: 25, Iter: 1, Train Loss: 0.38656611677400526, Train Accuracy: 0.8125, Val Loss: 1.2010764396137235, Val Accuracy: 0.45999999999999996\n",
+      "Epoch: 26, Iter: 1, Train Loss: 0.3910140006546419, Train Accuracy: 0.8125, Val Loss: 1.2394434665872176, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 27, Iter: 1, Train Loss: 0.39012809759646405, Train Accuracy: 0.8125, Val Loss: 1.267704284952889, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 28, Iter: 1, Train Loss: 0.3986668930898999, Train Accuracy: 0.8125, Val Loss: 1.3134788291583197, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 29, Iter: 1, Train Loss: 0.39096586484137014, Train Accuracy: 0.8125, Val Loss: 1.3457368548231847, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 30, Iter: 1, Train Loss: 0.3913665786483714, Train Accuracy: 0.8125, Val Loss: 1.395200160764677, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 31, Iter: 1, Train Loss: 0.39306020872450564, Train Accuracy: 0.8125, Val Loss: 1.4547617764166234, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 32, Iter: 1, Train Loss: 0.3961123079325197, Train Accuracy: 0.8125, Val Loss: 1.4988918781732432, Val Accuracy: 0.45999999999999996\n",
+      "Epoch: 33, Iter: 1, Train Loss: 0.39167597788728836, Train Accuracy: 0.875, Val Loss: 1.5580225154760752, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 34, Iter: 1, Train Loss: 0.3936826951721131, Train Accuracy: 0.875, Val Loss: 1.592168642509798, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 35, Iter: 1, Train Loss: 0.39446093556125095, Train Accuracy: 0.8125, Val Loss: 1.6504423270813886, Val Accuracy: 0.43000000000000005\n",
+      "Epoch: 36, Iter: 1, Train Loss: 0.3917767876760818, Train Accuracy: 0.8125, Val Loss: 1.6894229810333048, Val Accuracy: 0.43000000000000005\n",
+      "Epoch: 37, Iter: 1, Train Loss: 0.3936299068718723, Train Accuracy: 0.8125, Val Loss: 1.7342536990495687, Val Accuracy: 0.43000000000000005\n",
+      "Epoch: 38, Iter: 1, Train Loss: 0.4086856463043926, Train Accuracy: 0.8125, Val Loss: 1.7709575584324264, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 39, Iter: 1, Train Loss: 0.3946728895715752, Train Accuracy: 0.8125, Val Loss: 1.8323990419424212, Val Accuracy: 0.43000000000000005\n",
+      "Epoch: 40, Iter: 1, Train Loss: 0.4092882424416999, Train Accuracy: 0.8125, Val Loss: 1.8647938002160964, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 41, Iter: 1, Train Loss: 0.4050641439255627, Train Accuracy: 0.8125, Val Loss: 1.891264442380163, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 42, Iter: 1, Train Loss: 0.4170644006779869, Train Accuracy: 0.8125, Val Loss: 1.9423174900115594, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 43, Iter: 1, Train Loss: 0.3923480753991977, Train Accuracy: 0.8125, Val Loss: 1.9731695043639572, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 44, Iter: 1, Train Loss: 0.40490676281916327, Train Accuracy: 0.8125, Val Loss: 2.010804834458905, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 45, Iter: 1, Train Loss: 0.40181821707001014, Train Accuracy: 0.8125, Val Loss: 2.051962004205519, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 46, Iter: 1, Train Loss: 0.40355348381441153, Train Accuracy: 0.8125, Val Loss: 2.0891022279849456, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 47, Iter: 1, Train Loss: 0.38715605504077866, Train Accuracy: 0.8125, Val Loss: 2.117280026954698, Val Accuracy: 0.44999999999999996\n",
+      "Epoch: 48, Iter: 1, Train Loss: 0.39836973023268446, Train Accuracy: 0.8125, Val Loss: 2.141835697116999, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 49, Iter: 1, Train Loss: 0.3901144594871556, Train Accuracy: 0.8125, Val Loss: 2.176511579483428, Val Accuracy: 0.43999999999999995\n",
+      "Epoch: 50, Iter: 1, Train Loss: 0.3917649057215277, Train Accuracy: 0.8125, Val Loss: 2.2288326304130806, Val Accuracy: 0.43999999999999995\n",
+      "NCF training completed after 50 epochs\n",
+      "SystemDS Statistics:\n",
+      "Total execution time:\t\t9.206 sec.\n",
+      "\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "20/05/28 15:04:03 INFO api.DMLScript: BEGIN DML run 05/28/2020 15:04:03\n",
+      "20/05/28 15:04:03 WARN util.NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
+      "20/05/28 15:04:13 INFO api.DMLScript: END DML run 05/28/2020 15:04:13\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%bash\n",
+    "cd ../../..\n",
+    "bin/systemds target/SystemDS.jar scripts/nn/examples/ncf-dummy-data.dml > scripts/nn/examples/run_log.txt && cat scripts/nn/examples/run_log.txt"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### with real data"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%%bash\n",
+    "cd ../../..\n",
+    "bin/systemds target/SystemDS.jar scripts/nn/examples/ncf-real-data.dml > scripts/nn/examples/run_log.txt && cat scripts/nn/examples/run_log.txt"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    " ### Plot training results"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8GearUAAAgAElEQVR4nOzdd3iUVdrA4d+ZyaT3RkkhlNARkEgRFAGVpuKqiCgqrIgd0NXVteEq9l39cEUUBBFEXZXFVYqIIrKA9N4DoSQBQkjv0873x4whQkICTJhk8tzXNdfMvPV5h/DMmXPOe47SWiOEEMJzGdwdgBBCiNoliV4IITycJHohhPBwkuiFEMLDSaIXQggP5+XuACoTGRmpExIS3B2GEELUG5s2bTqltY6qbF2dTPQJCQls3LjR3WEIIUS9oZQ6UtU6qboRQggPJ4leCCE8nCR6IYTw [...]
+      "text/plain": [
+       "<Figure size 432x288 with 1 Axes>"
+      ]
+     },
+     "metadata": {
+      "needs_background": "light"
+     },
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "log_name = \"run_log\"\n",
+    "txt_name = log_name + \".txt\"\n",
+    "csv_name = log_name + \".csv\"\n",
+    "\n",
+    "# convert to CSV\n",
+    "with open(txt_name, \"r\") as txt_file:\n",
+    "    data = txt_file.readlines()\n",
+    "    csv_lines = list(map(lambda x: x.replace(\"Epoch: \", \"\")\n",
+    "                         .replace(\", Iter: \", \",\")\n",
+    "                         .replace(\", Train Loss: \", \",\")\n",
+    "                         .replace(\", Train Accuracy: \", \",\")\n",
+    "                         .replace(\", Val Loss: \", \",\")\n",
+    "                         .replace(\", Val Accuracy: \", \",\"),\n",
+    "                            filter(lambda x: \"Epoch: \" in x, data)))\n",
+    "    with open(csv_name, \"w\") as csv_file:\n",
+    "        csv_file.write(\"epoch,iter,train_loss,train_acc,val_loss,val_acc\\n\")\n",
+    "        for item in csv_lines:\n",
+    "            csv_file.write(\"%s\" % item)\n",
+    "\n",
+    "# plot\n",
+    "log = pd.read_csv(csv_name)\n",
+    "plot_log = log[log[\"iter\"] == 1]\n",
+    "\n",
+    "for val in [\"train_loss\", \"train_acc\", \"val_loss\", \"val_acc\"]:\n",
+    "    plt.plot(plot_log[\"epoch\"], plot_log[val], label=val)\n",
+    "\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.7"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/scripts/nn/examples/README.md b/scripts/nn/examples/README.md
index dd30eb2..fd3c1a2 100644
--- a/scripts/nn/examples/README.md
+++ b/scripts/nn/examples/README.md
@@ -40,6 +40,13 @@ limitations under the License.
 * Training script: `mnist_lenet-train.dml`
 * Prediction script: `mnist_lenet-predict.dml`
 
+### Neural Collaborative Filtering
+
+* This example trains a neural network on the MovieLens data set using the concept of [Neural Collaborative Filtering (NCF)](https://dl.acm.org/doi/abs/10.1145/3038912.3052569) that is aimed at approaching recommendation problems using deep neural networks as opposed to common matrix factorization approaches.
+* As in the original paper, the targets are binary and only indicate whether a user has rated a movie or not. This makes the recommendation problem harder than working with the values of the ratings, but interaction data is in practice easier to collect.
+* MovieLens only provides positive interactions in form of ratings. We therefore randomly sample negative interactions as suggested by the original paper.
+* The implementation works with a fixed layer architecture with two embedding layers at the beginning for users and items, three dense layers with ReLu activations in the middle and a sigmoid activation for the final classification.
+
 ---
 
 # Setup
diff --git a/scripts/nn/examples/ncf-dummy-data.dml b/scripts/nn/examples/ncf-dummy-data.dml
new file mode 100644
index 0000000..fff5f63
--- /dev/null
+++ b/scripts/nn/examples/ncf-dummy-data.dml
@@ -0,0 +1,57 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Imports
+source("staging/NCF.dml") as NCF
+
+K_train = 1000; # number of training samples
+K_val = 100; # number of validation samples
+
+N = 50; # number items
+M = 60; # number users
+
+# targets
+targets_train = round(rand(rows=K_train, cols=1));
+targets_val = round(rand(rows=K_val, cols=1));
+
+# user/items integer-encoded vectors
+items_train_int_encoded = round(rand(rows=K_train, cols=1, min=1, max=N));
+users_train_int_encoded = round(rand(rows=K_train, cols=1, min=1, max=M));
+items_val_int_encoded = round(rand(rows=K_val, cols=1, min=1, max=N));
+users_val_int_encoded = round(rand(rows=K_val, cols=1, min=1, max=M));
+
+# user/items matrices by applying one-hot-encoding
+items_train = toOneHot(items_train_int_encoded, N);
+items_val = toOneHot(items_val_int_encoded, N);
+users_train = toOneHot(users_train_int_encoded, M);
+users_val = toOneHot(users_val_int_encoded, M);
+
+# Train
+epochs = 50;
+batch_size = 16;
+
+# layer dimensions
+E = 8; # embedding
+D1 = 64; # dense layer 1
+D2 = 32; # dense layer 2
+D3 = 16; # dense layer 3
+
+[biases, weights] = NCF::train(users_train, items_train, targets_train, users_val, items_val, targets_val, epochs, batch_size, E, D1, D2, D3);
\ No newline at end of file
diff --git a/scripts/nn/examples/ncf-real-data.dml b/scripts/nn/examples/ncf-real-data.dml
new file mode 100644
index 0000000..2e0a2e8
--- /dev/null
+++ b/scripts/nn/examples/ncf-real-data.dml
@@ -0,0 +1,65 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Imports
+source("staging/NCF.dml") as NCF
+
+# prepare input data
+
+data_loc = "scripts/nn/examples/data/ml-latest-small/ml-latest-small/"
+
+# - read user/items integer-encoded vectors
+train = read(data_loc + "sampled-train.csv", format="csv", header=FALSE, sep=",");
+val = read(data_loc + "sampled-test.csv", format="csv", header=FALSE, sep=",");
+
+users_train_int_encoded = train[, 1];
+items_train_int_encoded = train[, 2];
+targets_train = train[, 3];
+
+users_val_int_encoded = val[, 1];
+items_val_int_encoded = val[, 2];
+targets_val = val[, 3];
+
+N = max(max(items_train_int_encoded), max(items_val_int_encoded)); # number items
+M = max(max(users_train_int_encoded), max(users_val_int_encoded)); # number users
+
+print("Done reading.");
+
+# - create user/items matrices by applying one-hot-encoding
+items_train = toOneHot(items_train_int_encoded, N);
+items_val = toOneHot(items_val_int_encoded, N);
+users_train = toOneHot(users_train_int_encoded, M);
+users_val = toOneHot(users_val_int_encoded, M);
+
+print("Done encoding.");
+
+# Train
+
+epochs = 20;
+batch_size = 16;
+
+# layer dimensions
+E = 8; # embedding
+D1 = 64; # dense layer 1
+D2 = 32; # dense layer 2
+D3 = 16; # dense layer 3
+
+[biases, weights] = NCF::train(users_train, items_train, targets_train, users_val, items_val, targets_val, epochs, batch_size, E, D1, D2, D3);
\ No newline at end of file
diff --git a/scripts/staging/NCF.dml b/scripts/staging/NCF.dml
new file mode 100644
index 0000000..0719b58
--- /dev/null
+++ b/scripts/staging/NCF.dml
@@ -0,0 +1,330 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+#
+# Neural Collaborative Filtering
+#
+
+# Imports
+source("nn/optim/adam.dml") as adam
+source("nn/layers/relu.dml") as relu
+source("nn/layers/sigmoid.dml") as sigmoid
+source("nn/layers/affine.dml") as affine
+source("nn/layers/log_loss.dml") as log_loss
+source("nn/layers/l2_reg.dml") as l2_reg
+
+train = function( matrix[double] users_train, 
+                  matrix[double] items_train, 
+                  matrix[double] targets_train, 
+                  matrix[double] users_val, 
+                  matrix[double] items_val, 
+                  matrix[double] targets_val,
+                  integer epochs,
+                  integer batch_size,
+                  integer E,
+                  integer D1,
+                  integer D2,
+                  integer D3)
+    return (List[unknown] biases, List[unknown] weights) {
+  # /*
+  #  * Train NCF model
+  #  *
+  #  * Inputs:
+  #  *  - users_train: matrix of shape K_train × M with K_train samples of one-hot encoded users
+  #  *  - items_train: matrix of shape K_train × N with K_train samples of one-hot encoded items
+  #  *  - targets_train: vector with K_train entries containing either 0 or 1 indicating whether the user interacted with the item
+  #  *  - users_val: matrix of shape K_val × M with K_val samples of one-hot encoded users
+  #  *  - items_val: matrix of shape K_val × N with K_val samples of one-hot encoded items
+  #  *  - targets_val: vector with K_val entries containing either 0 or 1 indicating whether the user interacted with the item
+  #  *  - epochs: number of training epochs
+  #  *  - batch_size: size of the training batches
+  #  *  - E:  dimension of embedding layers
+  #  *  - D1: dimension of dense layer 1
+  #  *  - D2: dimension of dense layer 2
+  #  *  - D3: dimension of dense layer 3
+  #  *
+  #  * Outputs:
+  #  *  - biases: list of biases
+  #  *  - weights: list o weights
+  #  *
+  #  * Network Architecture:
+  #
+  # +----------------------+    +----------------------+
+  # |User Embedding [users]|    |Item Embedding [items]|
+  # +----------+-----------+    +---------+------------+
+  #           |                          |
+  #           |                          |
+  #           |       +-----------+      |
+  #           +------>+Concatenate+<-----+
+  #                   +-----+-----+
+  #                         |
+  #                         v
+  #                     +----+----+
+  #                     | Dense 1 |
+  #                     | (ReLU)  |
+  #                     +----+----+
+  #                         |
+  #                         v
+  #                     +----+----+
+  #                     | Dense 2 |
+  #                     | (ReLU)  |
+  #                     +----+----+
+  #                         |
+  #                         v
+  #                     +----+----+
+  #                     | Dense 3 |
+  #                     | (ReLU)  |
+  #                     +----+----+
+  #                         |
+  #                         v
+  #                   +-----+-----+
+  #                   |Prediction |
+  #                   |(Sigmoid)  |
+  #                   +-----------+
+  #  *
+  #  */
+
+  # sanity checks
+  assert(nrow(items_train) == nrow(users_train));
+  assert(nrow(users_train) == nrow(targets_train));
+  assert(nrow(items_val) == nrow(users_val));
+  assert(nrow(users_val) == nrow(targets_val));
+  assert(ncol(items_val) == ncol(items_train));
+  assert(ncol(users_val) == ncol(users_train));
+  
+  assert(ncol(targets_val) == ncol(targets_train));
+  assert(ncol(targets_train) == 1);
+
+  K_train = nrow(targets_train); # number of training samples
+  K_val = nrow(targets_val); # number of validation samples
+
+  N = ncol(items_train); # number items
+  M = ncol(users_train); # number users
+
+  print("NCF training starting with " 
+          + K_train + " training samples, " 
+          + K_val + " validation samples, " 
+          + N + " items and "
+          + M + " users...");
+
+  # 1.initialize layers
+  [W_U,  b_U] = affine::init(M, E); # user embedding
+  [W_I,  b_I] = affine::init(N, E); # item embedding
+
+  [W_D1, b_D1] = affine::init(2 * E, D1); # dense layer 1
+  [W_D2, b_D2] = affine::init(D1, D2);    # dense layer 2
+  [W_D3, b_D3] = affine::init(D2, D3);    # dense layer 3
+
+  [W_F,  b_F] = affine::init(D3, 1); # final prediction
+
+  # initialize bias and weight lists
+  biases = list(b_U, b_I, b_D1, b_D2, b_D3, b_F);
+  weights = list(W_U, W_I, W_D1, W_D2, W_D3, W_F);
+
+  # 2.initialize adam optimizer
+  ## Default values for some parameters
+  lr      = 0.001;
+  beta1   = 0.9;       # [0, 1)
+  beta2   = 0.999;     # [0, 1)
+  epsilon = 0.0000001;
+  t       = 0;
+
+  # (1) user embedding
+  [mW_U, vW_U] = adam::init(W_U);
+  [mb_U, vb_U] = adam::init(b_U);
+
+  # (1) item embedding
+  [mW_I, vW_I] = adam::init(W_I);
+  [mb_I, vb_I] = adam::init(b_I);
+
+  # (2) Dense 1
+  [mW_D1, vW_D1] = adam::init(W_D1);
+  [mb_D1, vb_D1] = adam::init(b_D1);
+
+  # (3) Dense 2
+  [mW_D2, vW_D2] = adam::init(W_D2);
+  [mb_D2, vb_D2] = adam::init(b_D2);
+
+  # (3) Dense 3
+  [mW_D3, vW_D3] = adam::init(W_D3);
+  [mb_D3, vb_D3] = adam::init(b_D3);
+
+  # (N) final prediction
+  [mW_F, vW_F] = adam::init(W_F);
+  [mb_F, vb_F] = adam::init(b_F);
+
+  # Optimize 
+  N = K_train;
+  iters = ceil(N / batch_size);
+
+  for (e in 1:epochs) {
+    for (i in 1:iters) {
+      # Get the next batch
+      beg = ((i-1) * batch_size) %% N + 1;
+      end = min(N, beg + batch_size - 1);
+      
+      items_batch = items_train[beg:end,];
+      users_batch = users_train[beg:end,];
+      y_batch = targets_train[beg:end,];
+
+      # 3.Send inputs through layers and get activations
+      [out_FA, out_F, out_D1A, out_D1, out_D2A, out_D2, out_D3A, out_D3, out_concat, out_U, out_I] = 
+        predict(users_batch, items_batch, biases, weights);
+
+      # 4.compute final error gradients 
+      # params: (predictions, targets)
+      dout = log_loss::backward(out_FA, y_batch);
+
+      # Compute loss & accuracy for training & validation data in the last iteration
+      if (i %% 100 == 1) {
+        # Compute training loss & accuracy
+        [loss, accuracy] = eval(out_FA, y_batch);
+
+        # Compute validation loss & accuracy
+        out_FA_val = predict(users_val, items_val, biases, weights);
+        [loss_val, accuracy_val] = eval(out_FA_val, targets_val);
+
+        # Output results
+        print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", Train Accuracy: "
+                + accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + accuracy_val)
+      }
+
+      # 5.Backpropagation
+      # params: (gradient from upstream, activation, weights, biases) 
+      dout_FA = sigmoid::backward(dout, out_F);
+      [dout_F, dW_F, db_F] = affine::backward(dout_FA, out_D3A, W_F, b_F);
+
+      dout_D3A = relu::backward(dout_F, out_D3);
+      [dout_D3, dW_D3, db_D3] = affine::backward(dout_D3A, out_D2A, W_D3, b_D3); 
+
+      dout_D2A = relu::backward(dout_D3, out_D2);
+      [dout_D2, dW_D2, db_D2] = affine::backward(dout_D2A, out_D1A, W_D2, b_D2); 
+
+      dout_D1A = relu::backward(dout_D2, out_D1);
+      [dout_D1, dW_D1, db_D1] = affine::backward(dout_D1A, out_concat, W_D1, b_D1); 
+
+      # backprop concatenation: split the gradients up
+      dout_U = dout_D1[,1:E];
+      dout_I = dout_D1[,E+1:2*E];
+
+      [dUsers, dW_U, db_U] = affine::backward(dout_U, users_batch, W_U, b_U); 
+      [dItems, dW_I, db_I] = affine::backward(dout_I, items_batch, W_I, b_I); 
+
+      # 6.update timestep
+      t = e * i - 1;
+
+      # 7.Call adam::update for all parameters
+      [b_U, mb_U, vb_U] = adam::update(b_U, db_U, lr, beta1, beta2, epsilon, t, mb_U, vb_U);
+      [W_U, mW_U, vW_U] = adam::update(W_U, dW_U, lr, beta1, beta2, epsilon, t, mW_U, vW_U);
+
+      [b_I, mb_I, vb_I] = adam::update(b_I, db_I, lr, beta1, beta2, epsilon, t, mb_I, vb_I);
+      [W_I, mW_I, vW_I] = adam::update(W_I, dW_I, lr, beta1, beta2, epsilon, t, mW_I, vW_I);
+
+      [b_D1, mb_D1, vb_D1] = adam::update(b_D1, db_D1, lr, beta1, beta2, epsilon, t, mb_D1, vb_D1);
+      [W_D1, mW_D1, vW_D1] = adam::update(W_D1, dW_D1, lr, beta1, beta2, epsilon, t, mW_D1, vW_D1);
+      
+      [b_D2, mb_D2, vb_D2] = adam::update(b_D2, db_D2, lr, beta1, beta2, epsilon, t, mb_D2, vb_D2);
+      [W_D2, mW_D2, vW_D2] = adam::update(W_D2, dW_D2, lr, beta1, beta2, epsilon, t, mW_D2, vW_D2);
+      
+      [b_D3, mb_D3, vb_D3] = adam::update(b_D3, db_D3, lr, beta1, beta2, epsilon, t, mb_D3, vb_D3);
+      [W_D3, mW_D3, vW_D3] = adam::update(W_D3, dW_D3, lr, beta1, beta2, epsilon, t, mW_D3, vW_D3);
+      
+      [b_F, mb_F, vb_F] = adam::update(b_F, db_F, lr, beta1, beta2, epsilon, t, mb_F, vb_F);
+      [W_F, mW_F, vW_F] = adam::update(W_F, dW_F, lr, beta1, beta2, epsilon, t, mW_F, vW_F);
+
+      # 8. Update lists
+      biases = list(b_U, b_I, b_D1, b_D2, b_D3, b_F);
+      weights = list(W_U, W_I, W_D1, W_D2, W_D3, W_F);
+    }
+  }
+
+  print("NCF training completed after " + epochs + " epochs")
+}
+
+predict = function(matrix[double] users, matrix[double] items, List[unknown] biases, List[unknown] weights)
+    return (matrix[double] out_FA, matrix[double] out_F, 
+            matrix[double] out_D1A, matrix[double] out_D1, 
+            matrix[double] out_D2A, matrix[double] out_D2, 
+            matrix[double] out_D3A, matrix[double] out_D3, 
+            matrix[double] out_concat, matrix[double] out_U, matrix[double] out_I) {
+  #
+  # Computes the predictions for the given inputs.
+  #
+  # Inputs:
+  #  - users : K user examples with E features, of shape (K, E).
+  #  - items : K item examples with E features, of shape (K, E).
+  #  - biases, weights : list of trained model parameters
+  #
+  # Outputs:
+  #  - out : target vector, y.
+  #
+
+  # parse parameters
+  b_U = as.matrix(biases[1]);
+  b_I = as.matrix(biases[2]);
+  b_D1 = as.matrix(biases[3]);
+  b_D2 = as.matrix(biases[4]);
+  b_D3 = as.matrix(biases[5]);
+  b_F = as.matrix(biases[6]);
+
+  W_U = as.matrix(weights[1]);
+  W_I = as.matrix(weights[2]);
+  W_D1 = as.matrix(weights[3]);
+  W_D2 = as.matrix(weights[4]);
+  W_D3 = as.matrix(weights[5]);
+  W_F = as.matrix(weights[6]);
+
+  # send inputs through layers
+
+  # (1) User and item embeddings + concatenation
+  out_U = affine::forward(users, W_U, b_U);
+  out_I = affine::forward(items, W_I, b_I);
+
+  out_concat = cbind(out_U, out_I); 
+
+  # (2) Dense layers
+  out_D1 = affine::forward(out_concat, W_D1, b_D1);
+  out_D1A = relu::forward(out_D1); # separate "activation" for ReLU activation function
+
+  out_D2 = affine::forward(out_D1A, W_D2, b_D2);
+  out_D2A = relu::forward(out_D2); # separate "activation" for ReLU activation function
+
+  out_D3 = affine::forward(out_D2A, W_D3, b_D3);
+  out_D3A = relu::forward(out_D3); # separate "activation" for ReLU activation function
+
+  # (N) final prediction
+  out_F = affine::forward(out_D3A, W_F, b_F);
+  out_FA = sigmoid::forward(out_F);
+}
+
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+   /*
+    * Computes loss and accuracy.
+    */
+
+    # compute the log loss
+    loss = log_loss::forward(probs, y);
+
+    # compute accuracy
+    Z = probs >= 0.5;
+    accuracy = 1 - sum(abs(Z - y)) / nrow(y);
+}