You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@age.apache.org by jg...@apache.org on 2021/01/27 19:59:42 UTC

[incubator-age] branch master updated: Add aggregate function collect

This is an automated email from the ASF dual-hosted git repository.

jgemignani pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-age.git


The following commit(s) were added to refs/heads/master by this push:
     new c2cefa1  Add aggregate function collect
c2cefa1 is described below

commit c2cefa1f1eedae35d2343eba1ba4294c5998aa08
Author: John Gemignani <jr...@gmail.com>
AuthorDate: Tue Jan 19 15:28:34 2021 -0800

    Add aggregate function collect
    
    Added the openCypher aggregate function collect.
    
    Added a fix to `fill_agtype_value` function. It was not making deep
    copies for agtype_value strings or numerics. This caused their
    pointers to potentially point to invalid data.
    
    Added regression tests.
---
 age--0.3.0.sql                      |  25 +++++++++
 regress/expected/expr.out           |  43 +++++++++++++++
 regress/sql/expr.sql                |  16 +++++-
 src/backend/utils/adt/agtype.c      | 102 ++++++++++++++++++++++++++++++++++++
 src/backend/utils/adt/agtype_util.c |  20 +++++--
 5 files changed, 202 insertions(+), 4 deletions(-)

diff --git a/age--0.3.0.sql b/age--0.3.0.sql
index 121ed29..2bee57b 100644
--- a/age--0.3.0.sql
+++ b/age--0.3.0.sql
@@ -1364,6 +1364,31 @@ CREATE AGGREGATE ag_catalog.age_percentiledisc(float8, float8)
 );
 
 --
+-- aggregate transfer/final functions for collect
+--
+CREATE FUNCTION ag_catalog.age_collect_aggtransfn(internal, variadic "any")
+RETURNS internal
+LANGUAGE c
+IMMUTABLE
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
+CREATE FUNCTION ag_catalog.age_collect_aggfinalfn(internal)
+RETURNS agtype
+LANGUAGE c
+IMMUTABLE
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
+CREATE AGGREGATE ag_catalog.age_collect(variadic "any")
+(
+    stype = internal,
+    sfunc = ag_catalog.age_collect_aggtransfn,
+    finalfunc = ag_catalog.age_collect_aggfinalfn,
+    parallel = safe
+);
+
+--
 -- function for typecasting an agtype value to another agtype value
 --
 CREATE FUNCTION ag_catalog.agtype_typecast_numeric(agtype)
diff --git a/regress/expected/expr.out b/regress/expected/expr.out
index c714c78..af29780 100644
--- a/regress/expected/expr.out
+++ b/regress/expected/expr.out
@@ -4418,6 +4418,49 @@ ERROR:  percentile value NULL is not a valid numeric value
 SELECT * FROM cypher('UCSC', $$ RETURN percentileDisc(.5, NULL) $$) AS (percentileDisc agtype);
 ERROR:  percentile value NULL is not a valid numeric value
 --
+-- aggregate function collect()
+--
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.name), collect(u.age), collect(u.gpa), collect(u.zip) $$)
+AS (name agtype, age agtype, gqa agtype, zip agtype);
+                                    name                                    |                 age                  |                          gqa                           |                  zip                  
+----------------------------------------------------------------------------+--------------------------------------+--------------------------------------------------------+---------------------------------------
+ ["Jack", "Jill", "Jim", "Rick", "Ann", "Derek", "Jessica", "Dave", "Mike"] | [21, 27, 32, 24, 23, 19, 20, 24, 18] | [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric] | [94110, 95060, 96062, "95060", 90210]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.gpa), collect(u.gpa) $$)
+AS (gpa1 agtype, gpa2 agtype);
+                          gpa1                          |                          gpa2                          
+--------------------------------------------------------+--------------------------------------------------------
+ [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric] | [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.zip), collect(u.zip) $$)
+AS (zip1 agtype, zip2 agtype);
+                 zip1                  |                 zip2                  
+---------------------------------------+---------------------------------------
+ [94110, 95060, 96062, "95060", 90210] | [94110, 95060, 96062, "95060", 90210]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ RETURN collect(5) $$) AS (result agtype);
+ result 
+--------
+ [5]
+(1 row)
+
+-- should return an empty aray
+SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
+ empty 
+-------
+ []
+(1 row)
+
+-- should fail
+SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);
+ERROR:  function ag_catalog.age_collect() does not exist
+LINE 1: SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (col...
+                                               ^
+HINT:  No function matches the given name and argument types. You might need to add explicit type casts.
+--
 -- Cleanup
 --
 SELECT * FROM drop_graph('UCSC', true);
diff --git a/regress/sql/expr.sql b/regress/sql/expr.sql
index 3ff73a5..9df7ef2 100644
--- a/regress/sql/expr.sql
+++ b/regress/sql/expr.sql
@@ -1842,7 +1842,6 @@ SELECT * FROM cypher('UCSC', $$ RETURN stDevP() $$) AS (stDevP agtype);
 --
 SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileCont(u.gpa, .55), percentileDisc(u.gpa, .55), percentileCont(u.gpa, .9), percentileDisc(u.gpa, .9) $$)
 AS (percentileCont1 agtype, percentileDisc1 agtype, percentileCont2 agtype, percentileDisc2 agtype);
-
 SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileCont(u.gpa, .55) $$)
 AS (percentileCont agtype);
 SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileDisc(u.gpa, .55) $$)
@@ -1855,6 +1854,21 @@ SELECT * FROM cypher('UCSC', $$ RETURN percentileCont(.5, NULL) $$) AS (percenti
 SELECT * FROM cypher('UCSC', $$ RETURN percentileDisc(.5, NULL) $$) AS (percentileDisc agtype);
 
 --
+-- aggregate function collect()
+--
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.name), collect(u.age), collect(u.gpa), collect(u.zip) $$)
+AS (name agtype, age agtype, gqa agtype, zip agtype);
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.gpa), collect(u.gpa) $$)
+AS (gpa1 agtype, gpa2 agtype);
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.zip), collect(u.zip) $$)
+AS (zip1 agtype, zip2 agtype);
+SELECT * FROM cypher('UCSC', $$ RETURN collect(5) $$) AS (result agtype);
+-- should return an empty aray
+SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
+-- should fail
+SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);
+
+--
 -- Cleanup
 --
 SELECT * FROM drop_graph('UCSC', true);
diff --git a/src/backend/utils/adt/agtype.c b/src/backend/utils/adt/agtype.c
index 5014f60..d208c1f 100644
--- a/src/backend/utils/adt/agtype.c
+++ b/src/backend/utils/adt/agtype.c
@@ -2290,6 +2290,7 @@ static agtype *execute_map_access_operator(agtype *map, agtype *key)
 
     case AGTV_STRING:
         new_key_value.val.string = key_value->val.string;
+        new_key_value.val.string.len = key_value->val.string.len;
         break;
 
     default:
@@ -7473,3 +7474,104 @@ Datum age_percentile_disc_aggfinalfn(PG_FUNCTION_ARGS)
     else
         PG_RETURN_DATUM(val);
 }
+
+/* functions to support the aggregate function COLLECT() */
+PG_FUNCTION_INFO_V1(age_collect_aggtransfn);
+
+Datum age_collect_aggtransfn(PG_FUNCTION_ARGS)
+{
+    agtype_in_state *castate;
+    int nargs;
+    Datum *args;
+    bool *nulls;
+    Oid *types;
+    MemoryContext old_mcxt;
+
+    /* verify we are in an aggregate context */
+    Assert(AggCheckCallContext(fcinfo, NULL) == AGG_CONTEXT_AGGREGATE);
+
+    /*
+     * Switch to the correct aggregate context. Otherwise, the data added to the
+     * array will be lost.
+     */
+    old_mcxt = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
+
+    /* if this is the first invocation, create the state */
+    if (PG_ARGISNULL(0))
+    {
+        /* create and initialize the state */
+        castate = palloc(sizeof(agtype_in_state));
+        memset(castate, 0, sizeof(agtype_in_state));
+        /* start the array */
+        castate->res = push_agtype_value(&castate->parse_state,
+                                         WAGT_BEGIN_ARRAY, NULL);
+    }
+    /* otherwise, retrieve the state */
+    else
+        castate = (agtype_in_state *) PG_GETARG_POINTER(0);
+
+    /*
+     * Extract the variadic args, of which there should only be one.
+     * Insert the arg into the array, unless it is null. Nulls are
+     * skipped over.
+     */
+    if (PG_ARGISNULL(1))
+        nargs = 0;
+    else
+        nargs = extract_variadic_args(fcinfo, 1, true, &args, &types, &nulls);
+
+    if (nargs == 1)
+    {
+        /* only add non null values */
+        if (nulls[0] == false)
+        {
+            /* we need to check for agtype null and skip it, if found */
+            if (types[0] == AGTYPEOID)
+            {
+                agtype *agt_arg;
+                agtype_value *agtv_value;
+
+                /* get the agtype argument */
+                agt_arg = DATUM_GET_AGTYPE_P(args[0]);
+                agtv_value = get_ith_agtype_value_from_container(&agt_arg->root,
+                                                                 0);
+                /* add the arg if not agtype null */
+                if (agtv_value->type != AGTV_NULL)
+                    add_agtype(args[0], nulls[0], castate, types[0], false);
+            }
+            else
+                add_agtype(args[0], nulls[0], castate, types[0], false);
+        }
+    }
+    else if (nargs > 1)
+        ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+                        errmsg("collect() invalid number of arguments")));
+
+    /* restore the old context */
+    MemoryContextSwitchTo(old_mcxt);
+
+    /* return the state */
+    PG_RETURN_POINTER(castate);
+}
+
+PG_FUNCTION_INFO_V1(age_collect_aggfinalfn);
+
+Datum age_collect_aggfinalfn(PG_FUNCTION_ARGS)
+{
+    agtype_in_state *castate;
+    MemoryContext old_mcxt;
+
+    /* verify we are in an aggregate context */
+    Assert(AggCheckCallContext(fcinfo, NULL) == AGG_CONTEXT_AGGREGATE);
+    /* get the state */
+    castate = (agtype_in_state *) PG_GETARG_POINTER(0);
+    /* switch to the correct aggregate context */
+    old_mcxt = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
+    /* Finish/close the array */
+    castate->res = push_agtype_value(&castate->parse_state, WAGT_END_ARRAY,
+                                     NULL);
+    /* restore the old context */
+    MemoryContextSwitchTo(old_mcxt);
+    /* return the agtype array */
+    PG_RETURN_POINTER(agtype_value_to_agtype(castate->res));
+}
diff --git a/src/backend/utils/adt/agtype_util.c b/src/backend/utils/adt/agtype_util.c
index 904f400..f73a8e4 100644
--- a/src/backend/utils/adt/agtype_util.c
+++ b/src/backend/utils/adt/agtype_util.c
@@ -545,15 +545,29 @@ static void fill_agtype_value(agtype_container *container, int index,
     }
     else if (AGTE_IS_STRING(entry))
     {
+        char *string_val;
+        int string_len;
+
         result->type = AGTV_STRING;
-        result->val.string.val = base_addr + offset;
-        result->val.string.len = get_agtype_length(container, index);
+        /* get the position and length of the string */
+        string_val = base_addr + offset;
+        string_len = get_agtype_length(container, index);
+        /* we need to do a deep copy of the string value */
+        result->val.string.val = pnstrdup(string_val, string_len);
+        result->val.string.len = string_len;
         Assert(result->val.string.len >= 0);
     }
     else if (AGTE_IS_NUMERIC(entry))
     {
+        Numeric numeric;
+        Numeric numeric_copy;
+
         result->type = AGTV_NUMERIC;
-        result->val.numeric = (Numeric)(base_addr + INTALIGN(offset));
+        /* we need to do a deep copy here */
+        numeric = (Numeric)(base_addr + INTALIGN(offset));
+        numeric_copy = (Numeric) palloc(VARSIZE(numeric));
+        memcpy(numeric_copy, numeric, VARSIZE(numeric));
+        result->val.numeric = numeric_copy;
     }
     /*
      * If this is an agtype.