You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@madlib.apache.org by ri...@apache.org on 2015/12/30 23:48:34 UTC

incubator-madlib git commit: Covariance: Add covariance matrix function (Pearson)

Repository: incubator-madlib
Updated Branches:
  refs/heads/master 33ed578b9 -> ba5fc1ead


Covariance: Add covariance matrix function (Pearson)

JIRA: MADLIB-941

Added new function covariance() which returns the covariance matrix. The
implementation is an update to the Pearson's correlation method, where
the scaling at the final step is avoided.


Project: http://git-wip-us.apache.org/repos/asf/incubator-madlib/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-madlib/commit/ba5fc1ea
Tree: http://git-wip-us.apache.org/repos/asf/incubator-madlib/tree/ba5fc1ea
Diff: http://git-wip-us.apache.org/repos/asf/incubator-madlib/diff/ba5fc1ea

Branch: refs/heads/master
Commit: ba5fc1ead162f71a37badcc772f410dbcd4322b2
Parents: 33ed578
Author: Rahul Iyer <ri...@pivotal.io>
Authored: Wed Dec 30 14:46:22 2015 -0800
Committer: Rahul Iyer <ri...@pivotal.io>
Committed: Wed Dec 30 14:46:22 2015 -0800

----------------------------------------------------------------------
 src/modules/stats/correlation.cpp               |   6 +-
 .../postgres/modules/stats/correlation.py_in    | 271 +++++++++----------
 .../postgres/modules/stats/correlation.sql_in   | 117 ++++++++
 .../modules/utilities/validate_args.py_in       |  31 ++-
 4 files changed, 269 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/ba5fc1ea/src/modules/stats/correlation.cpp
----------------------------------------------------------------------
diff --git a/src/modules/stats/correlation.cpp b/src/modules/stats/correlation.cpp
index b0fd858..e300829 100644
--- a/src/modules/stats/correlation.cpp
+++ b/src/modules/stats/correlation.cpp
@@ -20,7 +20,7 @@ using namespace dbal::eigen_integration;
 
 AnyType
 correlation_transition::run(AnyType& args) {
-    // args[2]
+    // args[2] is the mean of features vector
     if (args[2].isNull()) {
         throw std::runtime_error("Correlation: Mean vector is NULL.");
     }
@@ -31,7 +31,7 @@ correlation_transition::run(AnyType& args) {
     } catch (const ArrayWithNullException &e) {
         throw std::runtime_error("Correlation: Mean vector contains NULL.");
     }
-    // args[0]
+    // args[0] is the covariance matrix
     MutableNativeMatrix state;
     if (args[0].isNull()) {
         state.rebind(this->allocateArray<double>(mean.size(), mean.size()),
@@ -39,7 +39,7 @@ correlation_transition::run(AnyType& args) {
     } else {
         state.rebind(args[0].getAs<MutableArrayHandle<double> >());
     }
-    // args[1]
+    // args[1] is the current data vector
     if (args[1].isNull()) { return state; }
     MappedColumnVector x;
     try {

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/ba5fc1ea/src/ports/postgres/modules/stats/correlation.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/stats/correlation.py_in b/src/ports/postgres/modules/stats/correlation.py_in
index 6ad07ce..6bc1b5a 100644
--- a/src/ports/postgres/modules/stats/correlation.py_in
+++ b/src/ports/postgres/modules/stats/correlation.py_in
@@ -8,9 +8,13 @@
 import plpy
 from time import time
 from utilities.utilities import unique_string, add_postfix
+from utilities.validate_args import get_cols_and_types
+from utilities.control import MinWarning
+
 
 def correlation(schema_madlib, source_table, output_table,
-                target_cols, verbose, **kwargs):
+                target_cols, get_cov=False, verbose=False,
+                **kwargs):
     """
     Populates an output table with the coefficients of correlation between
     the columns in a source table
@@ -19,6 +23,8 @@ def correlation(schema_madlib, source_table, output_table,
         @param source_table   Name of input table
         @param output_table   Name of output table
         @param target_cols    Name of specific columns targetted for correlation
+        @param get_cov       If False return the correlation matrix else
+                                return the covariance matrix
         @param verbose        Flag to determine whether to output debug info
 
     Returns:
@@ -34,11 +40,11 @@ def correlation(schema_madlib, source_table, output_table,
         # column names
         _existing_target_cols = list(set(column for column in _target_cols
                                          if column in _numeric_column_names))
-        _nonexisting_target_cols = list(set(_target_cols)
-                                        - set(_existing_target_cols)
-                                        - set(_nonnumeric_column_names))
-        _nonnumeric_target_cols = list(set(_target_cols)
-                                       & set(_nonnumeric_column_names))
+        _nonexisting_target_cols = list(set(_target_cols) -
+                                        set(_existing_target_cols) -
+                                        set(_nonnumeric_column_names))
+        _nonnumeric_target_cols = list(set(_target_cols) &
+                                       set(_nonnumeric_column_names))
     else:
         # if target_cols not provided then all numeric columns are to be targeted
         _existing_target_cols = _numeric_column_names[:]
@@ -64,7 +70,7 @@ def correlation(schema_madlib, source_table, output_table,
     # ---- Output message ----
 
     return _populate_output_table(schema_madlib, source_table, output_table,
-                                  _existing_target_cols, verbose)
+                                  _existing_target_cols, get_cov, verbose)
 
 
 # -----------------------------------------------------------------------
@@ -102,7 +108,7 @@ def _validate_corr_arg(source_table, output_table):
 
 
 # -----------------------------------------------------------------------
-# Get all column names in source table
+# Get all numeric column names in source table
 # -----------------------------------------------------------------------
 def _get_numeric_columns(source_table):
     """
@@ -114,42 +120,18 @@ def _get_numeric_columns(source_table):
     Returns:
         List of column names in table
     """
-    # determine the exact table_schema and table_name
-    # in case that source_table only contains table_name
-    row = plpy.execute("""
-                        SELECT
-                            quote_ident(nspname) AS table_schema,
-                            quote_ident(relname) AS table_name
-                        FROM
-                            pg_class AS c,
-                            pg_namespace AS nsp
-                        WHERE
-                            c.oid = '{source_table}'::regclass::oid AND
-                            c.relnamespace = nsp.oid
-                        """.format(source_table=source_table))
-    table_schema = row[0]['table_schema']
-    table_name = row[0]['table_name']
 
     # retrieve the numeric columns
     numeric_types = ('smallint', 'integer', 'bigint',
                      'real', 'numeric', 'double precision')
-    all_columns = plpy.execute("""
-                                SELECT quote_ident(column_name) as column_name,
-                                       data_type
-                                FROM
-                                    information_schema.columns
-                                WHERE
-                                    quote_ident(table_schema) = '{table_schema}' AND
-                                    quote_ident(table_name) = '{table_name}'
-                                ORDER BY ordinal_position
-                                """.format(table_schema=table_schema,
-                                           table_name=table_name))
-    all_col_names = set(column['column_name'] for column in all_columns)
-    num_col_names = [column['column_name'] for column in all_columns
-                     if column['data_type'] in numeric_types]
-    nonnum_col_names = list(all_col_names - set(num_col_names))
-
-    return (num_col_names, nonnum_col_names)
+    all_col_types = get_cols_and_types(source_table)
+    all_col_names = set(all_col_types.keys())
+
+    numeric_col_names = [c for c in all_col_names
+                         if all_col_types[c] in numeric_types]
+    nonnum_col_names = list(all_col_names - set(numeric_col_names))
+
+    return (numeric_col_names, nonnum_col_names)
 
 
 # -----------------------------------------------------------------------
@@ -170,7 +152,7 @@ def _analyze_target_cols(target_cols):
 # Create and populate output table
 # -----------------------------------------------------------------------
 def _populate_output_table(schema_madlib, source_table, output_table,
-                           col_names, verbose):
+                           col_names, get_cov=True, verbose=False):
     """
     Creates a relation with the appropriate number of columns given a list of
     column names and populates with the correlation coefficients. If the table
@@ -180,122 +162,126 @@ def _populate_output_table(schema_madlib, source_table, output_table,
         @param schema_madlib  Schema of MADlib
         @param source_table   Name of source table
         @param output_table   Name of output table
-        @param _target_cols   Name of all columns to place in output table
+        @param col_names      Name of all columns to place in output table
+        @param get_cov        If False return the correlation matrix else
+                                return covariance matrix
 
     Returns:
         Tuple (output table name, number of columns, time for computation)
     """
-    old_msg_level = plpy.execute("""
-        SELECT setting FROM pg_settings
-        WHERE name='client_min_messages'
-        """)[0]['setting']
-    if verbose:
-        plpy.execute("SET client_min_messages TO info")
-    else:
-        plpy.execute("SET client_min_messages TO error")
-
-    start = time()
-    col_len = len(col_names)
-    col_names_str = ",".join(col_names)
-    temp_table = unique_string()
-
-    # actual computation
-    plpy.execute("""
-        CREATE TEMP TABLE {temp_table} AS
-        SELECT
-            tot_cnt,
-            count(*) AS non_null_cnt,
-            mean,
-            {schema_madlib}.correlation_agg(x, mean) as cor_mat
-        FROM
-        (
-            SELECT ARRAY[{col_names_str}]::float8[] AS x
-            FROM {source_table}
-        ) src1,
-        (
+    with MinWarning("info" if verbose else "error"):
+        start = time()
+        col_len = len(col_names)
+        col_names_str = ",".join(col_names)
+        temp_table = unique_string()
+        if get_cov:
+            agg_str = """
+                (CASE WHEN count(*) > 0
+                      THEN {0}.array_scalar_mult({0}.covariance_agg(x, mean),
+                                                 1.0 / count(*)::double precision)
+                      ELSE NULL
+                END) """.format(schema_madlib)
+        else:
+            agg_str = "{0}.correlation_agg(x, mean)".format(schema_madlib)
+
+        # actual computation
+        plpy.execute("""
+            CREATE TEMP TABLE {temp_table} AS
             SELECT
-                count(*) AS tot_cnt,
-                {schema_madlib}.avg(x) AS mean
+                tot_cnt,
+                count(*) AS non_null_cnt,
+                mean,
+                {agg_str} as cor_mat
             FROM
             (
                 SELECT ARRAY[{col_names_str}]::float8[] AS x
                 FROM {source_table}
-            ) src2
-        ) subq
-        WHERE NOT {schema_madlib}.array_contains_null(x)
-        GROUP BY tot_cnt, mean
-        """.format(**locals()))
-
-    # create summary table
-    summary_table = add_postfix(output_table, "_summary")
-    q_summary = """
-        CREATE TABLE {summary_table} AS
-        SELECT
-            'correlation'::varchar      AS method,
-            '{source_table}'::varchar   AS source,
-            '{output_table}'::varchar   AS output_table,
-            '{col_names_str}'::varchar  AS column_names,
-            mean                        AS mean_vector,
-            non_null_cnt                AS total_rows_processed,
-            tot_cnt - non_null_cnt      AS total_rows_skipped
-        FROM {temp_table}
-        """.format(**locals())
-
-    plpy.execute(q_summary)
-
-    # create output table
-    as_list = "deconstructed(column_position integer"
-    for k, c in enumerate(col_names):
-        if k % 10 == 0:
-            as_list += "\n                "
-        as_list += ", {c} float8".format(c=c)
-    as_list += ")"
-
-    output_plan = plpy.prepare("""
-        CREATE TABLE {output_table} AS
-        SELECT
-            *
-        FROM
-        (
+            ) src1,
+            (
+                SELECT
+                    count(*) AS tot_cnt,
+                    {schema_madlib}.avg(x) AS mean
+                FROM
+                (
+                    SELECT ARRAY[{col_names_str}]::float8[] AS x
+                    FROM {source_table}
+                ) src2
+            ) subq
+            WHERE NOT {schema_madlib}.array_contains_null(x)
+            GROUP BY tot_cnt, mean
+            """.format(**locals()))
+
+        # create summary table
+        summary_table = add_postfix(output_table, "_summary")
+        q_summary = """
+            CREATE TABLE {summary_table} AS
             SELECT
-                generate_series(1, {num_cols}) AS column_position,
-                unnest($1) AS variable
-        ) variable_subq
-        JOIN
-        (
+                'correlation'::varchar      AS method,
+                '{source_table}'::varchar   AS source,
+                '{output_table}'::varchar   AS output_table,
+                '{col_names_str}'::varchar  AS column_names,
+                mean                        AS mean_vector,
+                non_null_cnt                AS total_rows_processed,
+                tot_cnt - non_null_cnt      AS total_rows_skipped
+            FROM {temp_table}
+            """.format(**locals())
+
+        plpy.execute(q_summary)
+
+        # create output table
+        as_list = "deconstructed(column_position integer"
+        for k, c in enumerate(col_names):
+            if k % 10 == 0:
+                as_list += "\n                "
+            as_list += ", {c} float8".format(c=c)
+        as_list += ")"
+
+        output_plan = plpy.prepare("""
+            CREATE TABLE {output_table} AS
             SELECT
                 *
             FROM
-                {schema_madlib}.__deconstruct_lower_triangle(
-                    (SELECT cor_mat FROM {temp_table})
-                )
-                AS {as_list}
-        ) matrix_subq
-        USING (column_position)
-        """.format(num_cols=len(col_names), **locals()), ["varchar[]"])
+            (
+                SELECT
+                    generate_series(1, {num_cols}) AS column_position,
+                    unnest($1) AS variable
+            ) variable_subq
+            JOIN
+            (
+                SELECT
+                    *
+                FROM
+                    {schema_madlib}.__deconstruct_lower_triangle(
+                        (SELECT cor_mat FROM {temp_table})
+                    )
+                    AS {as_list}
+            ) matrix_subq
+            USING (column_position)
+            """.format(num_cols=len(col_names), **locals()), ["varchar[]"])
 
-    plpy.execute(output_plan, [col_names])
+        plpy.execute(output_plan, [col_names])
 
-    # clean up and return
-    plpy.execute("DROP TABLE {temp_table}".format(**locals()))
-    plpy.execute("SET client_min_messages TO " + old_msg_level)
-    end = time()
-    return (output_table, len(col_names), end - start)
+        # clean up and return
+        plpy.execute("DROP TABLE {temp_table}".format(**locals()))
+        end = time()
+        return (output_table, len(col_names), end - start)
 
 
 # -----------------------------------------------------------------------
 # Help messages
 # -----------------------------------------------------------------------
-def correlation_help_message(schema_madlib, message, **kwargs):
+def correlation_help_message(schema_madlib, message, cov=False, **kwargs):
     """
     Given a help string, provide usage information
     """
+    func = "covariance" if cov else "correlation"
+
     if message is not None and \
             message.lower() in ("usage", "help", "?"):
         return """
 Usage:
 -----------------------------------------------------------------------
-SELECT {schema_madlib}.correlation
+SELECT {schema_madlib}.{func}
 (
     source_table TEXT,   -- Source table name (Required)
     output_table TEXT,   -- Output table name (Required)
@@ -310,14 +296,14 @@ The columns of the table are described as follows:
 
     - column_position   : Position of the variable in the 'source_table'.
     - variable          : Provides the row-header for each variable
-    - Rest of the table is the NxN correlation matrix for all numeric columns
+    - Rest of the table is the NxN {func} matrix for all numeric columns
     in 'source_table'.
 
 The output table is arranged as a lower-traingular matrix with the upper
 triangle set to NULL and the diagonal elements set to 1.0. To obtain the
 result from the output_table in this matrix format ensure to order the
 elements using the 'column_position' column.
-        """.format(schema_madlib=schema_madlib)
+        """.format(schema_madlib=schema_madlib, func=func)
     elif message is not None and message.lower() in ('example', 'examples'):
         return """
 DROP TABLE IF EXISTS example_data;
@@ -362,13 +348,13 @@ VALUES(NULL, 100, 100, 'true', NULL);
 INSERT INTO example_data(outlook, temperature, humidity, windy, class)
 VALUES(NULL, 110, 100, 'true', NULL);
 
-SELECT madlib.correlation('example_data', 'example_data_output');
-SELECT madlib.correlation('example_data', 'example_data_output', '*');
-SELECT madlib.correlation('example_data', 'example_data_output', 'temperature, humidity');
+SELECT madlib.{func}('example_data', 'example_data_output');
+SELECT madlib.{func}('example_data', 'example_data_output', '*');
+SELECT madlib.{func}('example_data', 'example_data_output', 'temperature, humidity');
 
-To get the correlation matrix from output table:
+-- To get the {func} matrix from output table:
 SELECT * from example_data_output order by column_position;
-         """
+         """.format(func=func)
     else:
         return """
 A correlation function is the degree and direction of association of
@@ -378,9 +364,14 @@ from the other. The coefficient of correlation varies from -1 to 1:
 perfectly anti-correlated.
 -------
 For an overview on usage, run:
-SELECT {schema_madlib}.correlation('usage');
+    SELECT {schema_madlib}.correlation('usage');
+
+To obtain the covariance values instead of correlation:
+    SELECT {schema_madlib}.covariance('usage');
 -------
-For an example, run:
-SELECT {schema_madlib}.correlation('example')
-            """.format(schema_madlib=schema_madlib)
+For examples:
+    SELECT {schema_madlib}.correlation('example');
+OR
+    SELECT {schema_madlib}.covariance('example');
 
+            """.format(schema_madlib=schema_madlib)

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/ba5fc1ea/src/ports/postgres/modules/stats/correlation.sql_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/stats/correlation.sql_in b/src/ports/postgres/modules/stats/correlation.sql_in
index 4f2c255..36cb133 100644
--- a/src/ports/postgres/modules/stats/correlation.sql_in
+++ b/src/ports/postgres/modules/stats/correlation.sql_in
@@ -50,6 +50,16 @@ correlation( source_table,
            )
 </pre>
 
+The covariance function, with a similar syntax,
+can be used to compute the covariance between features.
+<pre class="syntax">
+covariance( source_table,
+             output_table,
+             target_cols,
+             verbose
+           )
+</pre>
+
 <dl class="arglist">
 
 <dt>source_table</dt>
@@ -176,6 +186,26 @@ Result:
 (2 rows)
 </pre>
 
+-# Compute the covariance of features in the data set.
+<pre class="example">
+SELECT madlib.covariance( 'example_data',
+                          'cov_output'
+                         );
+</pre>
+
+-# View the covariance matrix.
+<pre class="example">
+SELECT * FROM cov_output ORDER BY column_position;
+</pre>
+Result:
+<pre class="result">
+ column_position |  variable   |    temperature    | humidity
+-----------------+-------------+-------------------+----------
+               1 | temperature |      146.25       |
+               2 | humidity    |      82.125       | 121.1875
+(2 rows)
+</pre>
+
 @par Notes
 Current implementation ignores a row that contains NULL entirely. This means
 any correlation in such a row (with NULLs) does not contribute to the final answer.
@@ -230,6 +260,21 @@ CREATE AGGREGATE MADLIB_SCHEMA.correlation_agg(
     -- use NULL as the initial value
 );
 
+DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.covariance_agg(
+    double precision[], double precision[]);
+CREATE AGGREGATE MADLIB_SCHEMA.covariance_agg(
+    /* x */     double precision[],
+    /* mean */  double precision[]
+) (
+    SType = double precision[],
+    m4_ifdef(`__POSTGRESQL__', `', `prefunc=MADLIB_SCHEMA.correlation_merge,')
+    SFunc = MADLIB_SCHEMA.correlation_transition
+    -- use NULL as the initial value
+    -- return the last transition or merge state as the final state
+    -- this aggregate does not divide by the number of samples
+    --    (hence it's sum of (x-mean)^2 instead of expectation)
+);
+
 -----------------------------------------------------------------------
 
 DROP TYPE IF EXISTS MADLIB_SCHEMA.correlation_result CASCADE;
@@ -309,3 +354,75 @@ RETURNS text AS $$
 $$ LANGUAGE plpythonu IMMUTABLE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 -------------------------------------------------------------------------
+
+-----------------------------------------------------------------------
+-- Main function for covariance
+-----------------------------------------------------------------------
+/* @brief Compute a covariance matrix for a table with optional target columns specified
+
+   @param source_table Name of source relation containing the data
+   @param output_table Name of output table name to store the correlation
+   @param target_cols  String with comma separated list of columns for which cross-correlation is desired
+   @param verbose      Flag to determine verbosity
+
+   @usage
+   <pre> SELECT MADLIB_SCHEMA.covariance (
+         '<em>source_table</em>', '<em>output_table</em>',
+         '<em>target_cols</em>'
+     );
+     SELECT * FROM '<em>output_table</em>' ORDER BY '<em>colum_position</em>';
+   </pre>
+*/
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.covariance(
+    source_table varchar, --  input table name
+    output_table varchar, -- output table name
+    target_cols  varchar, -- comma separated list of output cols (default = '*')
+    verbose      boolean  -- flag to determine verbosity
+) RETURNS MADLIB_SCHEMA.correlation_result AS $$
+    PythonFunctionBodyOnly(`stats', `correlation')
+    return correlation.correlation(schema_madlib, source_table, output_table,
+                                   target_cols, True, verbose)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-----------------------------------------------------------------------
+-- Overloaded functions
+-----------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.covariance(
+    source_table varchar, --  input table name
+    output_table varchar, -- output table name
+    target_cols  varchar  -- comma separated list of output cols (default = '*')
+)
+RETURNS MADLIB_SCHEMA.correlation_result AS $$
+    select MADLIB_SCHEMA.covariance($1, $2, $3, FALSE)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-----------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.covariance(
+    source_table varchar, --  input table name
+    output_table varchar  -- output table name
+) RETURNS MADLIB_SCHEMA.correlation_result AS $$
+    select MADLIB_SCHEMA.covariance($1, $2, NULL)
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-----------------------------------------------------------------------
+-- Help functions
+-----------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.covariance(
+    input_message       text
+) RETURNS TEXT AS $$
+    PythonFunctionBodyOnly(`stats', `correlation')
+    return correlation.correlation_help_message(schema_madlib, input_message, cov=True)
+$$ LANGUAGE plpythonu IMMUTABLE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
+
+-----------------------------------------------------------------------
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.covariance()
+RETURNS text AS $$
+    PythonFunctionBodyOnly(`stats', `correlation')
+    return correlation.correlation_help_message(schema_madlib, None, cov=True)
+$$ LANGUAGE plpythonu IMMUTABLE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
+-------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/incubator-madlib/blob/ba5fc1ea/src/ports/postgres/modules/utilities/validate_args.py_in
----------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/utilities/validate_args.py_in b/src/ports/postgres/modules/utilities/validate_args.py_in
index 68bc862..fc4d142 100644
--- a/src/ports/postgres/modules/utilities/validate_args.py_in
+++ b/src/ports/postgres/modules/utilities/validate_args.py_in
@@ -282,23 +282,28 @@ def get_cols_and_types(tbl):
     if tbl is None or tbl.lower() == 'null':
         plpy.error('Input error: Table name (NULL) is invalid')
 
-    names = tbl.split(".")
+   # determine the exact table_schema and table_name
+    # in case that source_table only contains table_name
+    row = plpy.execute("""
+                        SELECT
+                            quote_ident(nspname) AS table_schema,
+                            quote_ident(relname) AS table_name
+                        FROM
+                            pg_class AS c,
+                            pg_namespace AS nsp
+                        WHERE
+                            c.oid = '{source_table}'::regclass::oid AND
+                            c.relnamespace = nsp.oid
+                        """.format(source_table=tbl))
+    schema = row[0]['table_schema']
+    table = row[0]['table_name']
 
-    if not names or len(names) > 2:
-        raise TypeError("Input error: Invalid table name - {0}!".format(tbl))
-    elif len(names) == 1:
-        table = _unquote_name(names[0])
-        schema = get_first_schema(table)
-    elif len(names) == 2:
-        schema = _unquote_name(names[0])
-        table = _unquote_name(names[1])
     sql_string = """SELECT array_agg(quote_ident(column_name)::varchar) AS cols,
                            array_agg(data_type::varchar) AS types
                     FROM information_schema.columns
-                    WHERE table_name = '{table_name}'
-                    AND table_schema = '{schema_name}'
-                """.format(table_name=table,
-                           schema_name=schema)
+                    WHERE table_name = '{table}'
+                    AND table_schema = '{schema}'
+                """.format(table=table, schema=schema)
     result = plpy.execute(sql_string)[0]
     col_names = _string_to_array(result['cols'])
     col_types = _string_to_array(result['types'])