You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@age.apache.org by de...@apache.org on 2022/11/24 00:01:29 UTC

[age] branch master updated: Modify the python driver's parameterization

This is an automated email from the ASF dual-hosted git repository.

dehowef pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/age.git


The following commit(s) were added to refs/heads/master by this push:
     new ed44f91  Modify the python driver's parameterization
ed44f91 is described below

commit ed44f91cb2607605ec2e225248e9ca46e146c215
Author: Dehowe Feng <de...@gmail.com>
AuthorDate: Wed Nov 23 15:33:44 2022 -0800

    Modify the python driver's parameterization
    
    Modified the python driver to pass parameters to the cypher() function
    via age_prepare_cypher()
---
 drivers/python/age/age.py        | 53 +++++++++++++++++++++++++++++-----------
 drivers/python/age/exceptions.py |  2 +-
 2 files changed, 40 insertions(+), 15 deletions(-)

diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py
index 2c59610..7ba4a27 100644
--- a/drivers/python/age/age.py
+++ b/drivers/python/age/age.py
@@ -17,6 +17,7 @@ import re
 import psycopg2 
 from psycopg2 import errors
 from psycopg2 import extensions as ext
+from psycopg2 import sql
 from .exceptions import *
 from .builder import ResultHandler , parseAgeValue, newResultHandler
 
@@ -47,15 +48,15 @@ def setUpAge(conn:ext.connection, graphName:str):
 # Create the graph, if it does not exist
 def checkGraphCreated(conn:ext.connection, graphName:str):
     with conn.cursor() as cursor:
-        cursor.execute("SELECT count(*) FROM ag_graph WHERE name=%s", (graphName,))
+        cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
         if cursor.fetchone()[0] == 0:
-            cursor.execute("SELECT create_graph(%s);", (graphName,))
+            cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
             conn.commit()
 
 
 def deleteGraph(conn:ext.connection, graphName:str):
     with conn.cursor() as cursor:
-        cursor.execute("SELECT drop_graph(%s, true);", (graphName,))
+        cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
         conn.commit()
     
 
@@ -76,11 +77,7 @@ def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
         columnExp.append('v agtype')
 
     stmtArr = []
-    stmtArr.append("SELECT * from cypher('")
-    stmtArr.append(graphName)
-    stmtArr.append("', $$ ")
-    stmtArr.append(cypherStmt)
-    stmtArr.append(" $$) as (")
+    stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
     stmtArr.append(','.join(columnExp))
     stmtArr.append(");")
     return "".join(stmtArr)
@@ -101,7 +98,7 @@ def execSql(conn:ext.connection, stmt:str, commit:bool=False, params:tuple=None)
         raise cause
     except Exception as cause:
         conn.rollback()
-        raise SqlExcutionError("Excution ERR[" + str(cause) +"](" + stmt +")", cause)
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
 
 
 def querySql(conn:ext.connection, stmt:str, params:tuple=None) -> ext.cursor :
@@ -115,23 +112,51 @@ def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=Non
     if conn == None or conn.closed:
         raise _EXCEPTION_NoConnection
 
-    stmt = buildCypher(graphName, cypherStmt, cols)
+    cursor = conn.cursor()
+    #clean up the string for mogrificiation
+    cypherStmt = cypherStmt.replace("\n", "")
+    cypherStmt = cypherStmt.replace("\t", "")
+    cypher = str(cursor.mogrify(cypherStmt, params))
+    cypher = cypher[2:len(cypher)-1]
+
+    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
     
     cursor = conn.cursor()
     try:
-        cursor.execute(stmt, params)
+        cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+    except SyntaxError as cause:
+        conn.rollback()
+        raise cause
+    except Exception as cause:
+        conn.rollback()
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause)
+
+    stmt = buildCypher(graphName, cypher, cols)
+
+    cursor = conn.cursor()
+    try:
+        cursor.execute(stmt)
         return cursor
     except SyntaxError as cause:
         conn.rollback()
         raise cause
     except Exception as cause:
         conn.rollback()
-        raise SqlExcutionError("Excution ERR[" + str(cause) +"](" + stmt +")", cause)
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
 
 
 def cypher(cursor:ext.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
-    stmt = buildCypher(graphName, cypherStmt, cols)
-    cursor.execute(stmt, params)
+    #clean up the string for mogrificiation
+    cypherStmt = cypherStmt.replace("\n", "")
+    cypherStmt = cypherStmt.replace("\t", "")
+    cypher = str(cursor.mogrify(cypherStmt, params))
+    cypher = cypher[2:len(cypher)-1]
+
+    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
+    cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+
+    stmt = buildCypher(graphName, cypher, cols)
+    cursor.execute(stmt)
 
 
 # def execCypherWithReturn(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
diff --git a/drivers/python/age/exceptions.py b/drivers/python/age/exceptions.py
index c023b53..7bbb5b4 100644
--- a/drivers/python/age/exceptions.py
+++ b/drivers/python/age/exceptions.py
@@ -52,7 +52,7 @@ class NoCursor(Exception):
     def __repr__(self) :
         return 'No Cursor'
 
-class SqlExcutionError(Exception):
+class SqlExecutionError(Exception):
     def __init__(self, msg, cause):
         self.msg = msg
         self.cause = cause