You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by GitBox <gi...@apache.org> on 2020/05/06 22:55:37 UTC

[GitHub] [madlib] reductionista commented on a change in pull request #496: DBSCAN: Add new module DBSCAN

reductionista commented on a change in pull request #496:
URL: https://github.com/apache/madlib/pull/496#discussion_r421099375



##########
File path: src/ports/postgres/modules/dbscan/dbscan.py_in
##########
@@ -0,0 +1,331 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.utilities import add_postfix
+from utilities.utilities import NUMERIC, ONLY_ARRAY
+from utilities.utilities import is_valid_psql_type
+from utilities.utilities import is_platform_pg
+from utilities.validate_args import input_tbl_valid, output_tbl_valid
+from utilities.validate_args import is_var_valid
+from utilities.validate_args import cols_in_tbl_valid
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import get_algorithm_name
+from graph.wcc import wcc
+
+BRUTE_FORCE = 'brute_force'
+KD_TREE = 'kd_tree'
+
+def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs):
+
+    with MinWarning("warning"):
+
+        min_samples = 5 if not min_samples else min_samples
+        metric = 'squared_dist_norm2' if not metric else metric
+        algorithm = 'brute' if not algorithm else algorithm
+
+        algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
+            [BRUTE_FORCE, KD_TREE], 'DBSCAN')
+
+        _validate_dbscan(schema_madlib, source_table, output_table, id_column,
+                         expr_point, eps, min_samples, metric, algorithm)
+
+        dist_src_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY (__src__)'
+        dist_id_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column)
+        dist_reach_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)'
+
+        # Calculate pairwise distances
+        distance_table = unique_string(desp='distance_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table))
+
+        sql = """
+            CREATE TABLE {distance_table} AS
+            SELECT __src__, __dest__ FROM (
+                SELECT  __t1__.{id_column} AS __src__,
+                        __t2__.{id_column} AS __dest__,
+                        {schema_madlib}.{metric}(
+                            __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__
+                FROM {source_table} AS __t1__, {source_table} AS __t2__) q1
+            WHERE __dist__ < {eps}
+            {dist_src_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Find core points
+        core_points_table = unique_string(desp='core_points_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(core_points_table))
+        sql = """
+            CREATE TABLE {core_points_table} AS
+            SELECT * FROM (SELECT __src__ AS {id_column}, count(*) AS __count__
+                           FROM {distance_table} GROUP BY __src__) q1
+            WHERE __count__ >= {min_samples}
+            {dist_id_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Find the connections between core points to form the clusters
+        core_edge_table = unique_string(desp='core_edge_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(core_edge_table))
+        sql = """
+            CREATE TABLE {core_edge_table} AS
+            SELECT __src__, __dest__
+            FROM {distance_table} AS __t1__, (SELECT array_agg({id_column}) AS arr
+                                              FROM {core_points_table}) __t2__
+            WHERE __t1__.__src__ = ANY(arr) AND __t1__.__dest__ = ANY(arr)
+            {dist_src_sql}
+        """.format(**locals())
+        plpy.execute(sql)
+
+        # Run wcc to get the min id for each cluster
+        wcc(schema_madlib, core_points_table, id_column, core_edge_table, 'src=__src__, dest=__dest__',
+            output_table, None)
+        plpy.execute("""
+            ALTER TABLE {0}
+            ADD COLUMN is_core_point BOOLEAN,
+            ADD COLUMN __points__ DOUBLE PRECISION[]
+            """.format(output_table))
+        plpy.execute("""
+            ALTER TABLE {0}
+            RENAME COLUMN component_id TO cluster_id
+            """.format(output_table))
+        plpy.execute("""
+            UPDATE {0}
+            SET is_core_point = TRUE
+        """.format(output_table))
+
+        # Find reachable points
+        reachable_points_table = unique_string(desp='reachable_points_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(reachable_points_table))
+        sql = """
+            CREATE TABLE {reachable_points_table} AS
+                SELECT array_agg(__src__) AS __src_list__,
+                       __dest__ AS __reachable_id__
+                FROM {distance_table} AS __t1__,
+                     (SELECT array_agg({id_column}) AS __arr__
+                      FROM {core_points_table}) __t2__
+                WHERE __src__ = ANY(__arr__) AND __dest__ != ALL(__arr__)
+                GROUP BY __dest__
+                {dist_reach_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        sql = """
+            INSERT INTO {output_table}
+            SELECT  __reachable_id__ as {id_column},
+                    cluster_id,
+                    FALSE AS is_core_point,
+                    NULL AS __points__
+            FROM {reachable_points_table} AS __t1__ INNER JOIN
+                 {output_table} AS __t2__
+                 ON (__src_list__[1] = {id_column})
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Add features of points to the output table to use them for prediction
+        sql = """
+            UPDATE {output_table} AS __t1__
+            SET __points__ = {expr_point}
+            FROM {source_table} AS __t2__
+            WHERE __t1__.{id_column} = __t2__.{id_column}
+        """.format(**locals())
+        plpy.execute(sql)
+
+        # Update the cluster ids to be consecutive
+        sql = """
+            UPDATE {output_table} AS __t1__
+            SET cluster_id = new_id-1
+            FROM (
+                SELECT cluster_id, row_number() OVER(ORDER BY cluster_id) AS new_id
+                FROM {output_table}
+                GROUP BY cluster_id) __t2__
+            WHERE __t1__.cluster_id = __t2__.cluster_id
+        """.format(**locals())
+        plpy.execute(sql)
+
+        output_summary_table = add_postfix(output_table, '_summary')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(output_summary_table))
+
+        sql = """
+            CREATE TABLE {output_summary_table} AS
+            SELECT  '{id_column}'::VARCHAR AS id_column,
+                    {eps}::DOUBLE PRECISION AS eps,
+                    '{metric}'::VARCHAR AS metric
+            """.format(**locals())
+        plpy.execute(sql)
+
+        plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(
+                     distance_table, core_points_table, reachable_points_table))
+
+
+def dbscan_predict(schema_madlib, dbscan_table, new_point, **kwargs):
+
+    with MinWarning("warning"):
+
+        dbscan_summary_table = add_postfix(dbscan_table, '_summary')
+        summary = plpy.execute("SELECT * FROM {0}".format(dbscan_summary_table))[0]
+
+        eps = summary['eps']
+        metric = summary['metric']
+        sql = """
+            SELECT cluster_id,
+                   {schema_madlib}.{metric}(__points__, ARRAY{new_point}) as dist
+            FROM {dbscan_table}
+            WHERE is_core_point = TRUE
+            ORDER BY dist LIMIT 1
+            """.format(**locals())
+        result = plpy.execute(sql)[0]
+        dist = result['dist']
+        if dist < eps:
+            return result['cluster_id']
+        else:
+            return None
+
+def _validate_dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm):
+
+    input_tbl_valid(source_table, 'dbscan')
+    output_tbl_valid(output_table, 'dbscan')
+    output_summary_table = add_postfix(output_table, '_summary')
+    output_tbl_valid(output_summary_table, 'dbscan')
+
+    cols_in_tbl_valid(source_table, [id_column], 'dbscan')
+
+    _assert(is_var_valid(source_table, expr_point),
+            "dbscan error: {0} is an invalid column name or "
+            "expression for expr_point param".format(expr_point))
+
+    point_col_type = get_expr_type(expr_point, source_table)
+    _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
+            "dbscan Error: Feature column or expression '{0}' in train table is not"
+            " an array.".format(expr_point))

Review comment:
       Maybe this error message should say something like "... in train table should be a numeric array", to handle the case where someone accidentally passes a column with a type like `TEXT[]`

##########
File path: src/ports/postgres/modules/dbscan/dbscan.py_in
##########
@@ -0,0 +1,331 @@
+# coding=utf-8
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import plpy
+
+from utilities.control import MinWarning
+from utilities.utilities import _assert
+from utilities.utilities import unique_string
+from utilities.utilities import add_postfix
+from utilities.utilities import NUMERIC, ONLY_ARRAY
+from utilities.utilities import is_valid_psql_type
+from utilities.utilities import is_platform_pg
+from utilities.validate_args import input_tbl_valid, output_tbl_valid
+from utilities.validate_args import is_var_valid
+from utilities.validate_args import cols_in_tbl_valid
+from utilities.validate_args import get_expr_type
+from utilities.validate_args import get_algorithm_name
+from graph.wcc import wcc
+
+BRUTE_FORCE = 'brute_force'
+KD_TREE = 'kd_tree'
+
+def dbscan(schema_madlib, source_table, output_table, id_column, expr_point, eps, min_samples, metric, algorithm, **kwargs):
+
+    with MinWarning("warning"):
+
+        min_samples = 5 if not min_samples else min_samples
+        metric = 'squared_dist_norm2' if not metric else metric
+        algorithm = 'brute' if not algorithm else algorithm
+
+        algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
+            [BRUTE_FORCE, KD_TREE], 'DBSCAN')
+
+        _validate_dbscan(schema_madlib, source_table, output_table, id_column,
+                         expr_point, eps, min_samples, metric, algorithm)
+
+        dist_src_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY (__src__)'
+        dist_id_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY ({0})'.format(id_column)
+        dist_reach_sql = ''  if is_platform_pg() else 'DISTRIBUTED BY (__reachable_id__)'
+
+        # Calculate pairwise distances
+        distance_table = unique_string(desp='distance_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(distance_table))
+
+        sql = """
+            CREATE TABLE {distance_table} AS
+            SELECT __src__, __dest__ FROM (
+                SELECT  __t1__.{id_column} AS __src__,
+                        __t2__.{id_column} AS __dest__,
+                        {schema_madlib}.{metric}(
+                            __t1__.{expr_point}, __t2__.{expr_point}) AS __dist__
+                FROM {source_table} AS __t1__, {source_table} AS __t2__) q1
+            WHERE __dist__ < {eps}
+            {dist_src_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Find core points
+        core_points_table = unique_string(desp='core_points_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(core_points_table))
+        sql = """
+            CREATE TABLE {core_points_table} AS
+            SELECT * FROM (SELECT __src__ AS {id_column}, count(*) AS __count__
+                           FROM {distance_table} GROUP BY __src__) q1
+            WHERE __count__ >= {min_samples}
+            {dist_id_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Find the connections between core points to form the clusters
+        core_edge_table = unique_string(desp='core_edge_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(core_edge_table))
+        sql = """
+            CREATE TABLE {core_edge_table} AS
+            SELECT __src__, __dest__
+            FROM {distance_table} AS __t1__, (SELECT array_agg({id_column}) AS arr
+                                              FROM {core_points_table}) __t2__
+            WHERE __t1__.__src__ = ANY(arr) AND __t1__.__dest__ = ANY(arr)
+            {dist_src_sql}
+        """.format(**locals())
+        plpy.execute(sql)
+
+        # Run wcc to get the min id for each cluster
+        wcc(schema_madlib, core_points_table, id_column, core_edge_table, 'src=__src__, dest=__dest__',
+            output_table, None)
+        plpy.execute("""
+            ALTER TABLE {0}
+            ADD COLUMN is_core_point BOOLEAN,
+            ADD COLUMN __points__ DOUBLE PRECISION[]
+            """.format(output_table))
+        plpy.execute("""
+            ALTER TABLE {0}
+            RENAME COLUMN component_id TO cluster_id
+            """.format(output_table))
+        plpy.execute("""
+            UPDATE {0}
+            SET is_core_point = TRUE
+        """.format(output_table))
+
+        # Find reachable points
+        reachable_points_table = unique_string(desp='reachable_points_table')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(reachable_points_table))
+        sql = """
+            CREATE TABLE {reachable_points_table} AS
+                SELECT array_agg(__src__) AS __src_list__,
+                       __dest__ AS __reachable_id__
+                FROM {distance_table} AS __t1__,
+                     (SELECT array_agg({id_column}) AS __arr__
+                      FROM {core_points_table}) __t2__
+                WHERE __src__ = ANY(__arr__) AND __dest__ != ALL(__arr__)
+                GROUP BY __dest__
+                {dist_reach_sql}
+            """.format(**locals())
+        plpy.execute(sql)
+
+        sql = """
+            INSERT INTO {output_table}
+            SELECT  __reachable_id__ as {id_column},
+                    cluster_id,
+                    FALSE AS is_core_point,
+                    NULL AS __points__
+            FROM {reachable_points_table} AS __t1__ INNER JOIN
+                 {output_table} AS __t2__
+                 ON (__src_list__[1] = {id_column})
+            """.format(**locals())
+        plpy.execute(sql)
+
+        # Add features of points to the output table to use them for prediction
+        sql = """
+            UPDATE {output_table} AS __t1__
+            SET __points__ = {expr_point}
+            FROM {source_table} AS __t2__
+            WHERE __t1__.{id_column} = __t2__.{id_column}
+        """.format(**locals())
+        plpy.execute(sql)
+
+        # Update the cluster ids to be consecutive
+        sql = """
+            UPDATE {output_table} AS __t1__
+            SET cluster_id = new_id-1
+            FROM (
+                SELECT cluster_id, row_number() OVER(ORDER BY cluster_id) AS new_id
+                FROM {output_table}
+                GROUP BY cluster_id) __t2__
+            WHERE __t1__.cluster_id = __t2__.cluster_id
+        """.format(**locals())
+        plpy.execute(sql)
+
+        output_summary_table = add_postfix(output_table, '_summary')
+        plpy.execute("DROP TABLE IF EXISTS {0}".format(output_summary_table))
+
+        sql = """
+            CREATE TABLE {output_summary_table} AS
+            SELECT  '{id_column}'::VARCHAR AS id_column,
+                    {eps}::DOUBLE PRECISION AS eps,
+                    '{metric}'::VARCHAR AS metric
+            """.format(**locals())
+        plpy.execute(sql)
+
+        plpy.execute("DROP TABLE IF EXISTS {0}, {1}, {2}".format(
+                     distance_table, core_points_table, reachable_points_table))

Review comment:
       Add `core_edge_table` to list of tables to DROP.  I noticed this sticks around after running it.
   Also, should we be using CREATE TEMP TABLE for these?  (In which case I think they get dropped automatically?)




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org