You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@age.apache.org by de...@apache.org on 2022/12/14 17:23:05 UTC

[age] branch PG12 updated: Modify the python driver's parameterization

This is an automated email from the ASF dual-hosted git repository.

dehowef pushed a commit to branch PG12
in repository https://gitbox.apache.org/repos/asf/age.git


The following commit(s) were added to refs/heads/PG12 by this push:
     new 74a415a  Modify the python driver's parameterization
74a415a is described below

commit 74a415ad4330e51a51f7a1a3356475152acecbc4
Author: Dehowe Feng <de...@gmail.com>
AuthorDate: Thu Dec 15 01:15:27 2022 +0800

    Modify the python driver's parameterization
    
    Modified the python driver to pass parameters to the cypher() function
    via age_prepare_cypher()
    
    Modified README for the python driver
---
 drivers/python/README.md         |   2 +-
 drivers/python/age/age.py        |  76 +++++++++++++--------
 drivers/python/age/exceptions.py | 140 +++++++++++++++++++--------------------
 drivers/python/test_age_py.py    |   8 +--
 4 files changed, 124 insertions(+), 102 deletions(-)

diff --git a/drivers/python/README.md b/drivers/python/README.md
index d01b885..5d97cf9 100644
--- a/drivers/python/README.md
+++ b/drivers/python/README.md
@@ -12,7 +12,7 @@ AGType parser and driver support for [Apache AGE](https://age.apache.org/), grap
 sudo apt-get update
 sudo apt-get install python3-dev libpq-dev
 pip install --no-binary :all: psycopg2
-pip install antlr4-python3-runtime
+pip install antlr4-python3-runtime==4.9.2
 
 ```
 ### Test
diff --git a/drivers/python/age/age.py b/drivers/python/age/age.py
index 2c59610..e7cc88d 100644
--- a/drivers/python/age/age.py
+++ b/drivers/python/age/age.py
@@ -17,6 +17,7 @@ import re
 import psycopg2 
 from psycopg2 import errors
 from psycopg2 import extensions as ext
+from psycopg2 import sql
 from .exceptions import *
 from .builder import ResultHandler , parseAgeValue, newResultHandler
 
@@ -47,17 +48,17 @@ def setUpAge(conn:ext.connection, graphName:str):
 # Create the graph, if it does not exist
 def checkGraphCreated(conn:ext.connection, graphName:str):
     with conn.cursor() as cursor:
-        cursor.execute("SELECT count(*) FROM ag_graph WHERE name=%s", (graphName,))
+        cursor.execute(sql.SQL("SELECT count(*) FROM ag_graph WHERE name={graphName}").format(graphName=sql.Literal(graphName)))
         if cursor.fetchone()[0] == 0:
-            cursor.execute("SELECT create_graph(%s);", (graphName,))
+            cursor.execute(sql.SQL("SELECT create_graph({graphName});").format(graphName=sql.Literal(graphName)))
             conn.commit()
 
 
 def deleteGraph(conn:ext.connection, graphName:str):
     with conn.cursor() as cursor:
-        cursor.execute("SELECT drop_graph(%s, true);", (graphName,))
+        cursor.execute(sql.SQL("SELECT drop_graph({graphName}, true);").format(graphName=sql.Literal(graphName)))
         conn.commit()
-    
+
 
 def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
     if graphName == None:
@@ -76,11 +77,7 @@ def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
         columnExp.append('v agtype')
 
     stmtArr = []
-    stmtArr.append("SELECT * from cypher('")
-    stmtArr.append(graphName)
-    stmtArr.append("', $$ ")
-    stmtArr.append(cypherStmt)
-    stmtArr.append(" $$) as (")
+    stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
     stmtArr.append(','.join(columnExp))
     stmtArr.append(");")
     return "".join(stmtArr)
@@ -94,44 +91,72 @@ def execSql(conn:ext.connection, stmt:str, commit:bool=False, params:tuple=None)
         cursor.execute(stmt, params)
         if commit:
             conn.commit()
-        
+
         return cursor
     except SyntaxError as cause:
         conn.rollback()
         raise cause
     except Exception as cause:
         conn.rollback()
-        raise SqlExcutionError("Excution ERR[" + str(cause) +"](" + stmt +")", cause)
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
 
 
 def querySql(conn:ext.connection, stmt:str, params:tuple=None) -> ext.cursor :
     return execSql(conn, stmt, False, params)
 
 # Execute cypher statement and return cursor.
-# If cypher statement changes data (create, set, remove), 
-# You must commit session(ag.commit()) 
+# If cypher statement changes data (create, set, remove),
+# You must commit session(ag.commit())
 # (Otherwise the execution cannot make any effect.)
 def execCypher(conn:ext.connection, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
     if conn == None or conn.closed:
         raise _EXCEPTION_NoConnection
 
-    stmt = buildCypher(graphName, cypherStmt, cols)
-    
+    cursor = conn.cursor()
+    #clean up the string for mogrificiation
+    cypherStmt = cypherStmt.replace("\n", "")
+    cypherStmt = cypherStmt.replace("\t", "")
+    cypher = str(cursor.mogrify(cypherStmt, params))
+    cypher = cypher[2:len(cypher)-1]
+
+    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
+
     cursor = conn.cursor()
     try:
-        cursor.execute(stmt, params)
+        cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+    except SyntaxError as cause:
+        conn.rollback()
+        raise cause
+    except Exception as cause:
+        conn.rollback()
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + preparedStmt +")", cause)
+
+    stmt = buildCypher(graphName, cypher, cols)
+
+    cursor = conn.cursor()
+    try:
+        cursor.execute(stmt)
         return cursor
     except SyntaxError as cause:
         conn.rollback()
         raise cause
     except Exception as cause:
         conn.rollback()
-        raise SqlExcutionError("Excution ERR[" + str(cause) +"](" + stmt +")", cause)
+        raise SqlExecutionError("Execution ERR[" + str(cause) +"](" + stmt +")", cause)
 
 
 def cypher(cursor:ext.cursor, graphName:str, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
-    stmt = buildCypher(graphName, cypherStmt, cols)
-    cursor.execute(stmt, params)
+    #clean up the string for mogrificiation
+    cypherStmt = cypherStmt.replace("\n", "")
+    cypherStmt = cypherStmt.replace("\t", "")
+    cypher = str(cursor.mogrify(cypherStmt, params))
+    cypher = cypher[2:len(cypher)-1]
+
+    preparedStmt = "SELECT * FROM age_prepare_cypher({graphName},{cypherStmt})"
+    cursor.execute(sql.SQL(preparedStmt).format(graphName=sql.Literal(graphName),cypherStmt=sql.Literal(cypher)))
+
+    stmt = buildCypher(graphName, cypher, cols)
+    cursor.execute(stmt)
 
 
 # def execCypherWithReturn(conn:ext.connection, graphName:str, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
@@ -165,10 +190,10 @@ class Age:
 
     def commit(self):
         self.connection.commit()
-        
+
     def rollback(self):
         self.connection.rollback()
-    
+
     def execCypher(self, cypherStmt:str, cols:list=None, params:tuple=None) -> ext.cursor :
         return execCypher(self.connection, self.graphName, cypherStmt, cols=cols, params=params)
 
@@ -177,8 +202,8 @@ class Age:
 
     # def execSql(self, stmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
     #     return execSql(self.connection, stmt, commit, params)
-        
-    
+
+
     # def execCypher(self, cypherStmt:str, commit:bool=False, params:tuple=None) -> ext.cursor :
     #     return execCypher(self.connection, self.graphName, cypherStmt, commit, params)
 
@@ -186,7 +211,4 @@ class Age:
     #     return execCypherWithReturn(self.connection, self.graphName, cypherStmt, columns, params)
 
     # def queryCypher(self, cypherStmt:str, columns:list=None , params:tuple=None) -> ext.cursor :
-    #     return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)
-
-
-
+    #     return queryCypher(self.connection, self.graphName, cypherStmt, columns, params)
\ No newline at end of file
diff --git a/drivers/python/age/exceptions.py b/drivers/python/age/exceptions.py
index c023b53..c29a719 100644
--- a/drivers/python/age/exceptions.py
+++ b/drivers/python/age/exceptions.py
@@ -1,70 +1,70 @@
-
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from psycopg2.errors import *
-
-class AgeNotSet(Exception):
-    def __init__(self, name):
-        self.name = name
-
-    def __repr__(self) :
-        return 'AGE extension is not set.' 
-
-class GraphNotFound(Exception):
-    def __init__(self, name):
-        self.name = name
-
-    def __repr__(self) :
-        return 'Graph[' + self.name + '] does not exist.' 
-
-
-class GraphAlreadyExists(Exception):
-    def __init__(self, name):
-        self.name = name
-
-    def __repr__(self) :
-        return 'Graph[' + self.name + '] already exists.' 
-
-        
-class GraphNotSet(Exception):
-    def __repr__(self) :
-        return 'Graph name is not set.'
-
-
-class NoConnection(Exception):
-    def __repr__(self) :
-        return 'No Connection'
-
-class NoCursor(Exception):
-    def __repr__(self) :
-        return 'No Cursor'
-
-class SqlExcutionError(Exception):
-    def __init__(self, msg, cause):
-        self.msg = msg
-        self.cause = cause
-        super().__init__(msg, cause)
-    
-    def __repr__(self) :
-        return 'SqlExcution [' + self.msg + ']'  
-
-class AGTypeError(Exception):
-    def __init__(self, msg, cause):
-        self.msg = msg
-        self.cause = cause
-        super().__init__(msg, cause)
-
-
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from psycopg2.errors import *
+
+class AgeNotSet(Exception):
+    def __init__(self, name):
+        self.name = name
+
+    def __repr__(self) :
+        return 'AGE extension is not set.'
+
+class GraphNotFound(Exception):
+    def __init__(self, name):
+        self.name = name
+
+    def __repr__(self) :
+        return 'Graph[' + self.name + '] does not exist.'
+
+
+class GraphAlreadyExists(Exception):
+    def __init__(self, name):
+        self.name = name
+
+    def __repr__(self) :
+        return 'Graph[' + self.name + '] already exists.'
+
+
+class GraphNotSet(Exception):
+    def __repr__(self) :
+        return 'Graph name is not set.'
+
+
+class NoConnection(Exception):
+    def __repr__(self) :
+        return 'No Connection'
+
+class NoCursor(Exception):
+    def __repr__(self) :
+        return 'No Cursor'
+
+class SqlExecutionError(Exception):
+    def __init__(self, msg, cause):
+        self.msg = msg
+        self.cause = cause
+        super().__init__(msg, cause)
+
+    def __repr__(self) :
+        return 'SqlExecution [' + self.msg + ']'
+
+class AGTypeError(Exception):
+    def __init__(self, msg, cause):
+        self.msg = msg
+        self.cause = cause
+        super().__init__(msg, cause)
+
+
diff --git a/drivers/python/test_age_py.py b/drivers/python/test_age_py.py
index 1758989..ddb398a 100644
--- a/drivers/python/test_age_py.py
+++ b/drivers/python/test_age_py.py
@@ -18,11 +18,11 @@ import unittest
 import decimal
 import age 
 
-DSN = "host=172.17.0.2 port=5432 dbname=postgres user=postgres password=agens"
-TEST_HOST = "172.17.0.2"
+DSN = "host=127.0.0.1 port=5432 dbname=postgres user=dehowefeng password=agens"
+TEST_HOST = "127.0.0.1"
 TEST_PORT = 5432
 TEST_DB = "postgres"
-TEST_USER = "postgres"
+TEST_USER = "dehowefeng"
 TEST_PASSWORD = "agens"
 TEST_GRAPH_NAME = "test_graph"
 
@@ -281,4 +281,4 @@ class TestAgeBasic(unittest.TestCase):
             self.assertEqual(3,len(collected))
 
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()