You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2017/08/29 20:42:19 UTC
[42/50] [abbrv] incubator-madlib git commit: Sample: Add function to
split train/test
Sample: Add function to split train/test
JIRA: MADLIB-1119
Add utility to create train and test samples from an input table.
This function uses the stratified sampling to create the samples.
Closes #166
Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/f63adda6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/f63adda6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/f63adda6
Branch: refs/heads/latest_release
Commit: f63adda68a5883e7f2dbd2a6695a65e3ba89efce
Parents: 3e427d9
Author: Cooper Sloan <co...@gmail.com>
Authored: Fri Aug 18 06:23:29 2017 -0700
Committer: Rahul Iyer <ri...@apache.org>
Committed: Fri Aug 18 16:29:15 2017 -0700
----------------------------------------------------------------------
doc/mainpage.dox.in | 8 +-
.../modules/sample/stratified_sample.py_in | 2 +-
.../modules/sample/test/test_train_split.sql_in | 85 +++++
.../modules/sample/test_train_split.py_in | 319 +++++++++++++++++++
.../modules/sample/test_train_split.sql_in | 319 +++++++++++++++++++
5 files changed, 730 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/f63adda6/doc/mainpage.dox.in
----------------------------------------------------------------------
diff --git a/doc/mainpage.dox.in b/doc/mainpage.dox.in
index e27e14a..be45369 100644
--- a/doc/mainpage.dox.in
+++ b/doc/mainpage.dox.in
@@ -142,12 +142,15 @@ Contains graph algorithms.
@defgroup grp_wcc Weakly Connected Components
@}
-@defgroup grp_mdl Model Evaluation
-@{Contains functions for evaluating accuracy and validation of predictive methods. @}
+@defgroup grp_mdl Model Selection
+@{Contains functions for model selection and model evaluation. @}
@defgroup grp_validation Cross Validation
@ingroup grp_mdl
@defgroup grp_pred Prediction Metrics
@ingroup grp_mdl
+ @defgroup grp_test_train_split Test Train Split
+ @ingroup grp_mdl
+
@defgroup grp_stats Statistics
@{Contains statistics modules @}
@@ -264,6 +267,7 @@ Contains graph algorithms.
@defgroup grp_strs Stratified Sampling
@ingroup grp_sampling
+
@defgroup grp_sessionize Sessionize
@ingroup grp_utility_functions
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/f63adda6/src/ports/postgres/modules/sample/stratified_sample.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/stratified_sample.py_in b/src/ports/postgres/modules/sample/stratified_sample.py_in
index e7762ef..0d29b41 100644
--- a/src/ports/postgres/modules/sample/stratified_sample.py_in
+++ b/src/ports/postgres/modules/sample/stratified_sample.py_in
@@ -167,7 +167,7 @@ def validate_strs (source_table, output_table, proportion, glist, target_cols):
_assert(not table_is_empty(source_table),
"Sample: Source table ({source_table}) is empty!".format(**locals()))
- _assert(proportion > 0 and proportion < 1,
+ _assert(proportion > 0 and proportion <= 1,
"Sample: Proportion isn't in the range (0,1)!")
if glist is not None:
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/f63adda6/src/ports/postgres/modules/sample/test/test_train_split.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/test/test_train_split.sql_in b/src/ports/postgres/modules/sample/test/test_train_split.sql_in
new file mode 100644
index 0000000..5ae0ade
--- /dev/null
+++ b/src/ports/postgres/modules/sample/test/test_train_split.sql_in
@@ -0,0 +1,85 @@
+/* ----------------------------------------------------------------------- *//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ----------------------------------------------------------------------- */
+
+DROP TABLE IF EXISTS test;
+
+CREATE TABLE test(
+ id1 INTEGER,
+ id2 INTEGER,
+ gr1 INTEGER,
+ gr2 INTEGER
+);
+
+INSERT INTO test VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2)
+;
+
+SELECT setseed(0);
+
+DROP TABLE IF EXISTS out_train,out_test,out;
+SELECT test_train_split('test', 'out', 0.1, 0.2, NULL, 'id1,id2,gr1,gr2', FALSE, TRUE);
+SELECT assert(count(*) = 2, 'Wrong number of samples') FROM out_train;
+SELECT assert(count(*) = 4, 'Wrong number of samples') FROM out_test;
+
+DROP TABLE IF EXISTS out_train,out_test,out;
+SELECT test_train_split('test', 'out', 0.1, 0.2, NULL, 'id1,id2,gr1,gr2', FALSE, FALSE);
+SELECT assert(count(*) = 2, 'Wrong number of samples') FROM out WHERE split=1;
+SELECT assert(count(*) = 4, 'Wrong number of samples') FROM out WHERE split=0;
+
+
+DROP TABLE IF EXISTS out_train,out_test,out;
+SELECT test_train_split('test', 'out', 0.5, 0.5, NULL, 'id1,id2,gr1,gr2', TRUE, FALSE);
+SELECT assert(count(*) = 20, 'Wrong number of samples') FROM out;
+
+DROP TABLE IF EXISTS out;
+SELECT test_train_split('test', 'out', 0.5, 0.5, 'gr1,gr2', 'id1,id2', TRUE, FALSE);
+select * from out;
+SELECT assert(count(*) = 6, 'Wrong number of samples')
+FROM out WHERE gr1 = 1 AND gr2 = 1 AND split = 0;
+SELECT assert(count(*) = 6, 'Wrong number of samples')
+FROM out WHERE gr1 = 1 AND gr2 = 1 AND split = 1;
+SELECT assert(count(*) = 3, 'Wrong number of samples')
+FROM out WHERE gr1 = 1 AND gr2 = 2 AND split = 0;
+SELECT assert(count(*) = 3, 'Wrong number of samples')
+FROM out WHERE gr1 = 1 AND gr2 = 2 AND split = 1;
+SELECT assert(count(*) = 1, 'Wrong number of samples')
+FROM out WHERE gr1 = 2 AND gr2 = 2 AND split = 0;
+SELECT assert(count(*) = 1, 'Wrong number of samples')
+FROM out WHERE gr1 = 2 AND gr2 = 2 AND split = 1;
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/f63adda6/src/ports/postgres/modules/sample/test_train_split.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/test_train_split.py_in b/src/ports/postgres/modules/sample/test_train_split.py_in
new file mode 100644
index 0000000..6056b2a
--- /dev/null
+++ b/src/ports/postgres/modules/sample/test_train_split.py_in
@@ -0,0 +1,319 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import extract_keyvalue_params
+from utilities.utilities import add_postfix
+from utilities.utilities import unique_string
+from utilities.utilities import split_quoted_delimited_str
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import get_cols
+from graph.graph_utils import _check_groups
+from graph.graph_utils import _grp_from_table
+
+m4_changequote(` <!', `!>')
+
+
+def _get_sql_string(str):
+ if str:
+ return "'" + str + "'"
+ return "NULL"
+
+
+def test_train_split(schema_madlib, source_table, output_table, train_proportion,
+ test_proportion, grouping_cols, target_cols, with_replacement,
+ separate_output_tables, **kwargs):
+ """
+ test train split function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param train_proportion The ratio of training data to the entire
+ input table
+ @param test_proportion The ratio of test data to the entire
+ input table
+ @param grouping_cols (Default: NULL) The columns to distinguish
+ each strata.
+ @param target_cols (Default: NULL) The columns to include in
+ the output.
+ @param with_replacement (Default: FALSE) The sampling method.
+ @param separate_output_tables (Default: FALSE) Create two output tables,
+ <output_table>_train and <output_table>_test.
+ Otherwise one output table is created with
+ and additional column 'split' which takes the
+ value 0 for test and 1 for training.
+
+ """
+ with MinWarning("warning"):
+ if test_proportion is None:
+ test_proportion = 1 - train_proportion
+ validate_strs(source_table, output_table, train_proportion, test_proportion,
+ split_quoted_delimited_str(grouping_cols), target_cols,
+ with_replacement)
+ grouping_cols = _get_sql_string(grouping_cols)
+ target_cols = _get_sql_string(target_cols)
+ with_replacement = with_replacement or "False"
+ strat_query = """
+ SELECT {schema_madlib}.stratified_sample(
+ '{strat_source_table}',
+ '{strat_out_table}',
+ '{strat_proportion}',
+ {strat_grouping_cols},
+ {strat_target_cols},
+ {strat_with_replacement}
+ )
+ """
+ strat_out_table = unique_string()
+ q = strat_query.format(
+ schema_madlib=schema_madlib,
+ strat_source_table=source_table,
+ strat_out_table=strat_out_table,
+ strat_proportion=train_proportion + test_proportion,
+ strat_grouping_cols=grouping_cols,
+ strat_with_replacement=with_replacement,
+ strat_target_cols=target_cols
+ )
+ plpy.execute(q)
+ test_table = add_postfix(output_table, "_test")
+ train_table = add_postfix(output_table, "_train")
+ if not separate_output_tables:
+ test_table = unique_string()
+ train_table = unique_string()
+ test_query = strat_query.format(
+ schema_madlib=schema_madlib,
+ strat_source_table=strat_out_table,
+ strat_out_table=test_table,
+ strat_proportion=(test_proportion /
+ (train_proportion + test_proportion)),
+ strat_grouping_cols=grouping_cols,
+ strat_with_replacement=False,
+ strat_target_cols=target_cols
+ )
+ plpy.execute(test_query)
+ train_query = """
+ CREATE TABLE {train_table} AS
+ SELECT * FROM {strat_out_table}
+ EXCEPT ALL
+ SELECT * FROM {test_table}
+ """.format(train_table=train_table,
+ strat_out_table=strat_out_table,
+ test_table=test_table)
+ plpy.execute(train_query)
+ clean_up_tables = [strat_out_table]
+ if not separate_output_tables:
+ union_query = """
+ CREATE TABLE {output_table} AS
+ SELECT *,0 AS split FROM {test_table}
+ UNION ALL
+ SELECT *,1 AS split FROM {train_table}
+ """.format(output_table=output_table,
+ test_table=test_table,
+ train_table=train_table)
+ plpy.execute(union_query)
+ clean_up_tables += [train_table, test_table]
+ clean_up_query = """
+ DROP TABLE IF EXISTS {clean_up_tables}
+ """.format(clean_up_tables=",".join(clean_up_tables))
+ plpy.execute(clean_up_query)
+ return
+
+
+def validate_strs(source_table, output_table, train_proportion, test_proportion, glist, target_cols, with_replacement):
+
+ _assert(output_table and output_table.strip().lower() not in ('null', ''),
+ "Sample: Invalid output table name {output_table}!".format(**locals()))
+ _assert(not table_exists(output_table),
+ "Sample: Output table already exists!".format(**locals()))
+
+ _assert(source_table and source_table.strip().lower() not in ('null', ''),
+ "Sample: Invalid Source table name!".format(**locals()))
+ _assert(table_exists(source_table),
+ "Sample: Source table ({source_table}) is missing!".format(**locals()))
+ _assert(not table_is_empty(source_table),
+ "Sample: Source table ({source_table}) is empty!".format(**locals()))
+
+ for proportion in [train_proportion, test_proportion]:
+ _assert(proportion > 0 and proportion < 1,
+ "Sample: Proportions aren't in the range (0,1)!")
+ if not with_replacement:
+ _assert(train_proportion + test_proportion <= 1,
+ "Sample: Proportions add up to greater than 1!")
+
+ if glist is not None:
+ _assert(columns_exist_in_table(source_table, glist),
+ ("""Sample: Not all columns from {glist} are present in source""" +
+ """ table ({source_table}).""").format(**locals()))
+
+ if not (target_cols is None or target_cols is '*'):
+ tlist = split_quoted_delimited_str(target_cols)
+ _assert(columns_exist_in_table(source_table, tlist),
+ ("""Sample: Not all columns from {target_cols} are present in""" +
+ """ edge table ({source_table})""").format(**locals()))
+ return
+
+
+def test_train_split_help(schema_madlib, message, **kwargs):
+ """
+ Help function for test_train_split
+
+ Args:
+ @param schema_madlib
+ @param message: string, Help message string
+ @param kwargs
+
+ Returns:
+ String. Help/usage information
+ """
+ if not message:
+ help_string = """
+-----------------------------------------------------------------------
+ SUMMARY
+-----------------------------------------------------------------------
+
+Given a table, test_train_split returns a random sample of the
+table for testing and training. It is possible to use with or without
+replacement sampling methods, specify a set of target columns, and a
+set of grouping columns, in which case, stratified sampling will be
+performed.
+
+For more details on function usage:
+ SELECT {schema_madlib}.test_train_split('usage');
+ SELECT {schema_madlib}.test_train_split('example');
+ """
+ elif message.lower() in ['usage', 'help', '?']:
+ help_string = """
+
+Given a table, test train split returns a proportion of records for
+each group (strata). It is possible to use with or without replacement
+sampling methods, specify a set of target columns, and assume the
+whole table is a single strata.
+
+----------------------------------------------------------------------------
+ USAGE
+----------------------------------------------------------------------------
+
+ SELECT {schema_madlib}.test_train_split(
+ source_table TEXT, -- Name of the table containing the input data.
+ output_table TEXT, -- Output table name.
+ train_proportion FLOAT8, -- The ratio of train sample size to the
+ -- number of records.
+ test_proportion FLOAT8, -- The ratio of test sample size to the
+ -- number of records.
+ grouping_cols TEXT -- (Default: NULL) The columns to distinguish
+ -- each strata.
+ target_cols TEXT, -- (Default: NULL) The columns to include in
+ -- the output.
+ with_replacement BOOLEAN -- (Default: FALSE) The sampling method.
+ separate_output_tables
+ BOOLEAN -- (Default: FALSE) Separate the output table
+ -- into $output_table$_train and
+ -- $output_table$_test, otherwise, the split
+ -- column in output_table will identify 1 for
+ -- train set and 0 for test set.
+
+If grouping_cols is NULL, the whole table is treated as a single group and
+sampled accordingly.
+
+If target_cols is NULL or '*', all of the columns will be included in the
+output table.
+
+If with_replacement is TRUE, each sample is independent (the same row may
+be selected in the sample set more than once). Else (if with_replacement
+is FALSE), a row can be selected at most once.
+);
+"""
+ elif message.lower() in ("example", "examples"):
+ help_string = """
+----------------------------------------------------------------------------
+ EXAMPLES
+----------------------------------------------------------------------------
+
+-- Create an input table
+DROP TABLE IF EXISTS test;
+
+CREATE TABLE test(
+ id1 INTEGER,
+ id2 INTEGER,
+ gr1 INTEGER,
+ gr2 INTEGER
+);
+
+INSERT INTO test VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2)
+;
+
+-- Sample without replacement
+DROP TABLE IF EXISTS out;
+SELECT madlib.test_train_split(
+ 'test', -- Source table
+ 'out', -- Output table
+ 0.5, -- Sample proportion
+ 0.5, -- Sample proportion
+ 'gr1,gr2', -- Strata definition
+ 'id1,id2', -- Columns to output
+ FALSE, -- Sample without replacement
+ FALSE); -- Do not separate output tables
+SELECT * FROM out ORDER BY split,gr1,gr2,id1,id2;
+
+-- Sample with replacement
+DROP TABLE IF EXISTS out_train, out_test;
+SELECT madlib.test_train_split(
+ 'test', -- Source table
+ 'out', -- Output table
+ 0.5, -- train_proportion
+ NULL, -- Default = 1 - train_proportion = 0.5
+ 'gr1,gr2', -- Strata definition
+ 'id1,id2', -- Columns to output
+ TRUE, -- Sample with replacement
+ TRUE); -- Separate output tables
+SELECT * FROM out_train ORDER BY gr1,gr2,id1,id2;
+"""
+ else:
+ help_string = "No such option. Use {schema_madlib}.graph_sssp()"
+
+ return help_string.format(schema_madlib=schema_madlib)
http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/f63adda6/src/ports/postgres/modules/sample/test_train_split.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/sample/test_train_split.sql_in b/src/ports/postgres/modules/sample/test_train_split.sql_in
new file mode 100644
index 0000000..ba1adb3
--- /dev/null
+++ b/src/ports/postgres/modules/sample/test_train_split.sql_in
@@ -0,0 +1,319 @@
+/* ----------------------------------------------------------------------- *//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *
+ * @file test_train_split.sql_in
+ *
+ * @brief SQL functions for test train split.
+ * @date 07/19/2017
+ *
+ * @sa Given a table, test train split returns a proportion of records
+ * for each group (strata).
+ *
+ *//* ----------------------------------------------------------------------- */
+
+m4_include(`SQLCommon.m4')
+
+
+/**
+@addtogroup grp_test_train_split
+
+<div class="toc"><b>Contents</b>
+<ul>
+<li><a href="#strs">test train split</a></li>
+<li><a href="#examples">Examples</a></li>
+</ul>
+</div>
+
+@brief A method for independently sampling subpopulations (strata).
+
+test_train_split is a utility to create test and
+training data set as subsets of a single table.
+
+@anchor strs
+@par test train split
+
+<pre class="syntax">
+test_train_split( source_table,
+ output_table,
+ train_proportion,
+ test_proportion,
+ grouping_cols,
+ target_cols,
+ with_replacement
+ )
+</pre>
+
+\b Arguments
+<dl class="arglist">
+<dt>source_table</dt>
+<dd>TEXT. Name of the table containing the input data.</dd>
+
+<dt>output_table</dt>
+<dd>Name of output table. A new INTEGER column on the right
+called 'split' will identify 1 for train set and 0 for test set,
+unless the 'separate_output_tables' parameter below is TRUE,
+in which case two output tables will be created using
+the 'output_table' name with the suffixes '_train' and '_test'.
+The output table contains all the columns present in the source
+table unless otherwise specified in the 'target_cols' parameter below. </dd>
+
+<dt>train_proportion</dt>
+<dd>FLOAT8 in the range (0,1). Proportion of the dataset to include
+in the train split. If the 'grouping_col' parameter is specified below,
+each group will be sampled independently using the
+train proportion, i.e., in a stratified fashion.</dd>
+
+<dt>test_proportion</dt>
+<dd>FLOAT8 in the range (0,1). Proportion of the dataset to include
+in the test split. Default is the complement to the train
+proportion (1-'train_proportion'). If the 'grouping_col'
+parameter is specified below, each group will be sampled
+independently using the train proportion,
+i.e., in a stratified fashion.</dd>
+
+<dt>grouping_cols (optional)</dt>
+<dd>TEXT, default: NULL. A single column or a list of comma-separated columns
+ that defines how to stratify. When this parameter is NULL,
+the train-test split is not stratified.</dd>
+
+<dt>target_cols (optional)</dt>
+<dd>TEXT, default NULL. A comma-separated list of columns
+to appear in the 'output_table'. If NULL or '*', all
+columns from the 'source_table' will appear in
+the 'output_table'.</dd>
+
+@anchor note
+@note
+ Do not include 'grouping_cols' in the parameter 'target_cols',
+ because they are always included in the 'output_table'.
+
+<dt>with_replacement (optional)</dt>
+<dd>BOOLEAN, default FALSE. Determines whether to sample
+with replacement or without replacement (default).
+With replacement means that it is possible that the
+same row may appear in the sample set more than once.
+Without replacement means a given row can be selected
+only once.</dd>
+</dl>
+
+<dt>separate_output_tables (optional)</dt>
+<dd>BOOLEAN, default FALSE. If TRUE, two output tables will be created using
+the 'output_table' name with the suffixes '_train' and '_test'.</dd>
+</dl>
+
+
+@anchor examples
+@par Examples
+
+Please note that due to the random nature of sampling, your
+results may look different from those below.
+
+-# Create an input table:
+<pre class="syntax">
+DROP TABLE IF EXISTS test;
+CREATE TABLE test(
+ id1 INTEGER,
+ id2 INTEGER,
+ gr1 INTEGER,
+ gr2 INTEGER
+);
+INSERT INTO test VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2);
+</pre>
+
+-# Sample without replacement:
+<pre class="syntax">
+DROP TABLE IF EXISTS out;
+SELECT madlib.test_train_split(
+ 'test', -- Source table
+ 'out', -- Output table
+ 0.5, -- Sample proportion
+ 0.5, -- Sample proportion
+ 'gr1,gr2', -- Strata definition
+ 'id1,id2', -- Columns to output
+ FALSE, -- Sample without replacement
+ FALSE); -- Do not separate output tables
+SELECT * FROM out ORDER BY split,gr1,gr2,id1,id2;
+</pre>
+<pre class="result">
+ gr1 | gr2 | id1 | id2 | split
+-----+-----+-----+-----+-------
+ 1 | 1 | 1 | 0 | 0
+ 1 | 1 | 4 | 0 | 0
+ 1 | 1 | 6 | 0 | 0
+ 1 | 1 | 9 | 0 | 0
+ 1 | 1 | 9 | 0 | 0
+ 1 | 1 | 9 | 0 | 0
+ 1 | 2 | 0 | 3 | 0
+ 1 | 2 | 0 | 4 | 0
+ 1 | 2 | 0 | 5 | 0
+ 2 | 2 | 10 | 10 | 0
+ 2 | 2 | 30 | 30 | 0
+ 2 | 2 | 40 | 40 | 0
+ 2 | 2 | 60 | 60 | 0
+ 1 | 1 | 2 | 0 | 1
+ 1 | 1 | 3 | 0 | 1
+ 1 | 1 | 5 | 0 | 1
+ 1 | 1 | 7 | 0 | 1
+ 1 | 1 | 8 | 0 | 1
+ 1 | 1 | 9 | 0 | 1
+ 1 | 2 | 0 | 1 | 1
+ 1 | 2 | 0 | 2 | 1
+ 1 | 2 | 0 | 6 | 1
+ 2 | 2 | 20 | 20 | 1
+ 2 | 2 | 50 | 50 | 1
+ 2 | 2 | 70 | 70 | 1
+(25 rows)
+</pre>
+
+-# Sample with replacement:
+<pre class="syntax">
+DROP TABLE IF EXISTS out_train, out_test;
+SELECT madlib.test_train_split(
+ 'test', -- Source table
+ 'out', -- Output table
+ 0.5, -- train_proportion
+ NULL, -- Default = 1 - train_proportion = 0.5
+ 'gr1,gr2', -- Strata definition
+ 'id1,id2', -- Columns to output
+ TRUE, -- Sample with replacement
+ TRUE); -- Separate output tables
+SELECT * FROM out_train ORDER BY gr1,gr2,id1,id2;
+</pre>
+<pre class="result">
+ gr1 | gr2 | id1 | id2
+-----+-----+-----+-----
+ 1 | 1 | 1 | 0
+ 1 | 1 | 2 | 0
+ 1 | 1 | 4 | 0
+ 1 | 1 | 7 | 0
+ 1 | 1 | 8 | 0
+ 1 | 1 | 9 | 0
+ 1 | 2 | 0 | 4
+ 1 | 2 | 0 | 5
+ 1 | 2 | 0 | 6
+ 2 | 2 | 40 | 40
+ 2 | 2 | 50 | 50
+ 2 | 2 | 50 | 50
+(12 rows)
+</pre>
+*/
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ source_table TEXT,
+ output_table TEXT,
+ train_proportion FLOAT8,
+ test_proportion FLOAT8,
+ grouping_cols TEXT,
+ target_cols TEXT,
+ with_replacement BOOLEAN,
+ separate_output_tables BOOLEAN
+) RETURNS VOID AS $$
+ PythonFunction(sample, test_train_split, test_train_split)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+-------------------------------------------------------------------------------
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ source_table TEXT,
+ output_table TEXT,
+ train_proportion FLOAT8,
+ test_proportion FLOAT8,
+ grouping_cols TEXT,
+ target_cols TEXT,
+ with_replacement BOOLEAN
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, $6, $7, FALSE);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ source_table TEXT,
+ output_table TEXT,
+ train_proportion FLOAT8,
+ test_proportion FLOAT8,
+ grouping_cols TEXT,
+ target_cols TEXT
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, $6, FALSE);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ source_table TEXT,
+ output_table TEXT,
+ train_proportion FLOAT8,
+ test_proportion FLOAT8,
+ grouping_cols TEXT
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, $5, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ source_table TEXT,
+ output_table TEXT,
+ train_proportion FLOAT8,
+ test_proportion FLOAT8
+) RETURNS VOID AS $$
+ SELECT MADLIB_SCHEMA.test_train_split($1, $2, $3, $4, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+-------------------------------------------------------------------------------
+
+-- Online help
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split(
+ message VARCHAR
+) RETURNS VARCHAR AS $$
+ PythonFunction(sample, test_train_split, test_train_split_help)
+$$ LANGUAGE plpythonu IMMUTABLE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+-------------------------------------------------------------------------------
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.test_train_split()
+RETURNS VARCHAR AS $$
+ SELECT MADLIB_SCHEMA.test_train_split('');
+$$ LANGUAGE sql IMMUTABLE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+-------------------------------------------------------------------------------