You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by du...@apache.org on 2017/08/18 18:07:17 UTC

systemml git commit: [SYSTEMML-1813] Breast cancer preprocessing simplification and cleanup

Repository: systemml
Updated Branches:
  refs/heads/master c18352f29 -> ae4c00682


[SYSTEMML-1813] Breast cancer preprocessing simplification and cleanup

In anticipation of near-future algorithmic improvements to the
preprocessing to improve model training, this simplifies and cleans up
the preprocessing code as follows.

- Previously, we were processing all slides into one large saved
DataFrame, and then splitting that DataFrame into train and validation
DataFrames.  This commit simplifies this by splitting the slide numbers
into train and validation sets, and then processing those slides
separately.  This effectively skips the creation of the large DataFrame,
and removes the need to split that large DataFrame into train/val ones,
which should provide a large performance benefit.  The DataFrame `union`
method can be used to combine two DataFrames row-wise.
- Previously, we maintained a list of "broken" slides that were manually
removed.  This commit removes that manual list, and instead adds a
try/except filtering step to automatically remove problematic slides.
- This commit moves ad-hoc sampling code into a new `sample` function.
- This commit moves code to add row indices to a DataFrame into a new
`add_row_indices` function.

The benefit is that near-future algorithmic improvements to the
preprocessing code will be much easier to incorporate.

Closes #597.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ae4c0068
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ae4c0068
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ae4c0068

Branch: refs/heads/master
Commit: ae4c00682361148284f37579f756ea3ee993f272
Parents: c18352f
Author: Mike Dusenberry <mw...@us.ibm.com>
Authored: Fri Aug 18 11:06:18 2017 -0700
Committer: Mike Dusenberry <mw...@us.ibm.com>
Committed: Fri Aug 18 11:06:18 2017 -0700

----------------------------------------------------------------------
 .../MachineLearning-Keras-ResNet50.ipynb        |  55 +++++--
 .../Preprocessing-Save-JPEGs.ipynb              |   2 +-
 .../breast_cancer/breastcancer/preprocessing.py | 153 ++++++++-----------
 projects/breast_cancer/preprocess.py            | 127 ++++++---------
 4 files changed, 156 insertions(+), 181 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ae4c0068/projects/breast_cancer/MachineLearning-Keras-ResNet50.ipynb
----------------------------------------------------------------------
diff --git a/projects/breast_cancer/MachineLearning-Keras-ResNet50.ipynb b/projects/breast_cancer/MachineLearning-Keras-ResNet50.ipynb
index 331b666..bafa74a 100644
--- a/projects/breast_cancer/MachineLearning-Keras-ResNet50.ipynb
+++ b/projects/breast_cancer/MachineLearning-Keras-ResNet50.ipynb
@@ -10,7 +10,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "%load_ext autoreload\n",
@@ -54,7 +56,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "# os.environ['CUDA_VISIBLE_DEVICES'] = \"\"\n",
@@ -100,7 +104,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "def get_run_dir(path, new_run):\n",
@@ -178,7 +184,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "K.image_data_format()"
@@ -187,7 +195,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "train_save_dir = \"images/{stage}/{p}\".format(stage=train_dir, p=p)\n",
@@ -198,7 +208,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "# Create train & val image generators\n",
@@ -243,6 +255,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
+    "collapsed": true,
     "scrolled": false
    },
    "outputs": [],
@@ -280,7 +293,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "class_counts = np.bincount(train_generator_orig.classes)\n",
@@ -320,7 +335,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "def plot(gen):\n",
@@ -482,6 +499,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
+    "collapsed": true,
     "scrolled": true
    },
    "outputs": [],
@@ -501,6 +519,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
+    "collapsed": true,
     "scrolled": true
    },
    "outputs": [],
@@ -522,6 +541,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
+    "collapsed": true,
     "scrolled": true
    },
    "outputs": [],
@@ -551,6 +571,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
+    "collapsed": true,
     "scrolled": true
    },
    "outputs": [],
@@ -589,7 +610,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "print(model.summary())"
@@ -609,7 +632,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "initial_epoch = epochs\n",
@@ -631,7 +656,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "raw_metrics = model.evaluate_generator(val_generator, steps=val_batches) #,\n",
@@ -656,7 +683,9 @@
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "collapsed": true
+   },
    "outputs": [],
    "source": [
     "filename = \"{acc:.5}_acc_{loss:.5}_loss_model.hdf5\".format(**metrics)\n",
@@ -709,7 +738,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.6.1"
+   "version": "3.6.2"
   }
  },
  "nbformat": 4,

http://git-wip-us.apache.org/repos/asf/systemml/blob/ae4c0068/projects/breast_cancer/Preprocessing-Save-JPEGs.ipynb
----------------------------------------------------------------------
diff --git a/projects/breast_cancer/Preprocessing-Save-JPEGs.ipynb b/projects/breast_cancer/Preprocessing-Save-JPEGs.ipynb
index 7e893f7..d60e622 100644
--- a/projects/breast_cancer/Preprocessing-Save-JPEGs.ipynb
+++ b/projects/breast_cancer/Preprocessing-Save-JPEGs.ipynb
@@ -602,7 +602,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.6.1"
+   "version": "3.6.2"
   }
  },
  "nbformat": 4,

http://git-wip-us.apache.org/repos/asf/systemml/blob/ae4c0068/projects/breast_cancer/breastcancer/preprocessing.py
----------------------------------------------------------------------
diff --git a/projects/breast_cancer/breastcancer/preprocessing.py b/projects/breast_cancer/breastcancer/preprocessing.py
index cfedf94..763bde5 100644
--- a/projects/breast_cancer/breastcancer/preprocessing.py
+++ b/projects/breast_cancer/breastcancer/preprocessing.py
@@ -32,6 +32,7 @@ import os
 
 import numpy as np
 import openslide
+from openslide import OpenSlideError
 from openslide.deepzoom import DeepZoomGenerator
 import pandas as pd
 from pyspark.ml.linalg import Vectors
@@ -66,7 +67,10 @@ def open_slide(slide_num, folder, training):
     # Testing images
     filename = os.path.join(folder, "testing_image_data",
                             "TUPAC-TE-{}.svs".format(str(slide_num).zfill(3)))
-  slide = openslide.open_slide(filename)
+  try:
+    slide = openslide.open_slide(filename)
+  except OpenSlideError:
+    slide = None
   return slide
 
 
@@ -235,7 +239,7 @@ def keep_tile(tile_tuple, tile_size, tissue_threshold):
   Args:
     tile_tuple: A (slide_num, tile) tuple, where slide_num is an
       integer, and tile is a 3D NumPy array of shape
-      (tile_size, tile_size, channels) in RGB format.
+      (tile_size, tile_size, channels).
     tile_size: The width and height of a square tile to be generated.
     tissue_threshold: Tissue percentage threshold.
 
@@ -485,7 +489,7 @@ def flatten_sample(sample_tuple):
 
 # Get Ground Truth Labels
 
-def get_labels_df(folder):
+def get_labels_df(folder, filename="training_ground_truth.csv"):
   """
   Create a DataFrame with the ground truth labels for each slide.
 
@@ -498,7 +502,7 @@ def get_labels_df(folder):
     A Pandas DataFrame containing the ground truth labels for each
     slide.
   """
-  filepath = os.path.join(folder, "training_ground_truth.csv")
+  filepath = os.path.join(folder, filename)
   labels_df = pd.read_csv(filepath, names=["tumor_score", "molecular_score"], header=None)
   labels_df["slide_num"] = labels_df.index + 1  # slide numbering starts at 1
   labels_df.set_index("slide_num", drop=False, inplace=True)  # use the slide num as index
@@ -546,7 +550,13 @@ def preprocess(spark, slide_nums, folder="data", training=True, tile_size=1024,
     A Spark DataFrame in which each row contains the slide number, tumor
     score, molecular score, and the sample stretched out into a Vector.
   """
-  slides = spark.sparkContext.parallelize(slide_nums)
+  # Filter out broken slides
+  # Note: "Broken" here is due to a "version of OpenJPEG with broken support for chroma-subsampled
+  # images".
+  slides = (spark.sparkContext
+      .parallelize(slide_nums)
+      .filter(lambda slide: open_slide(slide, folder, training) is not None))
+
   # Create DataFrame of all tile locations and increase number of partitions
   # to avoid OOM during subsequent processing.
   tile_indices = (slides.flatMap(
@@ -561,7 +571,8 @@ def preprocess(spark, slide_nums, folder="data", training=True, tile_size=1024,
   #num_parts = rows / rows_per_part
   tile_indices = tile_indices.repartition(num_partitions)
   tile_indices.cache()
-  # Extract all tiles into a DataFrame, filter, cut into smaller samples, apply stain
+
+  # Extract all tiles into an RDD, filter, cut into smaller samples, apply stain
   # normalization, and flatten.
   tiles = tile_indices.map(lambda tile_index: process_tile_index(tile_index, folder, training))
   filtered_tiles = tiles.filter(lambda tile: keep_tile(tile, tile_size, tissue_threshold))
@@ -569,6 +580,8 @@ def preprocess(spark, slide_nums, folder="data", training=True, tile_size=1024,
   if normalize_stains:
     samples = samples.map(lambda sample: normalize_staining(sample))
   samples = samples.map(lambda sample: flatten_sample(sample))
+
+  # Convert to a DataFrame
   if training:
     # Append labels
     labels_df = get_labels_df(folder)
@@ -584,87 +597,6 @@ def preprocess(spark, slide_nums, folder="data", training=True, tile_size=1024,
   return df
 
 
-# Split Into Separate Train & Validation DataFrames Based On Slide Number
-
-def train_val_split(spark, df, slide_nums, folder, train_frac=0.8, add_row_indices=True, seed=None,
-                    debug=False):
-  """
-  Split a DataFrame of slide samples into training and validation sets.
-
-  Args:
-    spark: SparkSession.
-    df: A Spark DataFrame in which each row contains the slide number,
-    tumor score, molecular score, and the sample stretched out into
-    a Vector.
-    slide_nums: A list of slide numbers to sample from.
-    folder: Directory containing a `training_ground_truth.csv` file
-      containing the ground truth "tumor_score" and "molecular_score"
-      labels for each slide.
-    train_frac: Fraction of the data to assign to the training set, with
-      `1-frac` assigned to the valiation set.
-    add_row_indices: Boolean for whether or not to prepend an index
-      column contain the row index for use downstream by SystemML.
-      The column name will be "__INDEX".
-
-  Returns:
-    A Spark DataFrame in which each row contains the slide number, tumor
-    score, molecular score, and the sample stretched out into a Vector.
-  """
-  # Create DataFrame of labels for the given slide numbers.
-  labels_df = get_labels_df(folder)
-  labels_df = labels_df.loc[slide_nums]
-
-  # Randomly split slides 80%/20% into train and validation sets.
-  train_nums_df = labels_df.sample(frac=train_frac, random_state=seed)
-  val_nums_df = labels_df.drop(train_nums_df.index)
-
-  train_nums = (spark.createDataFrame(train_nums_df)
-                     .selectExpr("cast(slide_num as int)")
-                     .coalesce(1))
-  val_nums = (spark.createDataFrame(val_nums_df)
-                   .selectExpr("cast(slide_num as int)")
-                   .coalesce(1))
-
-  # Note: Explicitly mark the smaller DataFrames as able to be broadcasted
-  # in order to have Catalyst choose the more efficient BroadcastHashJoin,
-  # rather than the costly SortMergeJoin.
-  train = df.join(F.broadcast(train_nums), on="slide_num")
-  val = df.join(F.broadcast(val_nums), on="slide_num")
-
-  if debug:
-    # DEBUG: Sanity checks.
-    assert len(pd.merge(train_nums_df, val_nums_df, on="slide_num")) == 0
-    assert train_nums.join(val_nums, on="slide_num").count() == 0
-    assert train.join(val, on="slide_num").count() == 0
-    #  - Check distributions.
-    for pdf in train_nums_df, val_nums_df:
-      print(pdf.count())
-      print(pdf["tumor_score"].value_counts(sort=False))
-      print(pdf["tumor_score"].value_counts(normalize=True, sort=False), "\n")
-    #  - Check total number of examples in each.
-    print(train.count(), val.count())
-    #  - Check physical plans for broadcast join.
-    print(train.explain(), val.explain())
-
-  # Add row indices for use with SystemML.
-  if add_row_indices:
-    train = (train.rdd
-                  .zipWithIndex()
-                  .map(lambda r: (r[1] + 1, *r[0]))  # flatten & convert index to 1-based indexing
-                  .toDF(['__INDEX', 'slide_num', 'tumor_score', 'molecular_score', 'sample']))
-    train = train.select(train["__INDEX"].astype("int"), train.slide_num.astype("int"),
-                         train.tumor_score.astype("int"), train.molecular_score, train["sample"])
-
-    val = (val.rdd
-              .zipWithIndex()
-              .map(lambda r: (r[1] + 1, *r[0]))  # flatten & convert index to 1-based indexing
-              .toDF(['__INDEX', 'slide_num', 'tumor_score', 'molecular_score', 'sample']))
-    val = val.select(val["__INDEX"].astype("int"), val.slide_num.astype("int"),
-                     val.tumor_score.astype("int"), val.molecular_score, val["sample"])
-
-  return train, val
-
-
 # Save DataFrame
 
 def save(df, filepath, sample_size, grayscale, mode="error", format="parquet", file_size=128):
@@ -694,3 +626,50 @@ def save(df, filepath, sample_size, grayscale, mode="error", format="parquet", f
   rows_per_file = round(file_size / row_mb)
   df.write.option("maxRecordsPerFile", rows_per_file).mode(mode).save(filepath, format=format)
 
+
+# Utilities
+
+def add_row_indices(df, training=True):
+  """
+  Add a row index column for faster data ingestion times with SystemML.
+
+  Args:
+    df: A Spark DataFrame in which each row contains the slide number,
+      tumor score, molecular score, and the sample stretched out into a
+      Vector.
+    training: Boolean for training or testing datasets.
+
+  Returns:
+    The Spark DataFrame with a row index column called "__INDEX".
+  """
+  rdd = (df.rdd
+           .zipWithIndex()
+           .map(lambda r: (r[1] + 1, *r[0])))  # flatten & convert index to 1-based indexing
+  if training:
+    df = rdd.toDF(['__INDEX', 'slide_num', 'tumor_score', 'molecular_score', 'sample'])
+    df = df.select(df["__INDEX"].astype("int"), df.slide_num.astype("int"),
+                   df.tumor_score.astype("int"), df.molecular_score, df["sample"])
+  else:  # testing data -- no labels
+    df = rdd.toDF(["__INDEX", "slide_num", "sample"])
+    df = df.select(df["__INDEX"].astype("int"), df.slide_num.astype("int"), df["sample"])
+  return df
+
+
+def sample(df, frac, training=True, seed=None):
+  """
+  Sample the DataFrame, stratified on the class.
+
+  Args:
+    df: A Spark DataFrame in which each row contains the slide number,
+      tumor score, molecular score, and the sample stretched out into a
+      Vector.
+    frac: Fraction of rows to keep.
+    training: Boolean for training or testing datasets.
+    seed: Random seed used for the sampling.
+
+  Returns:
+    A stratified sample of the original Spark DataFrame.
+  """
+  df_sample = df.sampleBy("tumor_score", fractions={1: frac, 2: frac, 3: frac}, seed=seed)
+  return df_sample
+

http://git-wip-us.apache.org/repos/asf/systemml/blob/ae4c0068/projects/breast_cancer/preprocess.py
----------------------------------------------------------------------
diff --git a/projects/breast_cancer/preprocess.py b/projects/breast_cancer/preprocess.py
index 167fa61..e90fe8c 100644
--- a/projects/breast_cancer/preprocess.py
+++ b/projects/breast_cancer/preprocess.py
@@ -29,10 +29,12 @@ import os
 import shutil
 
 import numpy as np
-
-from breastcancer.preprocessing import preprocess, save, train_val_split
+import pandas as pd
+from sklearn.model_selection import train_test_split
 from pyspark.sql import SparkSession
 
+from breastcancer.preprocessing import add_row_indices, get_labels_df, preprocess, save
+
 
 # Create new SparkSession
 spark = (SparkSession.builder
@@ -56,93 +58,58 @@ spark.sparkContext.addPyFile(zipname)
 # procedure of the paper.  Look into simply selecting tiles of the
 # desired size to begin with.
 
-# Get list of image numbers, minus the broken ones.
-broken = {2, 45, 91, 112, 242, 256, 280, 313, 329, 467}
-slide_nums = sorted(set(range(1,501)) - broken)
-
 # Settings
-training = True
+# TODO: Convert this to a set of parsed command line arguments
 tile_size = 256
 sample_size = 256
 grayscale = False
 num_partitions = 20000
-add_row_indices = True
+training = True
+row_indices = False
 train_frac = 0.8
-split_seed = 24
-folder = "data"  # Linux-filesystem directory to read raw data
+sample_frac=0.01
+seed = 42
+folder = "data"  # Linux-filesystem directory to read raw WSI data
 save_folder = "data"  # Hadoop-supported directory in which to save DataFrames
-df_path = os.path.join(save_folder, "samples_{}_{}{}.parquet".format(
-    "labels" if training else "testing", sample_size, "_grayscale" if grayscale else ""))
 train_df_path = os.path.join(save_folder, "train_{}{}.parquet".format(sample_size,
     "_grayscale" if grayscale else ""))
 val_df_path = os.path.join(save_folder, "val_{}{}.parquet".format(sample_size,
     "_grayscale" if grayscale else ""))
-
-# Process all slides.
-df = preprocess(spark, slide_nums, tile_size=tile_size, sample_size=sample_size,
-                grayscale=grayscale, training=training, num_partitions=num_partitions,
-                folder=folder)
-
-# Save DataFrame of samples.
-save(df, df_path, sample_size, grayscale)
-
-# Load full DataFrame from disk.
-df = spark.read.load(df_path)
-
-# Split into train and validation DataFrames based On slide number
-train, val = train_val_split(spark, df, slide_nums, folder, train_frac, add_row_indices,
-                             seed=split_seed)
-
-# Save train and validation DataFrames.
-save(train, train_df_path, sample_size, grayscale)
-save(val, val_df_path, sample_size, grayscale)
-
-
-# ---
-#
-# Sample Data
-## TODO: Wrap this in a function with appropriate default arguments
-
-# Load train and validation DataFrames from disk.
-train = spark.read.load(train_df_path)
-val = spark.read.load(val_df_path)
-
-# Take a stratified sample.
-p=0.01
-train_sample = train.drop("__INDEX").sampleBy("tumor_score", fractions={1: p, 2: p, 3: p}, seed=42)
-val_sample = val.drop("__INDEX").sampleBy("tumor_score", fractions={1: p, 2: p, 3: p}, seed=42)
-
-# Reassign row indices.
-# TODO: Wrap this in a function with appropriate default arguments.
-train_sample = (
-  train_sample.rdd
-              .zipWithIndex()
-              .map(lambda r: (r[1] + 1, *r[0]))
-              .toDF(['__INDEX', 'slide_num', 'tumor_score', 'molecular_score', 'sample']))
-train_sample = train_sample.select(train_sample["__INDEX"].astype("int"),
-                                   train_sample.slide_num.astype("int"),
-                                   train_sample.tumor_score.astype("int"),
-                                   train_sample.molecular_score,
-                                   train_sample["sample"])
-
-val_sample = (
-  val_sample.rdd
-            .zipWithIndex()
-            .map(lambda r: (r[1] + 1, *r[0]))
-            .toDF(['__INDEX', 'slide_num', 'tumor_score', 'molecular_score', 'sample']))
-val_sample = val_sample.select(val_sample["__INDEX"].astype("int"),
-                               val_sample.slide_num.astype("int"),
-                               val_sample.tumor_score.astype("int"),
-                               val_sample.molecular_score,
-                               val_sample["sample"])
-
-# Save train and validation DataFrames.
-tr_sample_filename = "train_{}_sample_{}{}.parquet".format(p, sample_size,
-    "_grayscale" if grayscale else "")
-val_sample_filename = "val_{}_sample_{}{}.parquet".format(p, sample_size,
-    "_grayscale" if grayscale else "")
-train_sample_path = os.path.join(save_folder, tr_sample_filename)
-val_sample_path = os.path.join(save_folder, val_sample_filename)
-save(train_sample, train_sample_path, sample_size, grayscale)
-save(val_sample, val_sample_path, sample_size, grayscale)
+train_sample_path = os.path.join(save_folder, "train_{}_sample_{}{}.parquet".format(sample_frac,
+    sample_size, "_grayscale" if grayscale else ""))
+val_sample_path = os.path.join(save_folder, "val_{}_sample_{}{}.parquet".format(sample_frac,
+    sample_size, "_grayscale" if grayscale else ""))
+
+# Get labels
+labels_df = get_labels_df(folder)
+
+# Split into train and validation sets based on slide number, stratified by class
+train, val = train_test_split(labels_df, train_size=train_frac, stratify=labels_df['tumor_score'],
+                              random_state=seed)
+
+# Process train & val slides
+train_df = preprocess(spark, train.index, tile_size=tile_size, sample_size=sample_size,
+                      grayscale=grayscale, num_partitions=num_partitions, folder=folder)
+val_df = preprocess(spark, val.index, tile_size=tile_size, sample_size=sample_size,
+                    grayscale=grayscale, num_partitions=num_partitions, folder=folder)
+
+if row_indices:
+  # Add row indices
+  train_df = add_row_indices(train_df)
+  val_df = add_row_indices(val_df)
+
+# Save train & val DataFrames
+save(train_df, train_df_path, sample_size, grayscale)
+save(val_df, val_df_path, sample_size, grayscale)
+
+if sample_frac > 0:
+  # Sample Data
+  train_df = spark.read.load(train_df_path)
+  val_df = spark.read.load(val_df_path)
+  train_sample = sample(train_df, sample_frac, seed)
+  val_sample = sample(val_df, sample_frac, seed)
+
+  # Save sampled DataFrames.
+  save(train_sample, train_sample_path, sample_size, grayscale)
+  save(val_sample, val_sample_path, sample_size, grayscale)