You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2019/09/19 17:24:14 UTC

[madlib] 01/02: Kmeans: Add simple silhouette score for every point

This is an automated email from the ASF dual-hosted git repository.

okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git

commit 7a130ec548f5e4f9e362e42de47cf04739cce64c
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Fri Sep 13 12:38:27 2019 -0400

    Kmeans: Add simple silhouette score for every point
    
    JIRA: MADLIB-1382
    
    This commit adds a function to calculate the simple silhouette score for
    every input data point.
    
    Closes #441
    
    Co-authored-by: Domino Valdano <dv...@pivotal.io>
---
 src/modules/linalg/metric.cpp                      |  55 +++++++-----
 src/ports/postgres/modules/kmeans/kmeans.py_in     |  90 +++++++++++++++++--
 src/ports/postgres/modules/kmeans/kmeans.sql_in    |  53 +++++++++++
 .../postgres/modules/kmeans/test/kmeans.sql_in     | 100 ++++++++++++++++++---
 src/ports/postgres/modules/linalg/linalg.sql_in    |   2 +-
 .../postgres/modules/utilities/utilities.py_in     |   2 +-
 6 files changed, 259 insertions(+), 43 deletions(-)

diff --git a/src/modules/linalg/metric.cpp b/src/modules/linalg/metric.cpp
index e3835ee..4809762 100644
--- a/src/modules/linalg/metric.cpp
+++ b/src/modules/linalg/metric.cpp
@@ -365,31 +365,40 @@ closest_column::run(AnyType& args) {
  * This function calls a user-supplied function, for which it does not do
  * garbage collection. It is therefore meant to be called only constantly many
  * times before control is returned to the backend.
- */
+  */
 AnyType
 closest_columns::run(AnyType& args) {
-    MappedMatrix M = args[0].getAs<MappedMatrix>();
-    MappedColumnVector x = args[1].getAs<MappedColumnVector>();
-    uint32_t num = args[2].getAs<uint32_t>();
-    FunctionHandle dist = args[3].getAs<FunctionHandle>()
-        .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall);
-    string dist_fname = args[4].getAs<char *>();
-
-    std::string fname = dist_fn_name(dist_fname);
-
-    std::vector<std::tuple<Index, double> > result(num);
-    closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(),
-        result.end());
-
-    MutableArrayHandle<int32_t> indices = allocateArray<int32_t,
-        dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
-    MutableArrayHandle<double> distances = allocateArray<double,
-        dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
-    for (uint32_t i = 0; i < num; ++i)
-        std::tie(indices[i], distances[i]) = result[i];
-
-    AnyType tuple;
-    return tuple << indices << distances;
+
+    /* If the input has a null value, we want to return nothing for that
+    *  particular data point (because we cannot calculate the distance)
+    *  instead of failing.
+    */
+    try{
+        MappedMatrix M = args[0].getAs<MappedMatrix>();
+        MappedColumnVector x = args[1].getAs<MappedColumnVector>();
+        uint32_t num = args[2].getAs<uint32_t>();
+        FunctionHandle dist = args[3].getAs<FunctionHandle>()
+            .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall);
+        string dist_fname = args[4].getAs<char *>();
+
+        std::string fname = dist_fn_name(dist_fname);
+
+        std::vector<std::tuple<Index, double> > result(num);
+        closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(),
+            result.end());
+
+        MutableArrayHandle<int32_t> indices = allocateArray<int32_t,
+            dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
+        MutableArrayHandle<double> distances = allocateArray<double,
+            dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
+        for (uint32_t i = 0; i < num; ++i)
+            std::tie(indices[i], distances[i]) = result[i];
+
+        AnyType tuple;
+        return tuple << indices << distances;
+    }catch (const ArrayWithNullException &e) {
+        return Null();
+    }
 }
 
 AnyType
diff --git a/src/ports/postgres/modules/kmeans/kmeans.py_in b/src/ports/postgres/modules/kmeans/kmeans.py_in
index 628b690..30e4005 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.py_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.py_in
@@ -15,11 +15,15 @@ import plpy
 import re
 
 from utilities.control import IterationController2D
+from utilities.control import MinWarning
 from utilities.control_composite import IterationControllerComposite
 from utilities.validate_args import table_exists
 from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import table_is_empty
 from utilities.validate_args import get_expr_type
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+from utilities.utilities import _assert
 from utilities.utilities import unique_string
 
 HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
@@ -224,7 +228,21 @@ def compute_kmeans_random_seeding(schema_madlib, rel_args, rel_state,
             m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str))
     return iterationCtrl.iteration
 # ------------------------------------------------------------------------------
+def _create_temp_view_for_expr(schema_madlib, rel_source, expr_point):
+    """
+    Create a temporary view to evaluate the expr_point.
+    """
+
+    if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
+        view_name = unique_string('km_view')
+
+        plpy.execute(""" CREATE TEMP VIEW {view_name} AS
+            SELECT {expr_point} AS expr FROM {rel_source}
+            """.format(**locals()))
+        rel_source = view_name
+        expr_point = 'expr'
 
+    return rel_source,expr_point
 
 def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
                    expr_point, agg_centroid, **kwargs):
@@ -246,14 +264,9 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
         result in \c rel_state
     """
 
-    if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
-        view_name = unique_string('km_view')
-
-        plpy.execute(""" CREATE TEMP VIEW {view_name} AS
-            SELECT {expr_point} AS expr FROM {rel_source}
-            """.format(**locals()))
-        rel_source = view_name
-        expr_point = 'expr'
+    rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
+                                                        rel_source,
+                                                        expr_point)
 
     fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " +
                                 rel_args)[0]['fn_dist_name']
@@ -387,5 +400,66 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
                             'old_centroid': old_centroid_str}))
     return iterationCtrl.iteration
 
+def simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+    expr_point, centroids, fn_dist, **kwargs):
+
+    """
+    Calculate the simple silhouette score for every data point.
+    """
+
+    with MinWarning("error"):
+        kmeans_validate_src(schema_madlib, rel_source)
+        output_tbl_valid(output_table, 'kmeans')
+
+        _assert(type(centroids) == list and
+                type(centroids[0]) == list and
+                len(centroids) > 1,
+                'kmeans: Invalid centroids shape. Centroids have to be a 2D numeric array.')
+
+        rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
+                                                            rel_source,
+                                                            expr_point)
+
+        plpy.execute("""
+            CREATE TABLE {output_table} AS
+                SELECT {pid}, centroids[1] AS centroid_id,
+                centroids[2] AS neighbor_centroid_id,
+                (CASE
+                    WHEN distances[2] = 0 THEN 0
+                    ELSE (distances[2] - distances[1]) / distances[2]
+                END) AS silh
+                FROM
+                (SELECT {pid},
+                       (cc_out).column_ids::integer[] AS centroids,
+                       (cc_out).distances::double precision[] AS distances
+                FROM (
+                    SELECT {pid},
+                           {schema_madlib}._closest_columns(
+                            array{centroids},
+                            {expr_point},
+                            2,
+                            '{fn_dist}'::REGPROC, '{fn_dist}') AS cc_out
+                    FROM {rel_source})q1
+                )q2
+            """.format(**locals()))
+
+def simple_silhouette_points_dbl_wrapper(schema_madlib, rel_source, output_table, pid,
+    expr_point, centroids, fn_dist, **kwargs):
+
+    simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+        expr_point, centroids, fn_dist)
+
+
+def simple_silhouette_points_str_wrapper(schema_madlib, rel_source, output_table, pid,
+    expr_point, centroids_table, centroids_col, fn_dist, **kwargs):
+
+    input_tbl_valid(centroids_table, 'kmeans')
+    columns_exist_in_table(centroids_table, centroids_col)
+    centroids = plpy.execute("""
+        SELECT {centroids_col} AS centroids FROM {centroids_table}
+        """.format(**locals()))[0]['centroids']
+
+    simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+        expr_point, centroids, fn_dist, **kwargs)
 
 m4_changequote(<!`!>, <!'!>)
diff --git a/src/ports/postgres/modules/kmeans/kmeans.sql_in b/src/ports/postgres/modules/kmeans/kmeans.sql_in
index 1eae525..e354e05 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.sql_in
@@ -1906,3 +1906,56 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
     SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL)
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    pid VARCHAR,
+    expr_point VARCHAR,
+    centroids_table VARCHAR,
+    centroids_col VARCHAR,
+    fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */
+) RETURNS VOID AS $$
+    PythonFunction(kmeans, kmeans, simple_silhouette_points_str_wrapper)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    pid VARCHAR,
+    expr_point VARCHAR,
+    centroids_table VARCHAR,
+    centroids_col VARCHAR
+) RETURNS VOID
+AS $$
+    SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5, $6,
+        'MADLIB_SCHEMA.dist_norm2')
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    pid VARCHAR,
+    expr_point VARCHAR,
+    centroids DOUBLE PRECISION[],
+    fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */
+) RETURNS VOID AS $$
+    PythonFunction(kmeans, kmeans, simple_silhouette_points_dbl_wrapper)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+    rel_source VARCHAR,
+    output_table VARCHAR,
+    pid VARCHAR,
+    expr_point VARCHAR,
+    centroids DOUBLE PRECISION[]
+) RETURNS VOID
+AS $$
+    SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5,
+        'MADLIB_SCHEMA.dist_norm2')
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
index 4553b6c..b0e5024 100644
--- a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
@@ -64,6 +64,10 @@ SELECT * FROM kmeans('kmeans_2d', 'position', ARRAY[
 SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm1');
 SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm2');
 
+SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]');
+SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
+SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
+
 DROP TABLE IF EXISTS km_sample;
 
 CREATE TABLE km_sample(pid int, points double precision[]);
@@ -81,16 +85,6 @@ COPY km_sample (pid, points) FROM stdin DELIMITER '|';
 10 | {13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, NULL, 1045}
 \.
 
-
-SELECT * FROM kmeanspp('km_sample', 'points', 2,
-                       'MADLIB_SCHEMA.squared_dist_norm2',
-                       'MADLIB_SCHEMA.avg', 20, 0.001);
-
-
-SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]');
-SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
-SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
-
 -- Test kmeanspp_auto
 DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
 SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
@@ -163,6 +157,43 @@ DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
 SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[12,3,5,6,8], 'MADLIB_SCHEMA.squared_dist_norm2',
                        'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouette');
 
+-- Silhouette Tests
+DROP TABLE IF EXISTS km_sample_out, silh_out;
+
+CREATE TABLE km_sample_out AS
+SELECT * FROM kmeanspp('km_sample', 'points', 2,
+                       'MADLIB_SCHEMA.squared_dist_norm2',
+                       'MADLIB_SCHEMA.avg', 20, 0.001);
+
+-- Test simple_silhouette_points full interface
+SELECT * FROM simple_silhouette_points('km_sample', 'silh_out', 'pid', 'points',
+                                'km_sample_out', 'centroids',
+                                'MADLIB_SCHEMA.squared_dist_norm2');
+
+SELECT assert(silh > 0, 'Incorrect silhouette value')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points default distance func
+SELECT * FROM simple_silhouette_points(
+    'km_sample', 'silh_out', 'pid', 'points',
+    'km_sample_out', 'centroids');
+
+SELECT assert(count(*) = 9, 'Incorrect silhouette count')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points double precision array centroids
+SELECT * FROM simple_silhouette_points(
+    'km_sample', 'silh_out', 'pid', 'points',
+    (SELECT centroids FROM km_sample_out));
+
+SELECT assert(silh > 0, 'Incorrect silhouette value')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
 SELECT assert(
         silhouette > 0 AND objective_fn > 0,
         'Kmeans: Auto Kmeans_random failed for silhouette on unordered k vals')
@@ -206,6 +237,55 @@ SELECT assert(
         'Kmeans: Auto Kmeans_random failed for both.')
 FROM autokm_out_summary;
 
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points actual values
+SELECT * FROM simple_silhouette_points(
+    'km_sample', 'silh_out', 'pid', 'points',
+    ARRAY[[1,1,1,1,1,1,1,1,1,1,1,1,1],
+    [14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065]]::DOUBLE PRECISION[][]);
+
+SELECT assert(relative_error(1, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 1;
+SELECT assert(relative_error(0.8789, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 2;
+SELECT assert(relative_error(0.8966, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 3;
+SELECT assert(relative_error(0.7200, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 4;
+SELECT assert(relative_error(0.5560, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 5;
+SELECT assert(relative_error(0.7348, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 6;
+SELECT assert(relative_error(0.8242, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 7;
+SELECT assert(relative_error(0.8229, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 8;
+SELECT assert(relative_error(0.9655, silh) < 1e-3,
+    'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 9;
+
+SELECT assert(centroid_id = 1 AND neighbor_centroid_id = 0,
+    'Incorrect centroid ids')
+FROM silh_out;
+
+DROP TABLE IF EXISTS km_sample_out, silh_out;
 DROP TABLE IF EXISTS km_sample CASCADE;
 DROP TABLE IF EXISTS centroids CASCADE;
 DROP TABLE IF EXISTS kmeans_2d CASCADE;
diff --git a/src/ports/postgres/modules/linalg/linalg.sql_in b/src/ports/postgres/modules/linalg/linalg.sql_in
index b10cef7..3c74451 100644
--- a/src/ports/postgres/modules/linalg/linalg.sql_in
+++ b/src/ports/postgres/modules/linalg/linalg.sql_in
@@ -428,7 +428,7 @@ LANGUAGE C IMMUTABLE STRICT
 m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!NO SQL!>, <!!>);
 
 
--- Because of Jiara MPP-23166, ORCA makes the following
+-- Because of Jira MPP-23166, ORCA makes the following
 -- function extremely slow because "NO SQL" now becomes
 -- "CONTAINS SQL". This is why we disabled the optimizer
 -- in kmeans.
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 4e142aa..8f5b2ff 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -209,7 +209,7 @@ def add_postfix(quoted_string, postfix):
 
 
 NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
-               'real', 'double precision', 'serial', 'bigserial'])
+               'real', 'double precision', 'float', 'serial', 'bigserial'])
 INTEGER = set(['smallint', 'integer', 'bigint'])
 TEXT = set(['text', 'varchar', 'character varying', 'char', 'character'])
 BOOLEAN = set(['boolean'])