You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ok...@apache.org on 2019/09/19 17:24:14 UTC
[madlib] 01/02: Kmeans: Add simple silhouette score for every point
This is an automated email from the ASF dual-hosted git repository.
okislal pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/madlib.git
commit 7a130ec548f5e4f9e362e42de47cf04739cce64c
Author: Orhan Kislal <ok...@apache.org>
AuthorDate: Fri Sep 13 12:38:27 2019 -0400
Kmeans: Add simple silhouette score for every point
JIRA: MADLIB-1382
This commit adds a function to calculate the simple silhouette score for
every input data point.
Closes #441
Co-authored-by: Domino Valdano <dv...@pivotal.io>
---
src/modules/linalg/metric.cpp | 55 +++++++-----
src/ports/postgres/modules/kmeans/kmeans.py_in | 90 +++++++++++++++++--
src/ports/postgres/modules/kmeans/kmeans.sql_in | 53 +++++++++++
.../postgres/modules/kmeans/test/kmeans.sql_in | 100 ++++++++++++++++++---
src/ports/postgres/modules/linalg/linalg.sql_in | 2 +-
.../postgres/modules/utilities/utilities.py_in | 2 +-
6 files changed, 259 insertions(+), 43 deletions(-)
diff --git a/src/modules/linalg/metric.cpp b/src/modules/linalg/metric.cpp
index e3835ee..4809762 100644
--- a/src/modules/linalg/metric.cpp
+++ b/src/modules/linalg/metric.cpp
@@ -365,31 +365,40 @@ closest_column::run(AnyType& args) {
* This function calls a user-supplied function, for which it does not do
* garbage collection. It is therefore meant to be called only constantly many
* times before control is returned to the backend.
- */
+ */
AnyType
closest_columns::run(AnyType& args) {
- MappedMatrix M = args[0].getAs<MappedMatrix>();
- MappedColumnVector x = args[1].getAs<MappedColumnVector>();
- uint32_t num = args[2].getAs<uint32_t>();
- FunctionHandle dist = args[3].getAs<FunctionHandle>()
- .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall);
- string dist_fname = args[4].getAs<char *>();
-
- std::string fname = dist_fn_name(dist_fname);
-
- std::vector<std::tuple<Index, double> > result(num);
- closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(),
- result.end());
-
- MutableArrayHandle<int32_t> indices = allocateArray<int32_t,
- dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
- MutableArrayHandle<double> distances = allocateArray<double,
- dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
- for (uint32_t i = 0; i < num; ++i)
- std::tie(indices[i], distances[i]) = result[i];
-
- AnyType tuple;
- return tuple << indices << distances;
+
+ /* If the input has a null value, we want to return nothing for that
+ * particular data point (because we cannot calculate the distance)
+ * instead of failing.
+ */
+ try{
+ MappedMatrix M = args[0].getAs<MappedMatrix>();
+ MappedColumnVector x = args[1].getAs<MappedColumnVector>();
+ uint32_t num = args[2].getAs<uint32_t>();
+ FunctionHandle dist = args[3].getAs<FunctionHandle>()
+ .unsetFunctionCallOptions(FunctionHandle::GarbageCollectionAfterCall);
+ string dist_fname = args[4].getAs<char *>();
+
+ std::string fname = dist_fn_name(dist_fname);
+
+ std::vector<std::tuple<Index, double> > result(num);
+ closestColumnsAndDistancesShortcut(M, x, dist, fname, result.begin(),
+ result.end());
+
+ MutableArrayHandle<int32_t> indices = allocateArray<int32_t,
+ dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
+ MutableArrayHandle<double> distances = allocateArray<double,
+ dbal::FunctionContext, dbal::DoNotZero, dbal::ThrowBadAlloc>(num);
+ for (uint32_t i = 0; i < num; ++i)
+ std::tie(indices[i], distances[i]) = result[i];
+
+ AnyType tuple;
+ return tuple << indices << distances;
+ }catch (const ArrayWithNullException &e) {
+ return Null();
+ }
}
AnyType
diff --git a/src/ports/postgres/modules/kmeans/kmeans.py_in b/src/ports/postgres/modules/kmeans/kmeans.py_in
index 628b690..30e4005 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.py_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.py_in
@@ -15,11 +15,15 @@ import plpy
import re
from utilities.control import IterationController2D
+from utilities.control import MinWarning
from utilities.control_composite import IterationControllerComposite
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_expr_type
+from utilities.validate_args import input_tbl_valid
+from utilities.validate_args import output_tbl_valid
+from utilities.utilities import _assert
from utilities.utilities import unique_string
HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
@@ -224,7 +228,21 @@ def compute_kmeans_random_seeding(schema_madlib, rel_args, rel_state,
m = it.evaluate("_args.k - coalesce(array_upper({0}, 1), 0)".format(state_str))
return iterationCtrl.iteration
# ------------------------------------------------------------------------------
+def _create_temp_view_for_expr(schema_madlib, rel_source, expr_point):
+ """
+ Create a temporary view to evaluate the expr_point.
+ """
+
+ if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
+ view_name = unique_string('km_view')
+
+ plpy.execute(""" CREATE TEMP VIEW {view_name} AS
+ SELECT {expr_point} AS expr FROM {rel_source}
+ """.format(**locals()))
+ rel_source = view_name
+ expr_point = 'expr'
+ return rel_source,expr_point
def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
expr_point, agg_centroid, **kwargs):
@@ -246,14 +264,9 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
result in \c rel_state
"""
- if kmeans_validate_expr(schema_madlib, rel_source, expr_point):
- view_name = unique_string('km_view')
-
- plpy.execute(""" CREATE TEMP VIEW {view_name} AS
- SELECT {expr_point} AS expr FROM {rel_source}
- """.format(**locals()))
- rel_source = view_name
- expr_point = 'expr'
+ rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
+ rel_source,
+ expr_point)
fn_dist_name = plpy.execute("SELECT fn_dist_name FROM " +
rel_args)[0]['fn_dist_name']
@@ -387,5 +400,66 @@ def compute_kmeans(schema_madlib, rel_args, rel_state, rel_source,
'old_centroid': old_centroid_str}))
return iterationCtrl.iteration
+def simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+ expr_point, centroids, fn_dist, **kwargs):
+
+ """
+ Calculate the simple silhouette score for every data point.
+ """
+
+ with MinWarning("error"):
+ kmeans_validate_src(schema_madlib, rel_source)
+ output_tbl_valid(output_table, 'kmeans')
+
+ _assert(type(centroids) == list and
+ type(centroids[0]) == list and
+ len(centroids) > 1,
+ 'kmeans: Invalid centroids shape. Centroids have to be a 2D numeric array.')
+
+ rel_source, expr_point = _create_temp_view_for_expr(schema_madlib,
+ rel_source,
+ expr_point)
+
+ plpy.execute("""
+ CREATE TABLE {output_table} AS
+ SELECT {pid}, centroids[1] AS centroid_id,
+ centroids[2] AS neighbor_centroid_id,
+ (CASE
+ WHEN distances[2] = 0 THEN 0
+ ELSE (distances[2] - distances[1]) / distances[2]
+ END) AS silh
+ FROM
+ (SELECT {pid},
+ (cc_out).column_ids::integer[] AS centroids,
+ (cc_out).distances::double precision[] AS distances
+ FROM (
+ SELECT {pid},
+ {schema_madlib}._closest_columns(
+ array{centroids},
+ {expr_point},
+ 2,
+ '{fn_dist}'::REGPROC, '{fn_dist}') AS cc_out
+ FROM {rel_source})q1
+ )q2
+ """.format(**locals()))
+
+def simple_silhouette_points_dbl_wrapper(schema_madlib, rel_source, output_table, pid,
+ expr_point, centroids, fn_dist, **kwargs):
+
+ simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+ expr_point, centroids, fn_dist)
+
+
+def simple_silhouette_points_str_wrapper(schema_madlib, rel_source, output_table, pid,
+ expr_point, centroids_table, centroids_col, fn_dist, **kwargs):
+
+ input_tbl_valid(centroids_table, 'kmeans')
+ columns_exist_in_table(centroids_table, centroids_col)
+ centroids = plpy.execute("""
+ SELECT {centroids_col} AS centroids FROM {centroids_table}
+ """.format(**locals()))[0]['centroids']
+
+ simple_silhouette_points(schema_madlib, rel_source, output_table, pid,
+ expr_point, centroids, fn_dist, **kwargs)
m4_changequote(<!`!>, <!'!>)
diff --git a/src/ports/postgres/modules/kmeans/kmeans.sql_in b/src/ports/postgres/modules/kmeans/kmeans.sql_in
index 1eae525..e354e05 100644
--- a/src/ports/postgres/modules/kmeans/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/kmeans.sql_in
@@ -1906,3 +1906,56 @@ CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.kmeans_random_auto(
SELECT MADLIB_SCHEMA.kmeans_random_auto($1, $2, $3, $4, NULL, NULL, NULL, NULL, NULL)
$$ LANGUAGE sql VOLATILE
m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ pid VARCHAR,
+ expr_point VARCHAR,
+ centroids_table VARCHAR,
+ centroids_col VARCHAR,
+ fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */
+) RETURNS VOID AS $$
+ PythonFunction(kmeans, kmeans, simple_silhouette_points_str_wrapper)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ pid VARCHAR,
+ expr_point VARCHAR,
+ centroids_table VARCHAR,
+ centroids_col VARCHAR
+) RETURNS VOID
+AS $$
+ SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5, $6,
+ 'MADLIB_SCHEMA.dist_norm2')
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ pid VARCHAR,
+ expr_point VARCHAR,
+ centroids DOUBLE PRECISION[],
+ fn_dist VARCHAR /*+ DEFAULT 'dist_norm2' */
+) RETURNS VOID AS $$
+ PythonFunction(kmeans, kmeans, simple_silhouette_points_dbl_wrapper)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.simple_silhouette_points(
+ rel_source VARCHAR,
+ output_table VARCHAR,
+ pid VARCHAR,
+ expr_point VARCHAR,
+ centroids DOUBLE PRECISION[]
+) RETURNS VOID
+AS $$
+ SELECT MADLIB_SCHEMA.simple_silhouette_points($1, $2, $3, $4, $5,
+ 'MADLIB_SCHEMA.dist_norm2')
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
index 4553b6c..b0e5024 100644
--- a/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
+++ b/src/ports/postgres/modules/kmeans/test/kmeans.sql_in
@@ -64,6 +64,10 @@ SELECT * FROM kmeans('kmeans_2d', 'position', ARRAY[
SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm1');
SELECT * FROM kmeans('kmeans_2d', 'position', 'centroids', 'position', 'MADLIB_SCHEMA.dist_norm2');
+SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]');
+SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
+SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
+
DROP TABLE IF EXISTS km_sample;
CREATE TABLE km_sample(pid int, points double precision[]);
@@ -81,16 +85,6 @@ COPY km_sample (pid, points) FROM stdin DELIMITER '|';
10 | {13.86, 1.35, 2.27, 16, 98, 2.98, 3.15, 0.22, 1.8500, 7.2199, 1.01, NULL, 1045}
\.
-
-SELECT * FROM kmeanspp('km_sample', 'points', 2,
- 'MADLIB_SCHEMA.squared_dist_norm2',
- 'MADLIB_SCHEMA.avg', 20, 0.001);
-
-
-SELECT * FROM kmeans('kmeans_2d', 'array[x,y]', 'centroids', 'array[x,y]');
-SELECT * FROM kmeanspp('kmeans_2d', 'array[x,y]', 10);
-SELECT * FROM kmeans_random('kmeans_2d', 'arRAy [ x,y]', 10);
-
-- Test kmeanspp_auto
DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
SELECT * FROM kmeanspp_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[2,3,4,5,6,7,8], 'MADLIB_SCHEMA.squared_dist_norm2',
@@ -163,6 +157,43 @@ DROP TABLE IF EXISTS autokm_out,autokm_out_summary;
SELECT * FROM kmeans_random_auto('kmeans_2d', 'autokm_out', 'array[x,y]', ARRAY[12,3,5,6,8], 'MADLIB_SCHEMA.squared_dist_norm2',
'MADLIB_SCHEMA.avg', 20, 0.001, 'silhouette');
+-- Silhouette Tests
+DROP TABLE IF EXISTS km_sample_out, silh_out;
+
+CREATE TABLE km_sample_out AS
+SELECT * FROM kmeanspp('km_sample', 'points', 2,
+ 'MADLIB_SCHEMA.squared_dist_norm2',
+ 'MADLIB_SCHEMA.avg', 20, 0.001);
+
+-- Test simple_silhouette_points full interface
+SELECT * FROM simple_silhouette_points('km_sample', 'silh_out', 'pid', 'points',
+ 'km_sample_out', 'centroids',
+ 'MADLIB_SCHEMA.squared_dist_norm2');
+
+SELECT assert(silh > 0, 'Incorrect silhouette value')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points default distance func
+SELECT * FROM simple_silhouette_points(
+ 'km_sample', 'silh_out', 'pid', 'points',
+ 'km_sample_out', 'centroids');
+
+SELECT assert(count(*) = 9, 'Incorrect silhouette count')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points double precision array centroids
+SELECT * FROM simple_silhouette_points(
+ 'km_sample', 'silh_out', 'pid', 'points',
+ (SELECT centroids FROM km_sample_out));
+
+SELECT assert(silh > 0, 'Incorrect silhouette value')
+FROM silh_out
+WHERE silh IS NOT NULL;
+
SELECT assert(
silhouette > 0 AND objective_fn > 0,
'Kmeans: Auto Kmeans_random failed for silhouette on unordered k vals')
@@ -206,6 +237,55 @@ SELECT assert(
'Kmeans: Auto Kmeans_random failed for both.')
FROM autokm_out_summary;
+DROP TABLE IF EXISTS silh_out;
+-- Test simple_silhouette_points actual values
+SELECT * FROM simple_silhouette_points(
+ 'km_sample', 'silh_out', 'pid', 'points',
+ ARRAY[[1,1,1,1,1,1,1,1,1,1,1,1,1],
+ [14.23, 1.71, 2.43, 15.6, 127, 2.8, 3.0600, 0.2800, 2.29, 5.64, 1.04, 3.92, 1065]]::DOUBLE PRECISION[][]);
+
+SELECT assert(relative_error(1, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 1;
+SELECT assert(relative_error(0.8789, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 2;
+SELECT assert(relative_error(0.8966, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 3;
+SELECT assert(relative_error(0.7200, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 4;
+SELECT assert(relative_error(0.5560, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 5;
+SELECT assert(relative_error(0.7348, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 6;
+SELECT assert(relative_error(0.8242, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 7;
+SELECT assert(relative_error(0.8229, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 8;
+SELECT assert(relative_error(0.9655, silh) < 1e-3,
+ 'Incorrect silhouette value')
+FROM silh_out
+WHERE pid = 9;
+
+SELECT assert(centroid_id = 1 AND neighbor_centroid_id = 0,
+ 'Incorrect centroid ids')
+FROM silh_out;
+
+DROP TABLE IF EXISTS km_sample_out, silh_out;
DROP TABLE IF EXISTS km_sample CASCADE;
DROP TABLE IF EXISTS centroids CASCADE;
DROP TABLE IF EXISTS kmeans_2d CASCADE;
diff --git a/src/ports/postgres/modules/linalg/linalg.sql_in b/src/ports/postgres/modules/linalg/linalg.sql_in
index b10cef7..3c74451 100644
--- a/src/ports/postgres/modules/linalg/linalg.sql_in
+++ b/src/ports/postgres/modules/linalg/linalg.sql_in
@@ -428,7 +428,7 @@ LANGUAGE C IMMUTABLE STRICT
m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!NO SQL!>, <!!>);
--- Because of Jiara MPP-23166, ORCA makes the following
+-- Because of Jira MPP-23166, ORCA makes the following
-- function extremely slow because "NO SQL" now becomes
-- "CONTAINS SQL". This is why we disabled the optimizer
-- in kmeans.
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 4e142aa..8f5b2ff 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -209,7 +209,7 @@ def add_postfix(quoted_string, postfix):
NUMERIC = set(['smallint', 'integer', 'bigint', 'decimal', 'numeric',
- 'real', 'double precision', 'serial', 'bigserial'])
+ 'real', 'double precision', 'float', 'serial', 'bigserial'])
INTEGER = set(['smallint', 'integer', 'bigint'])
TEXT = set(['text', 'varchar', 'character varying', 'char', 'character'])
BOOLEAN = set(['boolean'])