You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@age.apache.org by jg...@apache.org on 2021/01/27 19:59:42 UTC
[incubator-age] branch master updated: Add aggregate function
collect
This is an automated email from the ASF dual-hosted git repository.
jgemignani pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-age.git
The following commit(s) were added to refs/heads/master by this push:
new c2cefa1 Add aggregate function collect
c2cefa1 is described below
commit c2cefa1f1eedae35d2343eba1ba4294c5998aa08
Author: John Gemignani <jr...@gmail.com>
AuthorDate: Tue Jan 19 15:28:34 2021 -0800
Add aggregate function collect
Added the openCypher aggregate function collect.
Added a fix to `fill_agtype_value` function. It was not making deep
copies for agtype_value strings or numerics. This caused their
pointers to potentially point to invalid data.
Added regression tests.
---
age--0.3.0.sql | 25 +++++++++
regress/expected/expr.out | 43 +++++++++++++++
regress/sql/expr.sql | 16 +++++-
src/backend/utils/adt/agtype.c | 102 ++++++++++++++++++++++++++++++++++++
src/backend/utils/adt/agtype_util.c | 20 +++++--
5 files changed, 202 insertions(+), 4 deletions(-)
diff --git a/age--0.3.0.sql b/age--0.3.0.sql
index 121ed29..2bee57b 100644
--- a/age--0.3.0.sql
+++ b/age--0.3.0.sql
@@ -1364,6 +1364,31 @@ CREATE AGGREGATE ag_catalog.age_percentiledisc(float8, float8)
);
--
+-- aggregate transfer/final functions for collect
+--
+CREATE FUNCTION ag_catalog.age_collect_aggtransfn(internal, variadic "any")
+RETURNS internal
+LANGUAGE c
+IMMUTABLE
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
+CREATE FUNCTION ag_catalog.age_collect_aggfinalfn(internal)
+RETURNS agtype
+LANGUAGE c
+IMMUTABLE
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
+
+CREATE AGGREGATE ag_catalog.age_collect(variadic "any")
+(
+ stype = internal,
+ sfunc = ag_catalog.age_collect_aggtransfn,
+ finalfunc = ag_catalog.age_collect_aggfinalfn,
+ parallel = safe
+);
+
+--
-- function for typecasting an agtype value to another agtype value
--
CREATE FUNCTION ag_catalog.agtype_typecast_numeric(agtype)
diff --git a/regress/expected/expr.out b/regress/expected/expr.out
index c714c78..af29780 100644
--- a/regress/expected/expr.out
+++ b/regress/expected/expr.out
@@ -4418,6 +4418,49 @@ ERROR: percentile value NULL is not a valid numeric value
SELECT * FROM cypher('UCSC', $$ RETURN percentileDisc(.5, NULL) $$) AS (percentileDisc agtype);
ERROR: percentile value NULL is not a valid numeric value
--
+-- aggregate function collect()
+--
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.name), collect(u.age), collect(u.gpa), collect(u.zip) $$)
+AS (name agtype, age agtype, gqa agtype, zip agtype);
+ name | age | gqa | zip
+----------------------------------------------------------------------------+--------------------------------------+--------------------------------------------------------+---------------------------------------
+ ["Jack", "Jill", "Jim", "Rick", "Ann", "Derek", "Jessica", "Dave", "Mike"] | [21, 27, 32, 24, 23, 19, 20, 24, 18] | [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric] | [94110, 95060, 96062, "95060", 90210]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.gpa), collect(u.gpa) $$)
+AS (gpa1 agtype, gpa2 agtype);
+ gpa1 | gpa2
+--------------------------------------------------------+--------------------------------------------------------
+ [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric] | [3.0, 3.5, 3.75, 2.5, 3.8::numeric, 4.0, 3.9::numeric]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.zip), collect(u.zip) $$)
+AS (zip1 agtype, zip2 agtype);
+ zip1 | zip2
+---------------------------------------+---------------------------------------
+ [94110, 95060, 96062, "95060", 90210] | [94110, 95060, 96062, "95060", 90210]
+(1 row)
+
+SELECT * FROM cypher('UCSC', $$ RETURN collect(5) $$) AS (result agtype);
+ result
+--------
+ [5]
+(1 row)
+
+-- should return an empty aray
+SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
+ empty
+-------
+ []
+(1 row)
+
+-- should fail
+SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);
+ERROR: function ag_catalog.age_collect() does not exist
+LINE 1: SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (col...
+ ^
+HINT: No function matches the given name and argument types. You might need to add explicit type casts.
+--
-- Cleanup
--
SELECT * FROM drop_graph('UCSC', true);
diff --git a/regress/sql/expr.sql b/regress/sql/expr.sql
index 3ff73a5..9df7ef2 100644
--- a/regress/sql/expr.sql
+++ b/regress/sql/expr.sql
@@ -1842,7 +1842,6 @@ SELECT * FROM cypher('UCSC', $$ RETURN stDevP() $$) AS (stDevP agtype);
--
SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileCont(u.gpa, .55), percentileDisc(u.gpa, .55), percentileCont(u.gpa, .9), percentileDisc(u.gpa, .9) $$)
AS (percentileCont1 agtype, percentileDisc1 agtype, percentileCont2 agtype, percentileDisc2 agtype);
-
SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileCont(u.gpa, .55) $$)
AS (percentileCont agtype);
SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN percentileDisc(u.gpa, .55) $$)
@@ -1855,6 +1854,21 @@ SELECT * FROM cypher('UCSC', $$ RETURN percentileCont(.5, NULL) $$) AS (percenti
SELECT * FROM cypher('UCSC', $$ RETURN percentileDisc(.5, NULL) $$) AS (percentileDisc agtype);
--
+-- aggregate function collect()
+--
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.name), collect(u.age), collect(u.gpa), collect(u.zip) $$)
+AS (name agtype, age agtype, gqa agtype, zip agtype);
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.gpa), collect(u.gpa) $$)
+AS (gpa1 agtype, gpa2 agtype);
+SELECT * FROM cypher('UCSC', $$ MATCH (u) RETURN collect(u.zip), collect(u.zip) $$)
+AS (zip1 agtype, zip2 agtype);
+SELECT * FROM cypher('UCSC', $$ RETURN collect(5) $$) AS (result agtype);
+-- should return an empty aray
+SELECT * FROM cypher('UCSC', $$ RETURN collect(NULL) $$) AS (empty agtype);
+-- should fail
+SELECT * FROM cypher('UCSC', $$ RETURN collect() $$) AS (collect agtype);
+
+--
-- Cleanup
--
SELECT * FROM drop_graph('UCSC', true);
diff --git a/src/backend/utils/adt/agtype.c b/src/backend/utils/adt/agtype.c
index 5014f60..d208c1f 100644
--- a/src/backend/utils/adt/agtype.c
+++ b/src/backend/utils/adt/agtype.c
@@ -2290,6 +2290,7 @@ static agtype *execute_map_access_operator(agtype *map, agtype *key)
case AGTV_STRING:
new_key_value.val.string = key_value->val.string;
+ new_key_value.val.string.len = key_value->val.string.len;
break;
default:
@@ -7473,3 +7474,104 @@ Datum age_percentile_disc_aggfinalfn(PG_FUNCTION_ARGS)
else
PG_RETURN_DATUM(val);
}
+
+/* functions to support the aggregate function COLLECT() */
+PG_FUNCTION_INFO_V1(age_collect_aggtransfn);
+
+Datum age_collect_aggtransfn(PG_FUNCTION_ARGS)
+{
+ agtype_in_state *castate;
+ int nargs;
+ Datum *args;
+ bool *nulls;
+ Oid *types;
+ MemoryContext old_mcxt;
+
+ /* verify we are in an aggregate context */
+ Assert(AggCheckCallContext(fcinfo, NULL) == AGG_CONTEXT_AGGREGATE);
+
+ /*
+ * Switch to the correct aggregate context. Otherwise, the data added to the
+ * array will be lost.
+ */
+ old_mcxt = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
+
+ /* if this is the first invocation, create the state */
+ if (PG_ARGISNULL(0))
+ {
+ /* create and initialize the state */
+ castate = palloc(sizeof(agtype_in_state));
+ memset(castate, 0, sizeof(agtype_in_state));
+ /* start the array */
+ castate->res = push_agtype_value(&castate->parse_state,
+ WAGT_BEGIN_ARRAY, NULL);
+ }
+ /* otherwise, retrieve the state */
+ else
+ castate = (agtype_in_state *) PG_GETARG_POINTER(0);
+
+ /*
+ * Extract the variadic args, of which there should only be one.
+ * Insert the arg into the array, unless it is null. Nulls are
+ * skipped over.
+ */
+ if (PG_ARGISNULL(1))
+ nargs = 0;
+ else
+ nargs = extract_variadic_args(fcinfo, 1, true, &args, &types, &nulls);
+
+ if (nargs == 1)
+ {
+ /* only add non null values */
+ if (nulls[0] == false)
+ {
+ /* we need to check for agtype null and skip it, if found */
+ if (types[0] == AGTYPEOID)
+ {
+ agtype *agt_arg;
+ agtype_value *agtv_value;
+
+ /* get the agtype argument */
+ agt_arg = DATUM_GET_AGTYPE_P(args[0]);
+ agtv_value = get_ith_agtype_value_from_container(&agt_arg->root,
+ 0);
+ /* add the arg if not agtype null */
+ if (agtv_value->type != AGTV_NULL)
+ add_agtype(args[0], nulls[0], castate, types[0], false);
+ }
+ else
+ add_agtype(args[0], nulls[0], castate, types[0], false);
+ }
+ }
+ else if (nargs > 1)
+ ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+ errmsg("collect() invalid number of arguments")));
+
+ /* restore the old context */
+ MemoryContextSwitchTo(old_mcxt);
+
+ /* return the state */
+ PG_RETURN_POINTER(castate);
+}
+
+PG_FUNCTION_INFO_V1(age_collect_aggfinalfn);
+
+Datum age_collect_aggfinalfn(PG_FUNCTION_ARGS)
+{
+ agtype_in_state *castate;
+ MemoryContext old_mcxt;
+
+ /* verify we are in an aggregate context */
+ Assert(AggCheckCallContext(fcinfo, NULL) == AGG_CONTEXT_AGGREGATE);
+ /* get the state */
+ castate = (agtype_in_state *) PG_GETARG_POINTER(0);
+ /* switch to the correct aggregate context */
+ old_mcxt = MemoryContextSwitchTo(fcinfo->flinfo->fn_mcxt);
+ /* Finish/close the array */
+ castate->res = push_agtype_value(&castate->parse_state, WAGT_END_ARRAY,
+ NULL);
+ /* restore the old context */
+ MemoryContextSwitchTo(old_mcxt);
+ /* return the agtype array */
+ PG_RETURN_POINTER(agtype_value_to_agtype(castate->res));
+}
diff --git a/src/backend/utils/adt/agtype_util.c b/src/backend/utils/adt/agtype_util.c
index 904f400..f73a8e4 100644
--- a/src/backend/utils/adt/agtype_util.c
+++ b/src/backend/utils/adt/agtype_util.c
@@ -545,15 +545,29 @@ static void fill_agtype_value(agtype_container *container, int index,
}
else if (AGTE_IS_STRING(entry))
{
+ char *string_val;
+ int string_len;
+
result->type = AGTV_STRING;
- result->val.string.val = base_addr + offset;
- result->val.string.len = get_agtype_length(container, index);
+ /* get the position and length of the string */
+ string_val = base_addr + offset;
+ string_len = get_agtype_length(container, index);
+ /* we need to do a deep copy of the string value */
+ result->val.string.val = pnstrdup(string_val, string_len);
+ result->val.string.len = string_len;
Assert(result->val.string.len >= 0);
}
else if (AGTE_IS_NUMERIC(entry))
{
+ Numeric numeric;
+ Numeric numeric_copy;
+
result->type = AGTV_NUMERIC;
- result->val.numeric = (Numeric)(base_addr + INTALIGN(offset));
+ /* we need to do a deep copy here */
+ numeric = (Numeric)(base_addr + INTALIGN(offset));
+ numeric_copy = (Numeric) palloc(VARSIZE(numeric));
+ memcpy(numeric_copy, numeric, VARSIZE(numeric));
+ result->val.numeric = numeric_copy;
}
/*
* If this is an agtype.