You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by jf...@apache.org on 2014/02/21 18:12:11 UTC

git commit: THRIFT-1719:SASL client support for Python Client: py Patch: Tyler Hobbs

Repository: thrift
Updated Branches:
  refs/heads/master 01386c95a -> 8b3ca02a2


THRIFT-1719:SASL client support for Python
Client: py
Patch: Tyler Hobbs

Add SASL client transports that will work with the Java lib's TSaslTransport


Project: http://git-wip-us.apache.org/repos/asf/thrift/repo
Commit: http://git-wip-us.apache.org/repos/asf/thrift/commit/8b3ca02a
Tree: http://git-wip-us.apache.org/repos/asf/thrift/tree/8b3ca02a
Diff: http://git-wip-us.apache.org/repos/asf/thrift/diff/8b3ca02a

Branch: refs/heads/master
Commit: 8b3ca02a2ad3a005685f66dc85a625a6731144b7
Parents: 01386c9
Author: jfarrell <jf...@apache.org>
Authored: Fri Feb 21 12:11:14 2014 -0500
Committer: jfarrell <jf...@apache.org>
Committed: Fri Feb 21 12:11:14 2014 -0500

----------------------------------------------------------------------
 lib/py/src/transport/TTransport.py | 111 ++++++++++++++++++++++++++++++++
 lib/py/src/transport/TTwisted.py   | 107 +++++++++++++++++++++++++++++-
 2 files changed, 216 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/thrift/blob/8b3ca02a/lib/py/src/transport/TTransport.py
----------------------------------------------------------------------
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 4481371..5ab2fde 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -328,3 +328,114 @@ class TFileObjectTransport(TTransportBase):
 
   def flush(self):
     self.fileobj.flush()
+
+
+class TSaslClientTransport(TTransportBase, CReadableTransport):
+  """
+  SASL transport 
+  """
+
+  START = 1
+  OK = 2
+  BAD = 3
+  ERROR = 4
+  COMPLETE = 5
+
+  def __init__(self, transport, host, service, mechanism='GSSAPI',
+      **sasl_kwargs):
+    """
+    transport: an underlying transport to use, typically just a TSocket
+    host: the name of the server, from a SASL perspective
+    service: the name of the server's service, from a SASL perspective
+    mechanism: the name of the preferred mechanism to use
+
+    All other kwargs will be passed to the puresasl.client.SASLClient
+    constructor.
+    """
+
+    from puresasl.client import SASLClient
+
+    self.transport = transport
+    self.sasl = SASLClient(host, service, mechanism, **sasl_kwargs)
+
+    self.__wbuf = StringIO()
+    self.__rbuf = StringIO()
+
+  def open(self):
+    if not self.transport.isOpen():
+      self.transport.open()
+
+    self.send_sasl_msg(self.START, self.sasl.mechanism)
+    self.send_sasl_msg(self.OK, self.sasl.process())
+
+    while True:
+      status, challenge = self.recv_sasl_msg()
+      if status == self.OK:
+        self.send_sasl_msg(self.OK, self.sasl.process(challenge))
+      elif status == self.COMPLETE:
+        if not self.sasl.complete:
+          raise TTransportException("The server erroneously indicated "
+              "that SASL negotiation was complete")
+        else:
+          break
+      else:
+        raise TTransportException("Bad SASL negotiation status: %d (%s)"
+            % (status, challenge))
+
+  def send_sasl_msg(self, status, body):
+    header = pack(">BI", status, len(body))
+    self.transport.write(header + body)
+    self.transport.flush()
+
+  def recv_sasl_msg(self):
+    header = self.transport.readAll(5)
+    status, length = unpack(">BI", header)
+    if length > 0:
+      payload = self.transport.readAll(length)
+    else:
+      payload = ""
+    return status, payload
+
+  def write(self, data):
+    self.__wbuf.write(data)
+
+  def flush(self):
+    data = self.__wbuf.getvalue()
+    encoded = self.sasl.wrap(data)
+    self.transport.write(''.join((pack("!i", len(encoded)), encoded)))
+    self.transport.flush()
+    self.__wbuf = StringIO()
+
+  def read(self, sz):
+    ret = self.__rbuf.read(sz)
+    if len(ret) != 0:
+      return ret
+
+    self._read_frame()
+    return self.__rbuf.read(sz)
+
+  def _read_frame(self):
+    header = self.transport.readAll(4)
+    length, = unpack('!i', header)
+    encoded = self.transport.readAll(length)
+    self.__rbuf = StringIO(self.sasl.unwrap(encoded))
+
+  def close(self):
+    self.sasl.dispose()
+    self.transport.close()
+
+  # based on TFramedTransport
+  @property
+  def cstringio_buf(self):
+    return self.__rbuf
+
+  def cstringio_refill(self, prefix, reqlen):
+    # self.__rbuf will already be empty here because fastbinary doesn't
+    # ask for a refill until the previous buffer is empty.  Therefore,
+    # we can start reading new frames immediately.
+    while len(prefix) < reqlen:
+      self._read_frame()
+      prefix += self.__rbuf.getvalue()
+    self.__rbuf = StringIO(prefix)
+    return self.__rbuf
+

http://git-wip-us.apache.org/repos/asf/thrift/blob/8b3ca02a/lib/py/src/transport/TTwisted.py
----------------------------------------------------------------------
diff --git a/lib/py/src/transport/TTwisted.py b/lib/py/src/transport/TTwisted.py
index 3ce3eb2..2b77414 100644
--- a/lib/py/src/transport/TTwisted.py
+++ b/lib/py/src/transport/TTwisted.py
@@ -17,14 +17,15 @@
 # under the License.
 #
 
+import struct
 from cStringIO import StringIO
 
 from zope.interface import implements, Interface, Attribute
-from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
+from twisted.internet.protocol import ServerFactory, ClientFactory, \
     connectionDone
 from twisted.internet import defer
+from twisted.internet.threads import deferToThread
 from twisted.protocols import basic
-from twisted.python import log
 from twisted.web import server, resource, http
 
 from thrift.transport import TTransport
@@ -101,6 +102,108 @@ class ThriftClientProtocol(basic.Int32StringReceiver):
         method(iprot, mtype, rseqid)
 
 
+class ThriftSASLClientProtocol(ThriftClientProtocol):
+
+    START = 1
+    OK = 2
+    BAD = 3
+    ERROR = 4
+    COMPLETE = 5
+
+    MAX_LENGTH = 2 ** 31 - 1
+
+    def __init__(self, client_class, iprot_factory, oprot_factory=None,
+            host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
+        """
+        host: the name of the server, from a SASL perspective
+        service: the name of the server's service, from a SASL perspective
+        mechanism: the name of the preferred mechanism to use
+
+        All other kwargs will be passed to the puresasl.client.SASLClient
+        constructor.
+        """
+
+        from puresasl.client import SASLClient
+        self.SASLCLient = SASLClient
+
+        ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
+
+        self._sasl_negotiation_deferred = None
+        self._sasl_negotiation_status = None
+        self.client = None
+
+        if host is not None:
+            self.createSASLClient(host, service, mechanism, **sasl_kwargs)
+
+    def createSASLClient(self, host, service, mechanism, **kwargs):
+        self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
+
+    def dispatch(self, msg):
+        encoded = self.sasl.wrap(msg)
+        len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
+        ThriftClientProtocol.dispatch(self, len_and_encoded)
+
+    @defer.inlineCallbacks
+    def connectionMade(self):
+        self._sendSASLMessage(self.START, self.sasl.mechanism)
+        initial_message = yield deferToThread(self.sasl.process)
+        self._sendSASLMessage(self.OK, initial_message)
+
+        while True:
+            status, challenge = yield self._receiveSASLMessage()
+            if status == self.OK:
+                response = yield deferToThread(self.sasl.process, challenge)
+                self._sendSASLMessage(self.OK, response)
+            elif status == self.COMPLETE:
+                if not self.sasl.complete:
+                    msg = "The server erroneously indicated that SASL " \
+                          "negotiation was complete"
+                    raise TTransport.TTransportException(msg, message=msg)
+                else:
+                    break
+            else:
+                msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
+                raise TTransport.TTransportException(msg, message=msg)
+
+        self._sasl_negotiation_deferred = None
+        ThriftClientProtocol.connectionMade(self)
+
+    def _sendSASLMessage(self, status, body):
+        if body is None:
+            body = ""
+        header = struct.pack(">BI", status, len(body))
+        self.transport.write(header + body)
+
+    def _receiveSASLMessage(self):
+        self._sasl_negotiation_deferred = defer.Deferred()
+        self._sasl_negotiation_status = None
+        return self._sasl_negotiation_deferred
+
+    def connectionLost(self, reason=connectionDone):
+        if self.client:
+            ThriftClientProtocol.connectionLost(self, reason)
+
+    def dataReceived(self, data):
+        if self._sasl_negotiation_deferred:
+            # we got a sasl challenge in the format (status, length, challenge)
+            # save the status, let IntNStringReceiver piece the challenge data together
+            self._sasl_negotiation_status, = struct.unpack("B", data[0])
+            ThriftClientProtocol.dataReceived(self, data[1:])
+        else:
+            # normal frame, let IntNStringReceiver piece it together
+            ThriftClientProtocol.dataReceived(self, data)
+
+    def stringReceived(self, frame):
+        if self._sasl_negotiation_deferred:
+            # the frame is just a SASL challenge
+            response = (self._sasl_negotiation_status, frame)
+            self._sasl_negotiation_deferred.callback(response)
+        else:
+            # there's a second 4 byte length prefix inside the frame
+            decoded_frame = self.sasl.unwrap(frame[4:])
+            ThriftClientProtocol.stringReceived(self, decoded_frame)
+
+
 class ThriftServerProtocol(basic.Int32StringReceiver):
 
     MAX_LENGTH = 2 ** 31 - 1