You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by Swatisoni <gi...@git.apache.org> on 2017/12/18 23:28:36 UTC
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
GitHub user Swatisoni opened a pull request:
https://github.com/apache/madlib/pull/218
Balanced Datasets: Random undersampling with/without replacement
JIRA:MADLIB-1168
Additional Authors:
Orhan Kislal <ok...@pivotal.io>
This commit implements random undersampling to create a dataset
with balanced classes.
Both with- and without-replacement methods are available.
You can merge this pull request into a Git repository by running:
$ git pull https://github.com/Swatisoni/madlib feature/balanced_sets
Alternatively you can review and apply these changes as the patch at:
https://github.com/apache/madlib/pull/218.patch
To close this pull request, make a commit to your master/trunk branch
with (at least) the following in the commit message:
This closes #218
----
commit 010199cbd2d14f13eca76330d54fc6e29fb9ecee
Author: Swatisoni <so...@gmail.com>
Date: 2017-12-18T23:30:41Z
Balanced Datasets: Random undersampling with/without replacement
JIRA:MADLIB-1168
Additional Authors:
Orhan Kislal <ok...@pivotal.io>
This commit implements random undersampling to create a dataset
with balanced classes.
Both with- and without-replacement methods are available.
----
---
[GitHub] madlib issue #218: Balanced Datasets: Random undersampling with/without repl...
Posted by Swatisoni <gi...@git.apache.org>.
Github user Swatisoni commented on the issue:
https://github.com/apache/madlib/pull/218
Balance Sample Phase 1 and Phase 2 are implemented in the new PR.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157794392
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
--- End diff --
It might be helpful to outline the "with replacement" algorithm as a comment to make it more understandable and readable. maybe even make function(s) out of each logical sql module.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157892828
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
--- End diff --
I think the comments can be more verbose and it might be easier to understand the implementation if the comments are consolidated and explained in a bit more detail. It's a bit hard to understand by just looking at these comments to understand what the algorithm is doing.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157799149
--- Diff: src/ports/postgres/modules/sample/test/balance_sample.sql_in ---
@@ -0,0 +1,103 @@
+/* ----------------------------------------------------------------------- *//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ----------------------------------------------------------------------- */
+
+DROP TABLE IF EXISTS "TEST_s" cascade;
+
+CREATE TABLE "TEST_s"(
+ id1 INTEGER,
+ "ID2" INTEGER,
+ gr1 INTEGER,
+ gr2 INTEGER
+);
+
+INSERT INTO "TEST_s" VALUES
+(1,0,1,1),
+(2,0,1,1),
+(3,0,1,1),
+(4,0,1,1),
+(5,0,1,1),
+(6,0,1,1),
+(7,0,1,1),
+(8,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(9,0,1,1),
+(0,1,1,2),
+(0,2,1,2),
+(0,3,1,2),
+(0,4,1,2),
+(0,5,1,2),
+(0,6,1,2),
+(10,10,2,2),
+(20,20,2,2),
+(30,30,2,2),
+(40,40,2,2),
+(50,50,2,2),
+(60,60,2,2),
+(70,70,2,2),
+(10,10,5,5),
+(50,50,5,5),
+(88,88,5,5),
+(40,40,5,6),
+(50,50,5,6),
+(60,60,5,6),
+(70,70,5,6),
+(10,10,6,6),
+(60,60,6,6),
+(30,30,6,6),
+(40,40,6,6),
+(50,50,6,6),
+(60,60,6,6),
+(70,70,6,6),
+(50,50,4,2),
+(60,60,4,2),
+(70,70,4,2),
+(50,50,3,2),
+(60,60,3,2),
+(70,70,3,2)
+;
+
+--- Test for random undersampling without replacement
+DROP TABLE IF EXISTS out_s;
+SELECT balance_sample('"TEST_s"', 'out_s', 'gr1', 'undersample', NULL, NULL, FALSE);
+SELECT assert(count(*) = 18, 'Wrong number of samples') FROM out_s;
+
+DROP TABLE IF EXISTS out_s1;
+SELECT balance_sample('"TEST_s"', 'out_s1', 'gr2', 'undersample', NULL, NULL, FALSE);
+SELECT assert(count(*) = 12, 'Wrong number of samples') FROM out_s1;
+
+--- Test for random undersampling with replacement
+DROP TABLE IF EXISTS out_sr2;
+SELECT balance_sample('"TEST_s"', 'out_sr2', 'gr1', 'undersample', NULL, NULL, TRUE);
+SELECT assert(sum(c) <= 18, 'Wrong number of samples') FROM
--- End diff --
Instead of sum, it might be better to check that all the classes have exactly 3 tuples. You can change the query to something like
```
select count(*) from (select gr1, count(*) as c from out_sr2 group by gr1) as foo where foo.c != 3;
```
The output of this should return 0 tuples. This applies to all the tests including with and without replacement
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by Swatisoni <gi...@git.apache.org>.
Github user Swatisoni commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157885094
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
--- End diff --
The 3 proceeding comments explain the steps of the sampling with replacement algorithm.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157893105
--- Diff: src/ports/postgres/modules/sample/test/balance_sample.sql_in ---
@@ -93,8 +93,8 @@ SELECT assert(sum(c) <= 18, 'Wrong number of samples') FROM
DROP TABLE IF EXISTS out_sr3;
SELECT balance_sample('"TEST_s"', 'out_sr3', 'gr2', 'undersample', NULL, NULL, TRUE);
-select assert(sum(c) <= 12, 'Wrong number of samples') from
- (select gr2, count(*) as c from out_sr3 group by gr2) as foo;
+select assert(count(*) = 0, 'Wrong number of samples') from
--- End diff --
the same changes can be made to all the other tests as well.
---
[GitHub] madlib issue #218: Balanced Datasets: Random undersampling with/without repl...
Posted by asfgit <gi...@git.apache.org>.
Github user asfgit commented on the issue:
https://github.com/apache/madlib/pull/218
Refer to this link for build results (access rights to CI server needed):
https://builds.apache.org/job/madlib-pr-build/315/
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157773474
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ """
+ Create independent random values
+ for each class that has more than the min number of rows
+ """
+ random_minorityclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1,minority_class_size) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ min(class_count) AS minority_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count != minority_class_size
+ """.format(**locals())
+ """
+ Match random values with the row identifiers
+ """
+ undersample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+ ({random_minorityclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col} = f2.classes)
+ """.format(**locals())
+ """
+ Find classes with minimum number of rows
+ """
+ minorityclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN
+ (SELECT
+ classes AS minority_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT min(class_count) FROM {class_counts}))
+ """.format(**locals())
+ """
+ Combine minority and other undersampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({minorityclass_set}) AS a
+ UNION ALL
+ ({undersample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP VIEW IF EXISTS {0}".format(class_counts))
+ return
+
+def _validate_strs (source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols):
+
+ _assert(source_table and source_table.strip().lower() not in ('null', ''),
--- End diff --
do we have any common functions that check for the validity of the source table and the output table. I would think that the validation will remain the same across modules.(??).
---
[GitHub] madlib issue #218: Balanced Datasets: Random undersampling with/without repl...
Posted by asfgit <gi...@git.apache.org>.
Github user asfgit commented on the issue:
https://github.com/apache/madlib/pull/218
Refer to this link for build results (access rights to CI server needed):
https://builds.apache.org/job/madlib-pr-build/314/
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by Swatisoni <gi...@git.apache.org>.
Github user Swatisoni closed the pull request at:
https://github.com/apache/madlib/pull/218
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by Swatisoni <gi...@git.apache.org>.
Github user Swatisoni commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157885756
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ """
+ Create independent random values
+ for each class that has more than the min number of rows
+ """
+ random_minorityclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1,minority_class_size) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ min(class_count) AS minority_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count != minority_class_size
+ """.format(**locals())
+ """
+ Match random values with the row identifiers
+ """
+ undersample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+ ({random_minorityclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col} = f2.classes)
+ """.format(**locals())
+ """
+ Find classes with minimum number of rows
+ """
+ minorityclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN
+ (SELECT
+ classes AS minority_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT min(class_count) FROM {class_counts}))
+ """.format(**locals())
+ """
+ Combine minority and other undersampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({minorityclass_set}) AS a
+ UNION ALL
+ ({undersample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP VIEW IF EXISTS {0}".format(class_counts))
+ return
+
+def _validate_strs (source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols):
+
+ _assert(source_table and source_table.strip().lower() not in ('null', ''),
--- End diff --
Thanks for the pointer Nikhil. There is a table_exists function in validate_args.py_in that we can use here.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by Swatisoni <gi...@git.apache.org>.
Github user Swatisoni commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157887416
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ """
+ Create independent random values
+ for each class that has more than the min number of rows
+ """
+ random_minorityclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1,minority_class_size) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ min(class_count) AS minority_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count != minority_class_size
+ """.format(**locals())
+ """
+ Match random values with the row identifiers
+ """
+ undersample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+ ({random_minorityclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col} = f2.classes)
+ """.format(**locals())
+ """
+ Find classes with minimum number of rows
+ """
+ minorityclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN
+ (SELECT
+ classes AS minority_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT min(class_count) FROM {class_counts}))
+ """.format(**locals())
+ """
+ Combine minority and other undersampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({minorityclass_set}) AS a
+ UNION ALL
+ ({undersample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP VIEW IF EXISTS {0}".format(class_counts))
+ return
+
+def _validate_strs (source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols):
+
+ _assert(source_table and source_table.strip().lower() not in ('null', ''),
--- End diff --
This assert seems redundant since the next one calls table_exists which handles the same validation tasks.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157769789
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
--- End diff --
Consider adding "@param" for all the arguments of the balance_sample function.
---
[GitHub] madlib pull request #218: Balanced Datasets: Random undersampling with/witho...
Posted by kaknikhil <gi...@git.apache.org>.
Github user kaknikhil commented on a diff in the pull request:
https://github.com/apache/madlib/pull/218#discussion_r157770698
--- Diff: src/ports/postgres/modules/sample/balance_sample.py_in ---
@@ -0,0 +1,322 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import plpy
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.validate_args import table_exists
+from utilities.validate_args import columns_exist_in_table
+from utilities.validate_args import table_is_empty
+from utilities.validate_args import get_cols
+from utilities.utilities import py_list_to_sql_string
+
+m4_changequote(`<!', `!>')
+
+def balance_sample(schema_madlib, source_table, output_table, class_col,
+ class_sizes, output_table_size, grouping_cols, with_replacement, **kwargs):
+
+ """
+ Balance sampling function
+ Args:
+ @param source_table Input table name.
+ @param output_table Output table name.
+ @param class_col Name of the column containing the class to be
+ balanced.
+ @param with_replacement (Default: FALSE) The sampling method.
+
+ """
+ with MinWarning("warning"):
+
+ class_counts = unique_string(desp='class_counts')
+
+ _validate_strs(source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols)
+ source_table_columns = ','.join(get_cols(source_table))
+ grp_by = "GROUP BY {0}".format(class_col)
+ """
+ Frequency table for classes
+ """
+ plpy.execute(""" CREATE VIEW {class_counts} AS (
+ SELECT
+ {class_col} AS classes,
+ count(*) AS class_count
+ FROM {source_table}
+ {grp_by})
+ """.format(**locals()))
+
+ if class_sizes.lower() == 'undersample':
+
+ if not with_replacement:
+ """
+ Random undersample without replacement
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT
+ {source_table_columns}
+ FROM
+ (SELECT
+ *,
+ row_number() OVER(PARTITION BY
+ {class_col} ORDER BY random())
+ AS __row_no
+ FROM {source_table}) AS foo
+ WHERE __row_no <=
+ (SELECT
+ MIN(class_count)
+ FROM {class_counts}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+ else:
+ """
+ Random undersample with replacement
+ """
+ """
+ Create row identifiers for each row wrt the class
+ """
+ classwise_row_numbering_sql = """
+ SELECT
+ *,
+ row_number() OVER(PARTITION BY {class_col})
+ AS __row_no
+ FROM
+ {source_table}
+ """.format(**locals())
+ """
+ Create independent random values
+ for each class that has more than the min number of rows
+ """
+ random_minorityclass_size_sample_number_gen_sql = """
+ SELECT
+ classes,
+ generate_series(1,minority_class_size) AS _i,
+ ((random()*({class_counts}.class_count-1)+1)::int)
+ AS __row_no
+ FROM
+ (SELECT
+ min(class_count) AS minority_class_size
+ FROM {class_counts})
+ AS foo,
+ {class_counts}
+ WHERE {class_counts}.class_count != minority_class_size
+ """.format(**locals())
+ """
+ Match random values with the row identifiers
+ """
+ undersample_otherclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM
+ ({classwise_row_numbering_sql}) AS f1
+ RIGHT JOIN
+ ({random_minorityclass_size_sample_number_gen_sql}) AS
+ f2
+ ON (f1.__row_no = f2.__row_no) AND
+ (f1.{class_col} = f2.classes)
+ """.format(**locals())
+ """
+ Find classes with minimum number of rows
+ """
+ minorityclass_set = """
+ SELECT
+ {source_table_columns}
+ FROM {source_table}
+ WHERE {class_col} IN
+ (SELECT
+ classes AS minority_class
+ FROM {class_counts}
+ WHERE class_count in
+ (SELECT min(class_count) FROM {class_counts}))
+ """.format(**locals())
+ """
+ Combine minority and other undersampled classes
+ """
+ output_sql = """
+ CREATE TABLE {output_table} AS (
+ SELECT {source_table_columns}
+ FROM
+ ({minorityclass_set}) AS a
+ UNION ALL
+ ({undersample_otherclass_set}))
+ """.format(**locals())
+ plpy.execute(output_sql)
+
+ plpy.execute("DROP VIEW IF EXISTS {0}".format(class_counts))
+ return
+
+def _validate_strs (source_table, output_table, class_col, class_sizes,
+ output_table_size, grouping_cols):
+
+ _assert(source_table and source_table.strip().lower() not in ('null', ''),
+ "Sample: Invalid Source table name!".format(**locals()))
+ _assert(table_exists(source_table),
+ "Sample: Source table ({source_table}) is missing!".format(**locals()))
+ _assert(not table_is_empty(source_table),
+ "Sample: Source table ({source_table}) is empty!".format(**locals()))
+
+ _assert(output_table and output_table.strip().lower() not in ('null', ''),
+ "Sample: Invalid output table name {output_table}!".format(**locals()))
+ _assert(not table_exists(output_table),
+ "Sample: Output table already exists!".format(**locals()))
--- End diff --
maybe print the name of the output table as well
---
[GitHub] madlib issue #218: Balanced Datasets: Random undersampling with/without repl...
Posted by asfgit <gi...@git.apache.org>.
Github user asfgit commented on the issue:
https://github.com/apache/madlib/pull/218
Refer to this link for build results (access rights to CI server needed):
https://builds.apache.org/job/madlib-pr-build/313/
---