You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hawq.apache.org by rv...@apache.org on 2015/09/22 21:14:15 UTC
[12/35] incubator-hawq git commit: SGA import. Now with files
previously missing because of the .gitignore issue
http://git-wip-us.apache.org/repos/asf/incubator-hawq/blob/a485be47/tools/bin/ext/pg8000/interface.py
----------------------------------------------------------------------
diff --git a/tools/bin/ext/pg8000/interface.py b/tools/bin/ext/pg8000/interface.py
new file mode 100644
index 0000000..d2f70fa
--- /dev/null
+++ b/tools/bin/ext/pg8000/interface.py
@@ -0,0 +1,542 @@
+# vim: sw=4:expandtab:foldmethod=marker
+#
+# Copyright (c) 2007-2009, Mathieu Fenniak
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * The name of the author may not be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+__author__ = "Mathieu Fenniak"
+
+import socket
+import protocol
+import threading
+from errors import *
+
+class DataIterator(object):
+ def __init__(self, obj, func):
+ self.obj = obj
+ self.func = func
+
+ def __iter__(self):
+ return self
+
+ def next(self):
+ retval = self.func(self.obj)
+ if retval == None:
+ raise StopIteration()
+ return retval
+
+statement_number_lock = threading.Lock()
+statement_number = 0
+
+##
+# This class represents a prepared statement. A prepared statement is
+# pre-parsed on the server, which reduces the need to parse the query every
+# time it is run. The statement can have parameters in the form of $1, $2, $3,
+# etc. When parameters are used, the types of the parameters need to be
+# specified when creating the prepared statement.
+# <p>
+# As of v1.01, instances of this class are thread-safe. This means that a
+# single PreparedStatement can be accessed by multiple threads without the
+# internal consistency of the statement being altered. However, the
+# responsibility is on the client application to ensure that one thread reading
+# from a statement isn't affected by another thread starting a new query with
+# the same statement.
+# <p>
+# Stability: Added in v1.00, stability guaranteed for v1.xx.
+#
+# @param connection An instance of {@link Connection Connection}.
+#
+# @param statement The SQL statement to be represented, often containing
+# parameters in the form of $1, $2, $3, etc.
+#
+# @param types Python type objects for each parameter in the SQL
+# statement. For example, int, float, str.
+class PreparedStatement(object):
+
+ ##
+ # Determines the number of rows to read from the database server at once.
+ # Reading more rows increases performance at the cost of memory. The
+ # default value is 100 rows. The affect of this parameter is transparent.
+ # That is, the library reads more rows when the cache is empty
+ # automatically.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx. It is
+ # possible that implementation changes in the future could cause this
+ # parameter to be ignored.
+ row_cache_size = 100
+
+ def __init__(self, connection, statement, *types, **kwargs):
+ global statement_number
+ if connection == None or connection.c == None:
+ raise InterfaceError("connection not provided")
+ try:
+ statement_number_lock.acquire()
+ self._statement_number = statement_number
+ statement_number += 1
+ finally:
+ statement_number_lock.release()
+ self.c = connection.c
+ self._portal_name = None
+ self._statement_name = kwargs.get("statement_name", "pg8000_statement_%s" % self._statement_number)
+ self._row_desc = None
+ self._cached_rows = []
+ self._ongoing_row_count = 0
+ self._command_complete = True
+ self._parse_row_desc = self.c.parse(self._statement_name, statement, types)
+ self._lock = threading.RLock()
+
+ def close(self):
+ if self._statement_name != "": # don't close unnamed statement
+ self.c.close_statement(self._statement_name)
+ if self._portal_name != None:
+ self.c.close_portal(self._portal_name)
+ self._portal_name = None
+
+ row_description = property(lambda self: self._getRowDescription())
+ def _getRowDescription(self):
+ if self._row_desc == None:
+ return None
+ return self._row_desc.fields
+
+ ##
+ # Run the SQL prepared statement with the given parameters.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def execute(self, *args, **kwargs):
+ self._lock.acquire()
+ try:
+ if not self._command_complete:
+ # cleanup last execute
+ self._cached_rows = []
+ self._ongoing_row_count = 0
+ if self._portal_name != None:
+ self.c.close_portal(self._portal_name)
+ self._command_complete = False
+ self._portal_name = "pg8000_portal_%s" % self._statement_number
+ self._row_desc, cmd = self.c.bind(self._portal_name, self._statement_name, args, self._parse_row_desc, kwargs.get("stream"))
+ if self._row_desc:
+ # We execute our cursor right away to fill up our cache. This
+ # prevents the cursor from being destroyed, apparently, by a rogue
+ # Sync between Bind and Execute. Since it is quite likely that
+ # data will be read from us right away anyways, this seems a safe
+ # move for now.
+ self._fill_cache()
+ else:
+ self._command_complete = True
+ self._ongoing_row_count = -1
+ if cmd != None and cmd.rows != None:
+ self._ongoing_row_count = cmd.rows
+ finally:
+ self._lock.release()
+
+ def _fill_cache(self):
+ self._lock.acquire()
+ try:
+ if self._cached_rows:
+ raise InternalError("attempt to fill cache that isn't empty")
+ end_of_data, rows = self.c.fetch_rows(self._portal_name, self.row_cache_size, self._row_desc)
+ self._cached_rows = rows
+ if end_of_data:
+ self._command_complete = True
+ finally:
+ self._lock.release()
+
+ def _fetch(self):
+ if not self._row_desc:
+ raise ProgrammingError("no result set")
+ self._lock.acquire()
+ try:
+ if not self._cached_rows:
+ if self._command_complete:
+ return None
+ self._fill_cache()
+ if self._command_complete and not self._cached_rows:
+ # fill cache tells us the command is complete, but yet we have
+ # no rows after filling our cache. This is a special case when
+ # a query returns no rows.
+ return None
+ row = self._cached_rows.pop(0)
+ self._ongoing_row_count += 1
+ return tuple(row)
+ finally:
+ self._lock.release()
+
+ ##
+ # Return a count of the number of rows relevant to the executed statement.
+ # For a SELECT, this is the number of rows returned. For UPDATE or DELETE,
+ # this the number of rows affected. For INSERT, the number of rows
+ # inserted. This property may have a value of -1 to indicate that there
+ # was no row count.
+ # <p>
+ # During a result-set query (eg. SELECT, or INSERT ... RETURNING ...),
+ # accessing this property requires reading the entire result-set into
+ # memory, as reading the data to completion is the only way to determine
+ # the total number of rows. Avoid using this property in with
+ # result-set queries, as it may cause unexpected memory usage.
+ # <p>
+ # Stability: Added in v1.03, stability guaranteed for v1.xx.
+ row_count = property(lambda self: self._get_row_count())
+ def _get_row_count(self):
+ self._lock.acquire()
+ try:
+ if not self._command_complete:
+ end_of_data, rows = self.c.fetch_rows(self._portal_name, 0, self._row_desc)
+ self._cached_rows += rows
+ if end_of_data:
+ self._command_complete = True
+ else:
+ raise InternalError("fetch_rows(0) did not hit end of data")
+ return self._ongoing_row_count + len(self._cached_rows)
+ finally:
+ self._lock.release()
+
+ ##
+ # Read a row from the database server, and return it in a dictionary
+ # indexed by column name/alias. This method will raise an error if two
+ # columns have the same name. Returns None after the last row.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def read_dict(self):
+ row = self._fetch()
+ if row == None:
+ return row
+ retval = {}
+ for i in range(len(self._row_desc.fields)):
+ col_name = self._row_desc.fields[i]['name']
+ if retval.has_key(col_name):
+ raise InterfaceError("cannot return dict of row when two columns have the same name (%r)" % (col_name,))
+ retval[col_name] = row[i]
+ return retval
+
+ ##
+ # Read a row from the database server, and return it as a tuple of values.
+ # Returns None after the last row.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def read_tuple(self):
+ return self._fetch()
+
+ ##
+ # Return an iterator for the output of this statement. The iterator will
+ # return a tuple for each row, in the same manner as {@link
+ # #PreparedStatement.read_tuple read_tuple}.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def iterate_tuple(self):
+ return DataIterator(self, PreparedStatement.read_tuple)
+
+ ##
+ # Return an iterator for the output of this statement. The iterator will
+ # return a dict for each row, in the same manner as {@link
+ # #PreparedStatement.read_dict read_dict}.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def iterate_dict(self):
+ return DataIterator(self, PreparedStatement.read_dict)
+
+##
+# The Cursor class allows multiple queries to be performed concurrently with a
+# single PostgreSQL connection. The Cursor object is implemented internally by
+# using a {@link PreparedStatement PreparedStatement} object, so if you plan to
+# use a statement multiple times, you might as well create a PreparedStatement
+# and save a small amount of reparsing time.
+# <p>
+# As of v1.01, instances of this class are thread-safe. See {@link
+# PreparedStatement PreparedStatement} for more information.
+# <p>
+# Stability: Added in v1.00, stability guaranteed for v1.xx.
+#
+# @param connection An instance of {@link Connection Connection}.
+class Cursor(object):
+ def __init__(self, connection):
+ self.connection = connection
+ self._stmt = None
+
+ def require_stmt(func):
+ def retval(self, *args, **kwargs):
+ if self._stmt == None:
+ raise ProgrammingError("attempting to use unexecuted cursor")
+ return func(self, *args, **kwargs)
+ return retval
+
+ row_description = property(lambda self: self._getRowDescription())
+ def _getRowDescription(self):
+ if self._stmt == None:
+ return None
+ return self._stmt.row_description
+
+ ##
+ # Run an SQL statement using this cursor. The SQL statement can have
+ # parameters in the form of $1, $2, $3, etc., which will be filled in by
+ # the additional arguments passed to this function.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ # @param query The SQL statement to execute.
+ def execute(self, query, *args, **kwargs):
+ if self.connection.is_closed:
+ raise ConnectionClosedError()
+ self.connection._unnamed_prepared_statement_lock.acquire()
+ try:
+ self._stmt = PreparedStatement(self.connection, query, statement_name="", *[{"type": type(x), "value": x} for x in args])
+ self._stmt.execute(*args, **kwargs)
+ finally:
+ self.connection._unnamed_prepared_statement_lock.release()
+
+ ##
+ # Return a count of the number of rows currently being read. If possible,
+ # please avoid using this function. It requires reading the entire result
+ # set from the database to determine the number of rows being returned.
+ # <p>
+ # Stability: Added in v1.03, stability guaranteed for v1.xx.
+ # Implementation currently requires caching entire result set into memory,
+ # avoid using this property.
+ row_count = property(lambda self: self._get_row_count())
+
+ @require_stmt
+ def _get_row_count(self):
+ return self._stmt.row_count
+
+ ##
+ # Read a row from the database server, and return it in a dictionary
+ # indexed by column name/alias. This method will raise an error if two
+ # columns have the same name. Returns None after the last row.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ @require_stmt
+ def read_dict(self):
+ return self._stmt.read_dict()
+
+ ##
+ # Read a row from the database server, and return it as a tuple of values.
+ # Returns None after the last row.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ @require_stmt
+ def read_tuple(self):
+ return self._stmt.read_tuple()
+
+ ##
+ # Return an iterator for the output of this statement. The iterator will
+ # return a tuple for each row, in the same manner as {@link
+ # #PreparedStatement.read_tuple read_tuple}.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ @require_stmt
+ def iterate_tuple(self):
+ return self._stmt.iterate_tuple()
+
+ ##
+ # Return an iterator for the output of this statement. The iterator will
+ # return a dict for each row, in the same manner as {@link
+ # #PreparedStatement.read_dict read_dict}.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ @require_stmt
+ def iterate_dict(self):
+ return self._stmt.iterate_dict()
+
+ def close(self):
+ if self._stmt != None:
+ self._stmt.close()
+ self._stmt = None
+
+
+ ##
+ # Return the fileno of the underlying socket for this cursor's connection.
+ # <p>
+ # Stability: Added in v1.07, stability guaranteed for v1.xx.
+ def fileno(self):
+ return self.connection.fileno()
+
+ ##
+ # Poll the underlying socket for this cursor and sync if there is data waiting
+ # to be read. This has the effect of flushing asynchronous messages from the
+ # backend. Returns True if messages were read, False otherwise.
+ # <p>
+ # Stability: Added in v1.07, stability guaranteed for v1.xx.
+ def isready(self):
+ return self.connection.isready()
+
+
+##
+# This class represents a connection to a PostgreSQL database.
+# <p>
+# The database connection is derived from the {@link #Cursor Cursor} class,
+# which provides a default cursor for running queries. It also provides
+# transaction control via the 'begin', 'commit', and 'rollback' methods.
+# Without beginning a transaction explicitly, all statements will autocommit to
+# the database.
+# <p>
+# As of v1.01, instances of this class are thread-safe. See {@link
+# PreparedStatement PreparedStatement} for more information.
+# <p>
+# Stability: Added in v1.00, stability guaranteed for v1.xx.
+#
+# @param user The username to connect to the PostgreSQL server with. This
+# parameter is required.
+#
+# @keyparam host The hostname of the PostgreSQL server to connect with.
+# Providing this parameter is necessary for TCP/IP connections. One of either
+# host, or unix_sock, must be provided.
+#
+# @keyparam unix_sock The path to the UNIX socket to access the database
+# through, for example, '/tmp/.s.PGSQL.5432'. One of either unix_sock or host
+# must be provided. The port parameter will have no affect if unix_sock is
+# provided.
+#
+# @keyparam port The TCP/IP port of the PostgreSQL server instance. This
+# parameter defaults to 5432, the registered and common port of PostgreSQL
+# TCP/IP servers.
+#
+# @keyparam database The name of the database instance to connect with. This
+# parameter is optional, if omitted the PostgreSQL server will assume the
+# database name is the same as the username.
+#
+# @keyparam password The user password to connect to the server with. This
+# parameter is optional. If omitted, and the database server requests password
+# based authentication, the connection will fail. On the other hand, if this
+# parameter is provided and the database does not request password
+# authentication, then the password will not be used.
+#
+# @keyparam socket_timeout Socket connect timeout measured in seconds.
+# Defaults to 60 seconds.
+#
+# @keyparam ssl Use SSL encryption for TCP/IP socket. Defaults to False.
+class Connection(Cursor):
+ def __init__(self, user, host=None, unix_sock=None, port=5432, database=None, password=None, socket_timeout=60, ssl=False, options=None):
+ self._row_desc = None
+ try:
+ self.c = protocol.Connection(unix_sock=unix_sock, host=host, port=port, socket_timeout=socket_timeout, ssl=ssl)
+ self.c.authenticate(user, password=password, database=database, options=options)
+ except socket.error, e:
+ raise InterfaceError("communication error", e)
+ Cursor.__init__(self, self)
+ self._begin = PreparedStatement(self, "BEGIN TRANSACTION")
+ self._commit = PreparedStatement(self, "COMMIT TRANSACTION")
+ self._rollback = PreparedStatement(self, "ROLLBACK TRANSACTION")
+ self._unnamed_prepared_statement_lock = threading.RLock()
+
+ ##
+ # An event handler that is fired when NOTIFY occurs for a notification that
+ # has been LISTEN'd for. The value of this property is a
+ # util.MulticastDelegate. A callback can be added by using
+ # connection.NotificationReceived += SomeMethod. The method will be called
+ # with a single argument, an object that has properties: backend_pid,
+ # condition, and additional_info. Callbacks can be removed with the -=
+ # operator.
+ # <p>
+ # Stability: Added in v1.03, stability guaranteed for v1.xx.
+ NotificationReceived = property(
+ lambda self: getattr(self.c, "NotificationReceived"),
+ lambda self, value: setattr(self.c, "NotificationReceived", value)
+ )
+
+ ##
+ # An event handler that is fired when the database server issues a notice.
+ # The value of this property is a util.MulticastDelegate. A callback can
+ # be added by using connection.NotificationReceived += SomeMethod. The
+ # method will be called with a single argument, an object that has
+ # properties: severity, code, msg, and possibly others (detail, hint,
+ # position, where, file, line, and routine). Callbacks can be removed with
+ # the -= operator.
+ # <p>
+ # Stability: Added in v1.03, stability guaranteed for v1.xx.
+ NoticeReceived = property(
+ lambda self: getattr(self.c, "NoticeReceived"),
+ lambda self, value: setattr(self.c, "NoticeReceived", value)
+ )
+
+ ##
+ # An event handler that is fired when a runtime configuration option is
+ # changed on the server. The value of this property is a
+ # util.MulticastDelegate. A callback can be added by using
+ # connection.NotificationReceived += SomeMethod. Callbacks can be removed
+ # with the -= operator. The method will be called with a single argument,
+ # an object that has properties "key" and "value".
+ # <p>
+ # Stability: Added in v1.03, stability guaranteed for v1.xx.
+ ParameterStatusReceived = property(
+ lambda self: getattr(self.c, "ParameterStatusReceived"),
+ lambda self, value: setattr(self.c, "ParameterStatusReceived", value)
+ )
+
+ ##
+ # Begins a new transaction.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def begin(self):
+ if self.is_closed:
+ raise ConnectionClosedError()
+ self._begin.execute()
+
+ ##
+ # Commits the running transaction.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def commit(self):
+ if self.is_closed:
+ raise ConnectionClosedError()
+ self._commit.execute()
+
+ ##
+ # Rolls back the running transaction.
+ # <p>
+ # Stability: Added in v1.00, stability guaranteed for v1.xx.
+ def rollback(self):
+ if self.is_closed:
+ raise ConnectionClosedError()
+ self._rollback.execute()
+
+ ##
+ # Closes an open connection.
+ def close(self):
+ if self.is_closed:
+ raise ConnectionClosedError()
+ self.c.close()
+ self.c = None
+
+ is_closed = property(lambda self: self.c == None)
+
+ def recache_record_types(self):
+ self.c._cache_record_attnames()
+
+ ##
+ # Return the fileno of the underlying socket for this connection.
+ # <p>
+ # Stability: Added in v1.07, stability guaranteed for v1.xx.
+ def fileno(self):
+ return self.c.fileno()
+
+ ##
+ # Poll the underlying socket for this connection and sync if there is data
+ # waiting to be read. This has the effect of flushing asynchronous
+ # messages from the backend. Returns True if messages were read, False
+ # otherwise.
+ # <p>
+ # Stability: Added in v1.07, stability guaranteed for v1.xx.
+ def isready(self):
+ return self.c.isready()
+
http://git-wip-us.apache.org/repos/asf/incubator-hawq/blob/a485be47/tools/bin/ext/pg8000/protocol.py
----------------------------------------------------------------------
diff --git a/tools/bin/ext/pg8000/protocol.py b/tools/bin/ext/pg8000/protocol.py
new file mode 100644
index 0000000..377074b
--- /dev/null
+++ b/tools/bin/ext/pg8000/protocol.py
@@ -0,0 +1,1340 @@
+# vim: sw=4:expandtab:foldmethod=marker
+#
+# Copyright (c) 2007-2009, Mathieu Fenniak
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+# * The name of the author may not be used to endorse or promote products
+# derived from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+__author__ = "Mathieu Fenniak"
+
+import socket
+import select
+import threading
+import struct
+import hashlib
+from cStringIO import StringIO
+
+from errors import *
+from util import MulticastDelegate
+import types
+
+##
+# An SSLRequest message. To initiate an SSL-encrypted connection, an
+# SSLRequest message is used rather than a {@link StartupMessage
+# StartupMessage}. A StartupMessage is still sent, but only after SSL
+# negotiation (if accepted).
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class SSLRequest(object):
+ def __init__(self):
+ pass
+
+ # Int32(8) - Message length, including self.<br>
+ # Int32(80877103) - The SSL request code.<br>
+ def serialize(self):
+ return struct.pack("!ii", 8, 80877103)
+
+
+##
+# A StartupMessage message. Begins a DB session, identifying the user to be
+# authenticated as and the database to connect to.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class StartupMessage(object):
+ # Greenplum utility mode
+ def __init__(self, user, database=None, options=None):
+ self.user = user
+ self.database = database
+ self.options = options
+
+ # Int32 - Message length, including self.
+ # Int32(196608) - Protocol version number. Version 3.0.
+ # Any number of key/value pairs, terminated by a zero byte:
+ # String - A parameter name (user, database, or options)
+ # String - Parameter value
+ def serialize(self):
+ protocol = 196608
+ val = struct.pack("!i", protocol)
+ val += "user\x00" + self.user + "\x00"
+ if self.database:
+ val += "database\x00" + self.database + "\x00"
+ if self.options:
+ val += "options\x00" + self.options + "\x00"
+ val += "\x00"
+ val = struct.pack("!i", len(val) + 4) + val
+ return val
+
+
+##
+# Parse message. Creates a prepared statement in the DB session.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+#
+# @param ps Name of the prepared statement to create.
+# @param qs Query string.
+# @param type_oids An iterable that contains the PostgreSQL type OIDs for
+# parameters in the query string.
+class Parse(object):
+ def __init__(self, ps, qs, type_oids):
+ self.ps = ps
+ self.qs = qs
+ self.type_oids = type_oids
+
+ def __repr__(self):
+ return "<Parse ps=%r qs=%r>" % (self.ps, self.qs)
+
+ # Byte1('P') - Identifies the message as a Parse command.
+ # Int32 - Message length, including self.
+ # String - Prepared statement name. An empty string selects the unnamed
+ # prepared statement.
+ # String - The query string.
+ # Int16 - Number of parameter data types specified (can be zero).
+ # For each parameter:
+ # Int32 - The OID of the parameter data type.
+ def serialize(self):
+ val = self.ps + "\x00" + self.qs + "\x00"
+ val = val + struct.pack("!h", len(self.type_oids))
+ for oid in self.type_oids:
+ # Parse message doesn't seem to handle the -1 type_oid for NULL
+ # values that other messages handle. So we'll provide type_oid 705,
+ # the PG "unknown" type.
+ if oid == -1: oid = 705
+ val = val + struct.pack("!i", oid)
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "P" + val
+ return val
+
+
+##
+# Bind message. Readies a prepared statement for execution.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+#
+# @param portal Name of the destination portal.
+# @param ps Name of the source prepared statement.
+# @param in_fc An iterable containing the format codes for input
+# parameters. 0 = Text, 1 = Binary.
+# @param params The parameters.
+# @param out_fc An iterable containing the format codes for output
+# parameters. 0 = Text, 1 = Binary.
+# @param kwargs Additional arguments to pass to the type conversion
+# methods.
+class Bind(object):
+ def __init__(self, portal, ps, in_fc, params, out_fc, **kwargs):
+ self.portal = portal
+ self.ps = ps
+ self.in_fc = in_fc
+ self.params = []
+ for i in range(len(params)):
+ if len(self.in_fc) == 0:
+ fc = 0
+ elif len(self.in_fc) == 1:
+ fc = self.in_fc[0]
+ else:
+ fc = self.in_fc[i]
+ self.params.append(types.pg_value(params[i], fc, **kwargs))
+ self.out_fc = out_fc
+
+ def __repr__(self):
+ return "<Bind p=%r s=%r>" % (self.portal, self.ps)
+
+ # Byte1('B') - Identifies the Bind command.
+ # Int32 - Message length, including self.
+ # String - Name of the destination portal.
+ # String - Name of the source prepared statement.
+ # Int16 - Number of parameter format codes.
+ # For each parameter format code:
+ # Int16 - The parameter format code.
+ # Int16 - Number of parameter values.
+ # For each parameter value:
+ # Int32 - The length of the parameter value, in bytes, not including this
+ # this length. -1 indicates a NULL parameter value, in which no
+ # value bytes follow.
+ # Byte[n] - Value of the parameter.
+ # Int16 - The number of result-column format codes.
+ # For each result-column format code:
+ # Int16 - The format code.
+ def serialize(self):
+ retval = StringIO()
+ retval.write(self.portal + "\x00")
+ retval.write(self.ps + "\x00")
+ retval.write(struct.pack("!h", len(self.in_fc)))
+ for fc in self.in_fc:
+ retval.write(struct.pack("!h", fc))
+ retval.write(struct.pack("!h", len(self.params)))
+ for param in self.params:
+ if param == None:
+ # special case, NULL value
+ retval.write(struct.pack("!i", -1))
+ else:
+ retval.write(struct.pack("!i", len(param)))
+ retval.write(param)
+ retval.write(struct.pack("!h", len(self.out_fc)))
+ for fc in self.out_fc:
+ retval.write(struct.pack("!h", fc))
+ val = retval.getvalue()
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "B" + val
+ return val
+
+
+##
+# A Close message, used for closing prepared statements and portals.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+#
+# @param typ 'S' for prepared statement, 'P' for portal.
+# @param name The name of the item to close.
+class Close(object):
+ def __init__(self, typ, name):
+ if len(typ) != 1:
+ raise InternalError("Close typ must be 1 char")
+ self.typ = typ
+ self.name = name
+
+ # Byte1('C') - Identifies the message as a close command.
+ # Int32 - Message length, including self.
+ # Byte1 - 'S' for prepared statement, 'P' for portal.
+ # String - The name of the item to close.
+ def serialize(self):
+ val = self.typ + self.name + "\x00"
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "C" + val
+ return val
+
+
+##
+# A specialized Close message for a portal.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class ClosePortal(Close):
+ def __init__(self, name):
+ Close.__init__(self, "P", name)
+
+
+##
+# A specialized Close message for a prepared statement.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class ClosePreparedStatement(Close):
+ def __init__(self, name):
+ Close.__init__(self, "S", name)
+
+
+##
+# A Describe message, used for obtaining information on prepared statements
+# and portals.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+#
+# @param typ 'S' for prepared statement, 'P' for portal.
+# @param name The name of the item to close.
+class Describe(object):
+ def __init__(self, typ, name):
+ if len(typ) != 1:
+ raise InternalError("Describe typ must be 1 char")
+ self.typ = typ
+ self.name = name
+
+ # Byte1('D') - Identifies the message as a describe command.
+ # Int32 - Message length, including self.
+ # Byte1 - 'S' for prepared statement, 'P' for portal.
+ # String - The name of the item to close.
+ def serialize(self):
+ val = self.typ + self.name + "\x00"
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "D" + val
+ return val
+
+
+##
+# A specialized Describe message for a portal.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class DescribePortal(Describe):
+ def __init__(self, name):
+ Describe.__init__(self, "P", name)
+
+ def __repr__(self):
+ return "<DescribePortal %r>" % (self.name)
+
+
+##
+# A specialized Describe message for a prepared statement.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class DescribePreparedStatement(Describe):
+ def __init__(self, name):
+ Describe.__init__(self, "S", name)
+
+ def __repr__(self):
+ return "<DescribePreparedStatement %r>" % (self.name)
+
+
+##
+# A Flush message forces the backend to deliver any data pending in its
+# output buffers.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class Flush(object):
+ # Byte1('H') - Identifies the message as a flush command.
+ # Int32(4) - Length of message, including self.
+ def serialize(self):
+ return 'H\x00\x00\x00\x04'
+
+ def __repr__(self):
+ return "<Flush>"
+
+##
+# Causes the backend to close the current transaction (if not in a BEGIN/COMMIT
+# block), and issue ReadyForQuery.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class Sync(object):
+ # Byte1('S') - Identifies the message as a sync command.
+ # Int32(4) - Length of message, including self.
+ def serialize(self):
+ return 'S\x00\x00\x00\x04'
+
+ def __repr__(self):
+ return "<Sync>"
+
+
+##
+# Transmits a password.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class PasswordMessage(object):
+ def __init__(self, pwd):
+ self.pwd = pwd
+
+ # Byte1('p') - Identifies the message as a password message.
+ # Int32 - Message length including self.
+ # String - The password. Password may be encrypted.
+ def serialize(self):
+ val = self.pwd + "\x00"
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "p" + val
+ return val
+
+
+##
+# Requests that the backend execute a portal and retrieve any number of rows.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+# @param row_count The number of rows to return. Can be zero to indicate the
+# backend should return all rows. If the portal represents a
+# query that does not return rows, no rows will be returned
+# no matter what the row_count.
+class Execute(object):
+ def __init__(self, portal, row_count):
+ self.portal = portal
+ self.row_count = row_count
+
+ # Byte1('E') - Identifies the message as an execute message.
+ # Int32 - Message length, including self.
+ # String - The name of the portal to execute.
+ # Int32 - Maximum number of rows to return, if portal contains a query that
+ # returns rows. 0 = no limit.
+ def serialize(self):
+ val = self.portal + "\x00" + struct.pack("!i", self.row_count)
+ val = struct.pack("!i", len(val) + 4) + val
+ val = "E" + val
+ return val
+
+
+##
+# Informs the backend that the connection is being closed.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class Terminate(object):
+ def __init__(self):
+ pass
+
+ # Byte1('X') - Identifies the message as a terminate message.
+ # Int32(4) - Message length, including self.
+ def serialize(self):
+ return 'X\x00\x00\x00\x04'
+
+##
+# Base class of all Authentication[*] messages.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class AuthenticationRequest(object):
+ def __init__(self, data):
+ pass
+
+ # Byte1('R') - Identifies the message as an authentication request.
+ # Int32(8) - Message length, including self.
+ # Int32 - An authentication code that represents different
+ # authentication messages:
+ # 0 = AuthenticationOk
+ # 5 = MD5 pwd
+ # 2 = Kerberos v5 (not supported by pg8000)
+ # 3 = Cleartext pwd (not supported by pg8000)
+ # 4 = crypt() pwd (not supported by pg8000)
+ # 6 = SCM credential (not supported by pg8000)
+ # 7 = GSSAPI (not supported by pg8000)
+ # 8 = GSSAPI data (not supported by pg8000)
+ # 9 = SSPI (not supported by pg8000)
+ # Some authentication messages have additional data following the
+ # authentication code. That data is documented in the appropriate class.
+ def createFromData(data):
+ ident = struct.unpack("!i", data[:4])[0]
+ klass = authentication_codes.get(ident, None)
+ if klass != None:
+ return klass(data[4:])
+ else:
+ raise NotSupportedError("authentication method %r not supported" % (ident,))
+ createFromData = staticmethod(createFromData)
+
+ def ok(self, conn, user, **kwargs):
+ raise InternalError("ok method should be overridden on AuthenticationRequest instance")
+
+##
+# A message representing that the backend accepting the provided username
+# without any challenge.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class AuthenticationOk(AuthenticationRequest):
+ def ok(self, conn, user, **kwargs):
+ return True
+
+
+##
+# A message representing the backend requesting an MD5 hashed password
+# response. The response will be sent as md5(md5(pwd + login) + salt).
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class AuthenticationMD5Password(AuthenticationRequest):
+ # Additional message data:
+ # Byte4 - Hash salt.
+ def __init__(self, data):
+ self.salt = "".join(struct.unpack("4c", data))
+
+ def ok(self, conn, user, password=None, **kwargs):
+ if password == None:
+ raise InterfaceError("server requesting MD5 password authentication, but no password was provided")
+ pwd = "md5" + hashlib.md5(hashlib.md5(password + user).hexdigest() + self.salt).hexdigest()
+ conn._send(PasswordMessage(pwd))
+ conn._flush()
+
+ reader = MessageReader(conn)
+ reader.add_message(AuthenticationRequest, lambda msg, reader: reader.return_value(msg.ok(conn, user)), reader)
+ reader.add_message(ErrorResponse, self._ok_error)
+ return reader.handle_messages()
+
+ def _ok_error(self, msg):
+ if msg.code == "28000":
+ raise InterfaceError("md5 password authentication failed")
+ else:
+ raise msg.createException()
+
+authentication_codes = {
+ 0: AuthenticationOk,
+ 5: AuthenticationMD5Password,
+}
+
+
+##
+# ParameterStatus message sent from backend, used to inform the frotnend of
+# runtime configuration parameter changes.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class ParameterStatus(object):
+ def __init__(self, key, value):
+ self.key = key
+ self.value = value
+
+ # Byte1('S') - Identifies ParameterStatus
+ # Int32 - Message length, including self.
+ # String - Runtime parameter name.
+ # String - Runtime parameter value.
+ def createFromData(data):
+ key = data[:data.find("\x00")]
+ value = data[data.find("\x00")+1:-1]
+ return ParameterStatus(key, value)
+ createFromData = staticmethod(createFromData)
+
+
+##
+# BackendKeyData message sent from backend. Contains a connection's process
+# ID and a secret key. Can be used to terminate the connection's current
+# actions, such as a long running query. Not supported by pg8000 yet.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class BackendKeyData(object):
+ def __init__(self, process_id, secret_key):
+ self.process_id = process_id
+ self.secret_key = secret_key
+
+ # Byte1('K') - Identifier.
+ # Int32(12) - Message length, including self.
+ # Int32 - Process ID.
+ # Int32 - Secret key.
+ def createFromData(data):
+ process_id, secret_key = struct.unpack("!2i", data)
+ return BackendKeyData(process_id, secret_key)
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing a query with no data.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class NoData(object):
+ # Byte1('n') - Identifier.
+ # Int32(4) - Message length, including self.
+ def createFromData(data):
+ return NoData()
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing a successful Parse.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class ParseComplete(object):
+ # Byte1('1') - Identifier.
+ # Int32(4) - Message length, including self.
+ def createFromData(data):
+ return ParseComplete()
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing a successful Bind.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class BindComplete(object):
+ # Byte1('2') - Identifier.
+ # Int32(4) - Message length, including self.
+ def createFromData(data):
+ return BindComplete()
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing a successful Close.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class CloseComplete(object):
+ # Byte1('3') - Identifier.
+ # Int32(4) - Message length, including self.
+ def createFromData(data):
+ return CloseComplete()
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing data from an Execute has been received, but more data
+# exists in the portal.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class PortalSuspended(object):
+ # Byte1('s') - Identifier.
+ # Int32(4) - Message length, including self.
+ def createFromData(data):
+ return PortalSuspended()
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Message representing the backend is ready to process a new query.
+# <p>
+# Stability: This is an internal class. No stability guarantee is made.
+class ReadyForQuery(object):
+ def __init__(self, status):
+ self._status = status
+
+ ##
+ # I = Idle, T = Idle in Transaction, E = idle in failed transaction.
+ status = property(lambda self: self._status)
+
+ def __repr__(self):
+ return "<ReadyForQuery %s>" % \
+ {"I": "Idle", "T": "Idle in Transaction", "E": "Idle in Failed Transaction"}[self.status]
+
+ # Byte1('Z') - Identifier.
+ # Int32(5) - Message length, including self.
+ # Byte1 - Status indicator.
+ def createFromData(data):
+ return ReadyForQuery(data)
+ createFromData = staticmethod(createFromData)
+
+
+##
+# Represents a notice sent from the server. This is not the same as a
+# notification. A notice is just additional information about a query, such
+# as a notice that a primary key has automatically been created for a table.
+# <p>
+# A NoticeResponse instance will have properties containing the data sent
+# from the server:
+# <ul>
+# <li>severity -- "ERROR", "FATAL', "PANIC", "WARNING", "NOTICE", "DEBUG",
+# "INFO", or "LOG". Always present.</li>
+# <li>code -- the SQLSTATE code for the error. See Appendix A of the
+# PostgreSQL documentation for specific error codes. Always present.</li>
+# <li>msg -- human-readable error message. Always present.</li>
+# <li>detail -- Optional additional information.</li>
+# <li>hint -- Optional suggestion about what to do about the issue.</li>
+# <li>position -- Optional index into the query string.</li>
+# <li>where -- Optional context.</li>
+# <li>file -- Source-code file.</li>
+# <li>line -- Source-code line.</li>
+# <li>routine -- Source-code routine.</li>
+# </ul>
+# <p>
+# Stability: Added in pg8000 v1.03. Required properties severity, code, and
+# msg are guaranteed for v1.xx. Other properties should be checked with
+# hasattr before accessing.
+class NoticeResponse(object):
+ responseKeys = {
+ "S": "severity", # always present
+ "C": "code", # always present
+ "M": "msg", # always present
+ "D": "detail",
+ "H": "hint",
+ "P": "position",
+ "p": "_position",
+ "q": "_query",
+ "W": "where",
+ "F": "file",
+ "L": "line",
+ "R": "routine",
+ }
+
+ def __init__(self, **kwargs):
+ for arg, value in kwargs.items():
+ setattr(self, arg, value)
+
+ def __repr__(self):
+ return "<NoticeResponse %s %s %r>" % (self.severity, self.code, self.msg)
+
+ def dataIntoDict(data):
+ retval = {}
+ for s in data.split("\x00"):
+ if not s: continue
+ key, value = s[0], s[1:]
+ key = NoticeResponse.responseKeys.get(key, key)
+ retval[key] = value
+ return retval
+ dataIntoDict = staticmethod(dataIntoDict)
+
+ # Byte1('N') - Identifier
+ # Int32 - Message length
+ # Any number of these, followed by a zero byte:
+ # Byte1 - code identifying the field type (see responseKeys)
+ # String - field value
+ def createFromData(data):
+ return NoticeResponse(**NoticeResponse.dataIntoDict(data))
+ createFromData = staticmethod(createFromData)
+
+
+##
+# A message sent in case of a server-side error. Contains the same properties
+# that {@link NoticeResponse NoticeResponse} contains.
+# <p>
+# Stability: Added in pg8000 v1.03. Required properties severity, code, and
+# msg are guaranteed for v1.xx. Other properties should be checked with
+# hasattr before accessing.
+class ErrorResponse(object):
+ def __init__(self, **kwargs):
+ for arg, value in kwargs.items():
+ setattr(self, arg, value)
+
+ def __repr__(self):
+ return "<ErrorResponse %s %s %r>" % (self.severity, self.code, self.msg)
+
+ def createException(self):
+ return ProgrammingError(self.severity, self.code, self.msg)
+
+ def createFromData(data):
+ return ErrorResponse(**NoticeResponse.dataIntoDict(data))
+ createFromData = staticmethod(createFromData)
+
+
+##
+# A message sent if this connection receives a NOTIFY that it was LISTENing for.
+# <p>
+# Stability: Added in pg8000 v1.03. When limited to accessing properties from
+# a notification event dispatch, stability is guaranteed for v1.xx.
+class NotificationResponse(object):
+ def __init__(self, backend_pid, condition, additional_info):
+ self._backend_pid = backend_pid
+ self._condition = condition
+ self._additional_info = additional_info
+
+ ##
+ # An integer representing the process ID of the backend that triggered
+ # the NOTIFY.
+ # <p>
+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
+ backend_pid = property(lambda self: self._backend_pid)
+
+ ##
+ # The name of the notification fired.
+ # <p>
+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
+ condition = property(lambda self: self._condition)
+
+ ##
+ # Currently unspecified by the PostgreSQL documentation as of v8.3.1.
+ # <p>
+ # Stability: Added in pg8000 v1.03, stability guaranteed for v1.xx.
+ additional_info = property(lambda self: self._additional_info)
+
+ def __repr__(self):
+ return "<NotificationResponse %s %s %r>" % (self.backend_pid, self.condition, self.additional_info)
+
+ def createFromData(data):
+ backend_pid = struct.unpack("!i", data[:4])[0]
+ data = data[4:]
+ null = data.find("\x00")
+ condition = data[:null]
+ data = data[null+1:]
+ null = data.find("\x00")
+ additional_info = data[:null]
+ return NotificationResponse(backend_pid, condition, additional_info)
+ createFromData = staticmethod(createFromData)
+
+
+class ParameterDescription(object):
+ def __init__(self, type_oids):
+ self.type_oids = type_oids
+ def createFromData(data):
+ count = struct.unpack("!h", data[:2])[0]
+ type_oids = struct.unpack("!" + "i"*count, data[2:])
+ return ParameterDescription(type_oids)
+ createFromData = staticmethod(createFromData)
+
+
+class RowDescription(object):
+ def __init__(self, fields):
+ self.fields = fields
+
+ def createFromData(data):
+ count = struct.unpack("!h", data[:2])[0]
+ data = data[2:]
+ fields = []
+ for i in range(count):
+ null = data.find("\x00")
+ field = {"name": data[:null]}
+ data = data[null+1:]
+ field["table_oid"], field["column_attrnum"], field["type_oid"], field["type_size"], field["type_modifier"], field["format"] = struct.unpack("!ihihih", data[:18])
+ data = data[18:]
+ fields.append(field)
+ return RowDescription(fields)
+ createFromData = staticmethod(createFromData)
+
+class CommandComplete(object):
+ def __init__(self, command, rows=None, oid=None):
+ self.command = command
+ self.rows = rows
+ self.oid = oid
+
+ def createFromData(data):
+ values = data[:-1].split(" ")
+ args = {}
+ args['command'] = values[0]
+ if args['command'] in ("INSERT", "DELETE", "UPDATE", "MOVE", "FETCH", "COPY"):
+ args['rows'] = int(values[-1])
+ if args['command'] == "INSERT":
+ args['oid'] = int(values[1])
+ else:
+ args['command'] = data[:-1]
+ return CommandComplete(**args)
+ createFromData = staticmethod(createFromData)
+
+
+class DataRow(object):
+ def __init__(self, fields):
+ self.fields = fields
+
+ def createFromData(data):
+ count = struct.unpack("!h", data[:2])[0]
+ data = data[2:]
+ fields = []
+ for i in range(count):
+ val_len = struct.unpack("!i", data[:4])[0]
+ data = data[4:]
+ if val_len == -1:
+ fields.append(None)
+ else:
+ fields.append(data[:val_len])
+ data = data[val_len:]
+ return DataRow(fields)
+ createFromData = staticmethod(createFromData)
+
+
+class CopyData(object):
+ # "d": CopyData,
+ def __init__(self, data):
+ self.data = data
+
+ def createFromData(data):
+ return CopyData(data)
+ createFromData = staticmethod(createFromData)
+
+ def serialize(self):
+ return 'd' + struct.pack('!i', len(self.data) + 4) + self.data
+
+
+class CopyDone(object):
+ # Byte1('c') - Identifier.
+ # Int32(4) - Message length, including self.
+
+ def createFromData(data):
+ return CopyDone()
+
+ createFromData = staticmethod(createFromData)
+
+ def serialize(self):
+ return 'c\x00\x00\x00\x04'
+
+class CopyOutResponse(object):
+ # Byte1('H')
+ # Int32(4) - Length of message contents in bytes, including self.
+ # Int8(1) - 0 textual, 1 binary
+ # Int16(2) - Number of columns
+ # Int16(N) - Format codes for each column (0 text, 1 binary)
+
+ def __init__(self, is_binary, column_formats):
+ self.is_binary = is_binary
+ self.column_formats = column_formats
+
+ def createFromData(data):
+ is_binary, num_cols = struct.unpack('!bh', data[:3])
+ column_formats = struct.unpack('!' + ('h' * num_cols), data[3:])
+ return CopyOutResponse(is_binary, column_formats)
+
+ createFromData = staticmethod(createFromData)
+
+
+class CopyInResponse(object):
+ # Byte1('G')
+ # Otherwise the same as CopyOutResponse
+
+ def __init__(self, is_binary, column_formats):
+ self.is_binary = is_binary
+ self.column_formats = column_formats
+
+ def createFromData(data):
+ is_binary, num_cols = struct.unpack('!bh', data[:3])
+ column_formats = struct.unpack('!' + ('h' * num_cols), data[3:])
+ return CopyInResponse(is_binary, column_formats)
+
+ createFromData = staticmethod(createFromData)
+
+class SSLWrapper(object):
+ def __init__(self, sslobj):
+ self.sslobj = sslobj
+ def send(self, data):
+ self.sslobj.write(data)
+ def recv(self, num):
+ return self.sslobj.read(num)
+
+
+class MessageReader(object):
+ def __init__(self, connection):
+ self._conn = connection
+ self._msgs = []
+
+ # If true, raise exception from an ErrorResponse after messages are
+ # processed. This can be used to leave the connection in a usable
+ # state after an error response, rather than having unconsumed
+ # messages that won't be understood in another context.
+ self.delay_raising_exception = False
+
+ self.ignore_unhandled_messages = False
+
+ def add_message(self, msg_class, handler, *args, **kwargs):
+ self._msgs.append((msg_class, handler, args, kwargs))
+
+ def clear_messages(self):
+ self._msgs = []
+
+ def return_value(self, value):
+ self._retval = value
+
+ def handle_messages(self):
+ exc = None
+ while 1:
+ msg = self._conn._read_message()
+ msg_handled = False
+ for (msg_class, handler, args, kwargs) in self._msgs:
+ if isinstance(msg, msg_class):
+ msg_handled = True
+ retval = handler(msg, *args, **kwargs)
+ if retval:
+ # The handler returned a true value, meaning that the
+ # message loop should be aborted.
+ if exc != None:
+ raise exc
+ return retval
+ elif hasattr(self, "_retval"):
+ # The handler told us to return -- used for non-true
+ # return values
+ if exc != None:
+ raise exc
+ return self._retval
+ if msg_handled:
+ continue
+ elif isinstance(msg, ErrorResponse):
+ exc = msg.createException()
+ if not self.delay_raising_exception:
+ raise exc
+ elif isinstance(msg, NoticeResponse):
+ self._conn.handleNoticeResponse(msg)
+ elif isinstance(msg, ParameterStatus):
+ self._conn.handleParameterStatus(msg)
+ elif isinstance(msg, NotificationResponse):
+ self._conn.handleNotificationResponse(msg)
+ elif not self.ignore_unhandled_messages:
+ raise InternalError("Unexpected response msg %r" % (msg))
+
+def sync_on_error(fn):
+ def _fn(self, *args, **kwargs):
+ try:
+ self._sock_lock.acquire()
+ return fn(self, *args, **kwargs)
+ except:
+ self._sync()
+ raise
+ finally:
+ self._sock_lock.release()
+ return _fn
+
+class Connection(object):
+ def __init__(self, unix_sock=None, host=None, port=5432, socket_timeout=60, ssl=False, records=False):
+ self._client_encoding = "ascii"
+ self._integer_datetimes = False
+ self._record_field_names = {}
+ self._sock_buf = ""
+ self._sock_buf_pos = 0
+ self._send_sock_buf = []
+ self._block_size = 8192
+ self.user_wants_records = records
+ if unix_sock == None and host != None:
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ elif unix_sock != None:
+ if not hasattr(socket, "AF_UNIX"):
+ raise InterfaceError("attempt to connect to unix socket on unsupported platform")
+ self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+ else:
+ raise ProgrammingError("one of host or unix_sock must be provided")
+ if unix_sock == None and host != None:
+ self._sock.connect((host, port))
+ elif unix_sock != None:
+ self._sock.connect(unix_sock)
+ if ssl:
+ self._send(SSLRequest())
+ self._flush()
+ resp = self._sock.recv(1)
+ if resp == 'S':
+ self._sock = SSLWrapper(socket.ssl(self._sock))
+ else:
+ raise InterfaceError("server refuses SSL")
+ else:
+ # settimeout causes ssl failure, on windows. Python bug 1462352.
+ self._sock.settimeout(socket_timeout)
+ self._state = "noauth"
+ self._backend_key_data = None
+ self._sock_lock = threading.Lock()
+
+ self.NoticeReceived = MulticastDelegate()
+ self.ParameterStatusReceived = MulticastDelegate()
+ self.NotificationReceived = MulticastDelegate()
+
+ self.ParameterStatusReceived += self._onParameterStatusReceived
+
+ def verifyState(self, state):
+ if self._state != state:
+ raise InternalError("connection state must be %s, is %s" % (state, self._state))
+
+ def _send(self, msg):
+ assert self._sock_lock.locked()
+ #print "_send(%r)" % msg
+ data = msg.serialize()
+ self._send_sock_buf.append(data)
+
+ def _flush(self):
+ assert self._sock_lock.locked()
+ self._sock.sendall("".join(self._send_sock_buf))
+ del self._send_sock_buf[:]
+
+ def _read_bytes(self, byte_count):
+ retval = []
+ bytes_read = 0
+ while bytes_read < byte_count:
+ if self._sock_buf_pos == len(self._sock_buf):
+ self._sock_buf = self._sock.recv(1024)
+ self._sock_buf_pos = 0
+ rpos = min(len(self._sock_buf), self._sock_buf_pos + (byte_count - bytes_read))
+ addt_data = self._sock_buf[self._sock_buf_pos:rpos]
+ bytes_read += (rpos - self._sock_buf_pos)
+ assert bytes_read <= byte_count
+ self._sock_buf_pos = rpos
+ retval.append(addt_data)
+ return "".join(retval)
+
+ def _read_message(self):
+ assert self._sock_lock.locked()
+ bytes = self._read_bytes(5)
+ message_code = bytes[0]
+ data_len = struct.unpack("!i", bytes[1:])[0] - 4
+ bytes = self._read_bytes(data_len)
+ assert len(bytes) == data_len
+ msg = message_types[message_code].createFromData(bytes)
+ #print "_read_message() -> %r" % msg
+ return msg
+
+ def authenticate(self, user, **kwargs):
+ self.verifyState("noauth")
+ self._sock_lock.acquire()
+ try:
+ self._send(StartupMessage(user, database=kwargs.get("database",None), options=kwargs.get("options", None)))
+ self._flush()
+ msg = self._read_message()
+ if isinstance(msg, ErrorResponse):
+ raise msg.createException()
+ if not isinstance(msg, AuthenticationRequest):
+ raise InternalError("StartupMessage was responded to with non-AuthenticationRequest msg %r" % msg)
+ if not msg.ok(self, user, **kwargs):
+ raise InterfaceError("authentication method %s failed" % msg.__class__.__name__)
+
+ self._state = "auth"
+
+ reader = MessageReader(self)
+ reader.add_message(ReadyForQuery, self._ready_for_query)
+ reader.add_message(BackendKeyData, self._receive_backend_key_data)
+ reader.handle_messages()
+ finally:
+ self._sock_lock.release()
+
+ self._cache_record_attnames()
+
+ def _ready_for_query(self, msg):
+ self._state = "ready"
+ return True
+
+ def _receive_backend_key_data(self, msg):
+ self._backend_key_data = msg
+
+ def _cache_record_attnames(self):
+ if not self.user_wants_records:
+ return
+
+ parse_retval = self.parse("",
+ """SELECT
+ pg_type.oid, attname
+ FROM
+ pg_type
+ INNER JOIN pg_attribute ON (attrelid = pg_type.typrelid)
+ WHERE typreceive = 'record_recv'::regproc
+ ORDER BY pg_type.oid, attnum""",
+ [])
+ row_desc, cmd = self.bind("tmp", "", (), parse_retval, None)
+ eod, rows = self.fetch_rows("tmp", 0, row_desc)
+
+ self._record_field_names = {}
+ typoid, attnames = None, []
+ for row in rows:
+ new_typoid, attname = row
+ if new_typoid != typoid and typoid != None:
+ self._record_field_names[typoid] = attnames
+ attnames = []
+ typoid = new_typoid
+ attnames.append(attname)
+ self._record_field_names[typoid] = attnames
+
+ @sync_on_error
+ def parse(self, statement, qs, param_types):
+ self.verifyState("ready")
+
+ type_info = [types.pg_type_info(x) for x in param_types]
+ param_types, param_fc = [x[0] for x in type_info], [x[1] for x in type_info] # zip(*type_info) -- fails on empty arr
+ self._send(Parse(statement, qs, param_types))
+ self._send(DescribePreparedStatement(statement))
+ self._send(Flush())
+ self._flush()
+
+ reader = MessageReader(self)
+
+ # ParseComplete is good.
+ reader.add_message(ParseComplete, lambda msg: 0)
+
+ # Well, we don't really care -- we're going to send whatever we
+ # want and let the database deal with it. But thanks anyways!
+ reader.add_message(ParameterDescription, lambda msg: 0)
+
+ # We're not waiting for a row description. Return something
+ # destinctive to let bind know that there is no output.
+ reader.add_message(NoData, lambda msg: (None, param_fc))
+
+ # Common row description response
+ reader.add_message(RowDescription, lambda msg: (msg, param_fc))
+
+ return reader.handle_messages()
+
+ @sync_on_error
+ def bind(self, portal, statement, params, parse_data, copy_stream):
+ self.verifyState("ready")
+
+ row_desc, param_fc = parse_data
+ if row_desc == None:
+ # no data coming out
+ output_fc = ()
+ else:
+ # We've got row_desc that allows us to identify what we're going to
+ # get back from this statement.
+ output_fc = [types.py_type_info(f, self._record_field_names) for f in row_desc.fields]
+ self._send(Bind(portal, statement, param_fc, params, output_fc, client_encoding = self._client_encoding, integer_datetimes = self._integer_datetimes))
+ # We need to describe the portal after bind, since the return
+ # format codes will be different (hopefully, always what we
+ # requested).
+ self._send(DescribePortal(portal))
+ self._send(Flush())
+ self._flush()
+
+ # Read responses from server...
+ reader = MessageReader(self)
+
+ # BindComplete is good -- just ignore
+ reader.add_message(BindComplete, lambda msg: 0)
+
+ # NoData in this case means we're not executing a query. As a
+ # result, we won't be fetching rows, so we'll never execute the
+ # portal we just created... unless we execute it right away, which
+ # we'll do.
+ reader.add_message(NoData, self._bind_nodata, portal, reader, copy_stream)
+
+ # Return the new row desc, since it will have the format types we
+ # asked the server for
+ reader.add_message(RowDescription, lambda msg: (msg, None))
+
+ return reader.handle_messages()
+
+ def _copy_in_response(self, copyin, fileobj, old_reader):
+ if fileobj == None:
+ raise CopyQueryWithoutStreamError()
+ while True:
+ data = fileobj.read(self._block_size)
+ if not data:
+ break
+ self._send(CopyData(data))
+ self._flush()
+ self._send(CopyDone())
+ self._send(Sync())
+ self._flush()
+
+ def _copy_out_response(self, copyout, fileobj, old_reader):
+ if fileobj == None:
+ raise CopyQueryWithoutStreamError()
+ reader = MessageReader(self)
+ reader.add_message(CopyData, self._copy_data, fileobj)
+ reader.add_message(CopyDone, lambda msg: 1)
+ reader.handle_messages()
+
+ def _copy_data(self, copydata, fileobj):
+ fileobj.write(copydata.data)
+
+ def _bind_nodata(self, msg, portal, old_reader, copy_stream):
+ # Bind message returned NoData, causing us to execute the command.
+ self._send(Execute(portal, 0))
+ self._send(Sync())
+ self._flush()
+
+ output = {}
+ reader = MessageReader(self)
+ reader.add_message(CopyOutResponse, self._copy_out_response, copy_stream, reader)
+ reader.add_message(CopyInResponse, self._copy_in_response, copy_stream, reader)
+ reader.add_message(CommandComplete, lambda msg, out: out.setdefault('msg', msg) and False, output)
+ reader.add_message(ReadyForQuery, lambda msg: 1)
+ reader.delay_raising_exception = True
+ reader.handle_messages()
+
+ old_reader.return_value((None, output['msg']))
+
+ @sync_on_error
+ def fetch_rows(self, portal, row_count, row_desc):
+ self.verifyState("ready")
+
+ self._send(Execute(portal, row_count))
+ self._send(Flush())
+ self._flush()
+ rows = []
+
+ reader = MessageReader(self)
+ reader.add_message(DataRow, self._fetch_datarow, rows, row_desc)
+ reader.add_message(PortalSuspended, lambda msg: 1)
+ reader.add_message(CommandComplete, self._fetch_commandcomplete, portal)
+ retval = reader.handle_messages()
+
+ # retval = 2 when command complete, indicating that we've hit the
+ # end of the available data for this command
+ return (retval == 2), rows
+
+ def _fetch_datarow(self, msg, rows, row_desc):
+ rows.append(
+ [
+ types.py_value(
+ msg.fields[i],
+ row_desc.fields[i],
+ client_encoding=self._client_encoding,
+ integer_datetimes=self._integer_datetimes,
+ record_field_names=self._record_field_names
+ )
+ for i in range(len(msg.fields))
+ ]
+ )
+
+ def _fetch_commandcomplete(self, msg, portal):
+ self._send(ClosePortal(portal))
+ self._send(Sync())
+ self._flush()
+
+ reader = MessageReader(self)
+ reader.add_message(ReadyForQuery, self._fetch_commandcomplete_rfq)
+ reader.add_message(CloseComplete, lambda msg: False)
+ reader.handle_messages()
+
+ return 2 # signal end-of-data
+
+ def _fetch_commandcomplete_rfq(self, msg):
+ self._state = "ready"
+ return True
+
+ # Send a Sync message, then read and discard all messages until we
+ # receive a ReadyForQuery message.
+ def _sync(self):
+ # it is assumed _sync is called from sync_on_error, which holds
+ # a _sock_lock throughout the call
+ self._send(Sync())
+ self._flush()
+ reader = MessageReader(self)
+ reader.ignore_unhandled_messages = True
+ reader.add_message(ReadyForQuery, lambda msg: True)
+ reader.handle_messages()
+
+ def close_statement(self, statement):
+ if self._state == "closed":
+ return
+ self.verifyState("ready")
+ self._sock_lock.acquire()
+ try:
+ self._send(ClosePreparedStatement(statement))
+ self._send(Sync())
+ self._flush()
+
+ reader = MessageReader(self)
+ reader.add_message(CloseComplete, lambda msg: 0)
+ reader.add_message(ReadyForQuery, lambda msg: 1)
+ reader.handle_messages()
+ finally:
+ self._sock_lock.release()
+
+ def close_portal(self, portal):
+ if self._state == "closed":
+ return
+ self.verifyState("ready")
+ self._sock_lock.acquire()
+ try:
+ self._send(ClosePortal(portal))
+ self._send(Sync())
+ self._flush()
+
+ reader = MessageReader(self)
+ reader.add_message(CloseComplete, lambda msg: 0)
+ reader.add_message(ReadyForQuery, lambda msg: 1)
+ reader.handle_messages()
+ finally:
+ self._sock_lock.release()
+
+ def close(self):
+ self._sock_lock.acquire()
+ try:
+ self._send(Terminate())
+ self._flush()
+ self._sock.close()
+ self._state = "closed"
+ finally:
+ self._sock_lock.release()
+
+ def _onParameterStatusReceived(self, msg):
+ if msg.key == "client_encoding":
+ self._client_encoding = msg.value
+ elif msg.key == "integer_datetimes":
+ self._integer_datetimes = (msg.value == "on")
+
+ def handleNoticeResponse(self, msg):
+ self.NoticeReceived(msg)
+
+ def handleParameterStatus(self, msg):
+ self.ParameterStatusReceived(msg)
+
+ def handleNotificationResponse(self, msg):
+ self.NotificationReceived(msg)
+
+ def fileno(self):
+ # This should be safe to do without a lock
+ return self._sock.fileno()
+
+ def isready(self):
+ self._sock_lock.acquire()
+ try:
+ rlst, _wlst, _xlst = select.select([self], [], [], 0)
+ if not rlst:
+ return False
+
+ self._sync()
+ return True
+ finally:
+ self._sock_lock.release()
+
+message_types = {
+ "N": NoticeResponse,
+ "R": AuthenticationRequest,
+ "S": ParameterStatus,
+ "K": BackendKeyData,
+ "Z": ReadyForQuery,
+ "T": RowDescription,
+ "E": ErrorResponse,
+ "D": DataRow,
+ "C": CommandComplete,
+ "1": ParseComplete,
+ "2": BindComplete,
+ "3": CloseComplete,
+ "s": PortalSuspended,
+ "n": NoData,
+ "t": ParameterDescription,
+ "A": NotificationResponse,
+ "c": CopyDone,
+ "d": CopyData,
+ "G": CopyInResponse,
+ "H": CopyOutResponse,
+ }
+
+