You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@madlib.apache.org by mktal <gi...@git.apache.org> on 2016/06/08 15:50:47 UTC

[GitHub] incubator-madlib pull request #42: Prediction Metrics: New module

Github user mktal commented on a diff in the pull request:

    https://github.com/apache/incubator-madlib/pull/42#discussion_r66282751
  
    --- Diff: src/ports/postgres/modules/stats/pred_metrics.py_in ---
    @@ -0,0 +1,562 @@
    +# coding=utf-8
    +#
    +# Licensed to the Apache Software Foundation (ASF) under one
    +# or more contributor license agreements.  See the NOTICE file
    +# distributed with this work for additional information
    +# regarding copyright ownership.  The ASF licenses this file
    +# to you under the Apache License, Version 2.0 (the
    +# "License"); you may not use this file except in compliance
    +# with the License.  You may obtain a copy of the License at
    +#
    +#   http://www.apache.org/licenses/LICENSE-2.0
    +#
    +# Unless required by applicable law or agreed to in writing,
    +# software distributed under the License is distributed on an
    +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    +# KIND, either express or implied.  See the License for the
    +# specific language governing permissions and limitations
    +# under the License.
    +
    +# Prediction Metrics
    +# This module provides a set of prediction accuracy metrics. It is a support
    +# module for several machine learning algorithms that require metrics to
    +# validate their models. A typical function will take a set of "prediction" and
    +# "observation" values to calculate the desired metric, unless noted otherwise.
    +# Grouping is supported by all of these functions (except confusion matrix).
    +
    +# Please refer to the pred_metrics.sql_in file for the documentation
    +
    +import plpy
    +from utilities.validate_args import input_tbl_valid, output_tbl_valid, is_var_valid
    +from utilities.utilities import _assert
    +from utilities.utilities import split_quoted_delimited_str
    +
    +
    +def _validate_args(table_in, table_out, validate_cols):
    +    input_tbl_valid(table_in, "Prediction Metrics")
    +    output_tbl_valid(table_out, "Prediction Metrics")
    +    is_var_valid(table_in, ', '.join(validate_cols))
    +
    +
    +def _parse_grp_col_str(grp_col_str):
    +    group_set = set(split_quoted_delimited_str(grp_col_str))
    +    return list(group_set)
    +# ----------------------------------------------------------------------
    +
    +
    +def _create_output_table(table_in, table_out, agg_fun, agg_name, grp_col_str=None):
    +    """ Create an output table with optional groups
    +
    +    General template function that builds an output table with grouping while
    +    applying an aggregate function.
    +
    +    Args:
    +        @param agg_fun: str, SQL aggregate to be executed
    +        @param grp_cols: str, Comma-separated list of column names
    +    """
    +    grp_cols = _parse_grp_col_str(grp_col_str)
    +    _validate_args(table_in, table_out, grp_cols)
    +    if not grp_cols:
    +        grp_by_str = grp_out_str = ""
    +    else:
    +        grp_by_str = "GROUP BY " + grp_col_str
    +        grp_out_str = grp_col_str + ", "
    +    plpy.execute("""
    +                 CREATE TABLE {table_out} AS
    +                 SELECT
    +                    {grp_out_str}
    +                    {agg_fun} AS {agg_name}
    +                 FROM {table_in}
    +                 {grp_by_str}
    +                 """.format(**locals()))
    +
    +
    +# Mean Absolute Error.
    +def mean_abs_error(table_in, table_out, pred_col, obs_col, grp_cols=None):
    +    mean_abs_agg = "AVG(ABS({0} - {1}))".format(pred_col, obs_col)
    +    _create_output_table(table_in, table_out, mean_abs_agg, "mean_abs_error", grp_cols)
    +
    +
    +# Mean Absolute Percentage Error.
    +def mean_abs_perc_error(table_in, table_out, pred_col, obs_col, grp_cols=None):
    +    mean_abs_perc_agg = "AVG(ABS({0} - {1})/NULLIF({1}, 0))".format(pred_col, obs_col)
    +    _create_output_table(table_in, table_out, mean_abs_perc_agg, "mean_abs_perc_error", grp_cols)
    +
    +
    +# Mean Percentage Error.
    +def mean_perc_error(table_in, table_out, pred_col, obs_col, grp_cols=None):
    +    mean_perc_agg = "AVG(({0} - {1})/NULLIF({1}, 0))".format(pred_col, obs_col)
    +    _create_output_table(table_in, table_out, mean_perc_agg, "mean_perc_error", grp_cols)
    +
    +
    +# Mean Squared Error.
    +def mean_squared_error(table_in, table_out, pred_col, obs_col, grp_cols=None):
    +    mean_sq_agg = "AVG(({0} - {1})^2)".format(pred_col, obs_col)
    +    _create_output_table(table_in, table_out, mean_sq_agg, "mean_squared_error", grp_cols)
    +
    +
    +def metric_agg_help_msg(schema_madlib, message, agg_name, **kwargs):
    +
    +    if not message:
    +        help_string = """
    +------------------------------------------------------------
    +                        SUMMARY
    +------------------------------------------------------------
    +Functionality: Evaluate prediction results using metric functions.
    +
    +This module provides a set of prediction accuracy metrics. It is a support
    +module for several machine learning algorithms that require metrics to validate
    +their models. The function will take "prediction" and "observation" values to
    +calculate the desired metric. Grouping is supported by all of these functions.
    +    """
    +    elif message.lower().strip() in ['usage', 'help', '?']:
    +        help_string = """
    +------------------------------------------------------------
    +                        USAGE
    +------------------------------------------------------------
    +SELECT {schema_madlib}.{agg_name}(
    +    'table_in',     -- Name of the input table
    +    'table_out',    -- Table name to store the metric results
    +    'pred_col',     -- Column name containing prediction results
    +    'obs_col',      -- Column name containing observed (actual) values
    +    'grouping_cols' -- Comma-separated list of columns to use as group-by
    +);
    +    """
    +    else:
    +        help_string = "No such option. Use {schema_madlib}.{agg_name}('usage')"
    +    return help_string.format(**locals())
    +
    +
    +def _get_r2_score_sql(table_in, pred_col, obs_col, grp_col_str=None):
    +    """ Generate the SQL query to compute r2 score.
    +
    +    This function abstracts the SQL to calculate r2 score from actually building
    +    the output table. This allows reusing the query for adjusted r2 function.
    +
    +    Args:
    +        @param table_in: str, Input table name containing the data
    +        @param pred_col: str, Column name containing the predictions
    +        @param obs_col: str, Column name containing the actual observed class
    +        @param grp_col_str: str, Comma-separated list of columns to group by
    +
    +    Definition:
    +        r2 = 1 - SS_res / SS_tot
    +        where SS_res = sum (pred - obs)^2
    +              SS_tot = sum (obs - mean)^2
    +
    +    """
    +    if grp_col_str:
    +        grp_out_str = grp_col_str + ","
    +        grp_by_str = "GROUP BY " + grp_col_str
    +        partition_str = "PARTITION BY " + grp_col_str
    +    else:
    +        grp_out_str = grp_by_str = partition_str = ""
    +    return """
    +            SELECT {grp_out_str}
    +                1 - ssres/sstot AS r2_score
    +            FROM (
    +                SELECT {grp_out_str}
    +                       sum(({obs_col} - mean)^2) as sstot,
    +                       sum(({pred_col} - {obs_col})^2) AS ssres
    +                FROM(
    +                    SELECT {grp_out_str}
    +                           {pred_col}, {obs_col},
    +                           avg({obs_col}) OVER ({partition_str}) as mean
    +                    FROM {table_in}
    +                ) x {grp_by_str}
    +            ) y
    --- End diff --
    
    Why not 
    ```sql
    SELECT 
         {grp_out_str}
         1 - avg(({pred_col} - {obs_col})^2)/var_pop({obs_col}) AS r2_score
    FROM {table_in} {grp_by_str}
    ```
    It is simpler, faster (2-3 times in my quick experiments) and numerically stable (avoid large sum)


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---