This is an automated email from the ASF dual-hosted git repository.

dehowef pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/age.git


The following commit(s) were added to refs/heads/master by this push:
     new 68e7524b Fix SET bug, issue #899 (#904)
68e7524b is described below

commit 68e7524becdebe36e83952049b43b0d04015a0aa
Author: John Gemignani <[email protected]>
AuthorDate: Tue May 9 07:53:21 2023 -0700

    Fix SET bug, issue #899 (#904)
---
 age--1.2.0.sql                     |  17 +++---
 regress/expected/cypher_set.out    |  20 +++++++
 regress/sql/cypher_set.sql         |  12 +++++
 src/backend/parser/cypher_clause.c |   2 +-
 src/backend/utils/adt/agtype.c     | 106 +++++++++++++++++++++++++++++++++++++
 5 files changed, 147 insertions(+), 10 deletions(-)

diff --git a/age--1.2.0.sql b/age--1.2.0.sql
index 6588eb1d..72e96e97 100644
--- a/age--1.2.0.sql
+++ b/age--1.2.0.sql
@@ -3158,15 +3158,14 @@ AS 'MODULE_PATHNAME';
 -- functions we need. Wrap the function with this to
 -- prevent that from happening
 --
-CREATE FUNCTION ag_catalog.agtype_volatile_wrapper(agt agtype)
-RETURNS agtype AS $return_value$
-BEGIN
-       RETURN agt;
-END;
-$return_value$ LANGUAGE plpgsql
+
+CREATE FUNCTION ag_catalog.agtype_volatile_wrapper("any")
+RETURNS agtype
+LANGUAGE c
 VOLATILE
 CALLED ON NULL INPUT
-PARALLEL SAFE;
+PARALLEL SAFE
+AS 'MODULE_PATHNAME';
 
 --
 -- agtype - list literal (`[expr, ...]`)
@@ -4203,8 +4202,8 @@ CALLED ON NULL INPUT
 PARALLEL SAFE
 AS 'MODULE_PATHNAME';
 
-CREATE FUNCTION ag_catalog.age_create_barbell_graph(graph_name name, 
-                                                graph_size int, 
+CREATE FUNCTION ag_catalog.age_create_barbell_graph(graph_name name,
+                                                graph_size int,
                                                 bridge_size int,
                                                 node_label name = NULL,
                                                 node_properties agtype = NULL,
diff --git a/regress/expected/cypher_set.out b/regress/expected/cypher_set.out
index adb2d28c..2caea937 100644
--- a/regress/expected/cypher_set.out
+++ b/regress/expected/cypher_set.out
@@ -806,6 +806,26 @@ $$) AS (p agtype);
  {"id": 2251799813685249, "label": "Robert", "properties": {"age": 47, "city": 
"London", "name": "Rob"}}::vertex
 (1 row)
 
+--
+-- Check passing mismatched types with SET
+-- Issue 899
+--
+SELECT * FROM cypher('cypher_set_1', $$
+    CREATE (x) SET x.n0 = (true OR true) RETURN x
+$$) AS (p agtype);
+                                    p                                     
+--------------------------------------------------------------------------
+ {"id": 281474976710657, "label": "", "properties": {"n0": true}}::vertex
+(1 row)
+
+SELECT * FROM cypher('cypher_set_1', $$
+    CREATE (x) SET x.n0 = (true OR false), x.n1 = (false AND false), x.n2 = 
(false = false) RETURN x
+$$) AS (p agtype);
+                                                 p                             
                    
+---------------------------------------------------------------------------------------------------
+ {"id": 281474976710658, "label": "", "properties": {"n0": true, "n1": false, 
"n2": true}}::vertex
+(1 row)
+
 --
 -- Clean up
 --
diff --git a/regress/sql/cypher_set.sql b/regress/sql/cypher_set.sql
index 205cf804..59bd0790 100644
--- a/regress/sql/cypher_set.sql
+++ b/regress/sql/cypher_set.sql
@@ -275,6 +275,18 @@ SELECT * FROM cypher('cypher_set_1', $$
     RETURN p
 $$) AS (p agtype);
 
+--
+-- Check passing mismatched types with SET
+-- Issue 899
+--
+SELECT * FROM cypher('cypher_set_1', $$
+    CREATE (x) SET x.n0 = (true OR true) RETURN x
+$$) AS (p agtype);
+
+SELECT * FROM cypher('cypher_set_1', $$
+    CREATE (x) SET x.n0 = (true OR false), x.n1 = (false AND false), x.n2 = 
(false = false) RETURN x
+$$) AS (p agtype);
+
 --
 -- Clean up
 --
diff --git a/src/backend/parser/cypher_clause.c 
b/src/backend/parser/cypher_clause.c
index 285ef318..98233b9e 100644
--- a/src/backend/parser/cypher_clause.c
+++ b/src/backend/parser/cypher_clause.c
@@ -5457,7 +5457,7 @@ static Expr *add_volatile_wrapper(Expr *node)
 {
     Oid oid;
 
-    oid = get_ag_func_oid("agtype_volatile_wrapper", 1, AGTYPEOID);
+    oid = get_ag_func_oid("agtype_volatile_wrapper", 1, ANYOID);
 
     return (Expr *)makeFuncExpr(oid, AGTYPEOID, list_make1(node), InvalidOid,
                                 InvalidOid, COERCE_EXPLICIT_CALL);
diff --git a/src/backend/utils/adt/agtype.c b/src/backend/utils/adt/agtype.c
index 8ca893c8..d00acc3f 100644
--- a/src/backend/utils/adt/agtype.c
+++ b/src/backend/utils/adt/agtype.c
@@ -10145,3 +10145,109 @@ Datum age_unnest(PG_FUNCTION_ARGS)
 
     PG_RETURN_NULL();
 }
+
+/*
+ * Volatile wrapper replacement. The previous version was PL/SQL
+ * and could only handle AGTYPE input and returned AGTYPE output.
+ * This version will create the appropriate AGTYPE based off of
+ * the input type.
+ */
+PG_FUNCTION_INFO_V1(agtype_volatile_wrapper);
+
+Datum agtype_volatile_wrapper(PG_FUNCTION_ARGS)
+{
+    int nargs = PG_NARGS();
+    Oid type = InvalidOid;
+    bool isnull = PG_ARGISNULL(0);
+
+    /* check for null and pass it through */
+    if (isnull)
+    {
+        PG_RETURN_NULL();
+    }
+
+    /* check for more than one argument */
+    if (nargs > 1)
+    {
+        ereport(ERROR,
+                (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+                 errmsg("agtype_volatile_wrapper: too many args")));
+
+    }
+
+    /* get the type of the input argument */
+    type = get_fn_expr_argtype(fcinfo->flinfo, 0);
+
+    /* if it is NOT an AGTYPE, we need convert it to one, if possible */
+    if (type != AGTYPEOID)
+    {
+        agtype_value agtv_result;
+        Datum arg = PG_GETARG_DATUM(0);
+
+        /* check for PG types that easily translate to AGTYPE */
+        if (type == BOOLOID)
+        {
+            agtv_result.type = AGTV_BOOL;
+            agtv_result.val.boolean = DatumGetBool(arg);
+        }
+        else if (type == INT2OID || type == INT4OID || type == INT8OID)
+        {
+            agtv_result.type = AGTV_INTEGER;
+
+            if (type == INT8OID)
+            {
+                agtv_result.val.int_value = DatumGetInt64(arg);
+            }
+            else if (type == INT4OID)
+            {
+                agtv_result.val.int_value = (int64) DatumGetInt32(arg);
+            }
+            else if (type == INT4OID)
+            {
+                agtv_result.val.int_value = (int64) DatumGetInt16(arg);
+            }
+        }
+        else if (type == FLOAT4OID || type == FLOAT8OID)
+        {
+            agtv_result.type = AGTV_FLOAT;
+
+            if (type == FLOAT8OID)
+            {
+                agtv_result.val.float_value = DatumGetFloat8(arg);
+            }
+            else if (type == FLOAT4OID)
+            {
+                agtv_result.val.float_value = (float8) DatumGetFloat4(arg);
+            }
+        }
+        else if (type == NUMERICOID)
+        {
+            agtv_result.type = AGTV_NUMERIC;
+            agtv_result.val.numeric = DatumGetNumeric(arg);
+        }
+        else if (type == CSTRINGOID)
+        {
+            agtv_result.type = AGTV_STRING;
+            agtv_result.val.string.val = DatumGetCString(arg);
+            agtv_result.val.string.len = strlen(agtv_result.val.string.val);
+        }
+        else if (type == TEXTOID)
+        {
+            agtv_result.type = AGTV_STRING;
+            agtv_result.val.string.val = text_to_cstring(DatumGetTextPP(arg));
+            agtv_result.val.string.len = strlen(agtv_result.val.string.val);
+        }
+        else
+        {
+            ereport(ERROR,
+                    (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+                     errmsg("agtype_volatile_wrapper: unsupported arg type")));
+        }
+
+        /* return the built result */
+        PG_RETURN_POINTER(agtype_value_to_agtype(&agtv_result));
+    }
+
+    /* otherwise, just pass it through */
+    PG_RETURN_POINTER(PG_GETARG_DATUM(0));
+}

Reply via email to