You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by je...@apache.org on 2020/04/09 20:00:21 UTC

[thrift] branch master updated: THRIFT-5166: Add support for using WebSockets as a server transport. Client: d Patch: James Lacey

This is an automated email from the ASF dual-hosted git repository.

jensg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bbdb1a  THRIFT-5166: Add support for using WebSockets as a server transport. Client: d Patch: James Lacey
6bbdb1a is described below

commit 6bbdb1a46ce6ba0ac4e27e29b2c9c9eef107186c
Author: James Lacey <ja...@gmail.com>
AuthorDate: Mon Apr 6 09:17:59 2020 -0700

    THRIFT-5166: Add support for using WebSockets as a server transport.
    Client: d
    Patch: James Lacey
    
    This closes #2087
---
 lib/d/src/thrift/transport/websocket.d | 388 +++++++++++++++++++++++++++++++++
 1 file changed, 388 insertions(+)

diff --git a/lib/d/src/thrift/transport/websocket.d b/lib/d/src/thrift/transport/websocket.d
new file mode 100644
index 0000000..25d1f7d
--- /dev/null
+++ b/lib/d/src/thrift/transport/websocket.d
@@ -0,0 +1,388 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements. See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership. The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License. You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied. See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+module thrift.transport.websocket;
+
+import std.algorithm;
+import std.algorithm.searching;
+import std.base64;
+import std.bitmanip;
+import std.conv;
+import std.digest.sha;
+import std.stdio;
+import std.string;
+import std.uni;
+import thrift.base : VERSION;
+import thrift.transport.base;
+import thrift.transport.http;
+
+/**
+ * WebSocket server transport.
+ */
+final class TServerWebSocketTransport(bool binary) : THttpTransport {
+  /**
+   * Constructs a new instance.
+   *
+   * Param:
+   *   transport = The underlying transport used for the actual I/O.
+   */
+  this(TTransport transport) {
+    super(transport);
+    transport_ = transport;
+  }
+
+  override size_t read(ubyte[] buf) {
+    // If we do not have a good handshake, the client will attempt one.
+    if (!handshakeComplete) {
+      resetHandshake();
+      super.read(buf);
+      // If we did not get everything we expected, the handshake failed
+      // and we need to send a 400 response back.
+      if (!handshakeComplete) {
+        sendBadRequest();
+        return 0;
+      }
+      // Otherwise, send back the 101 response.
+      super.flush();
+    }
+
+    // If the buffer is empty, read a new frame off the wire.
+    if (readBuffer_.empty) {
+      if (!readFrame()) {
+        return 0;
+      }
+    }
+
+    auto size = min(readBuffer_.length, buf.length);
+    buf[0..size] = readBuffer_[0..size];
+    readBuffer_ = readBuffer_[size..$];
+    return size;
+  }
+
+  override void write(in ubyte[] buf) {
+    writeBuffer_ ~= buf;
+  }
+
+  override void flush() {
+    if (writeBuffer_.empty) {
+      return;
+    }
+
+    // Properly reset the write buffer even some of the protocol operations go
+    // wrong.
+    scope (exit) {
+      writeBuffer_.length = 0;
+      writeBuffer_.assumeSafeAppend();
+    }
+
+    writeFrameHeader();
+    transport_.write(writeBuffer_);
+    transport_.flush();
+  }
+
+protected:
+  override string getHeader(size_t dataLength) {
+    return "HTTP/1.1 101 Switching Protocols\r\n" ~
+      "Server: Thrift/" ~ VERSION ~ "\r\n" ~
+      "Upgrade: websocket\r\n" ~
+      "Connection: Upgrade\r\n" ~
+      "Sec-WebSocket-Accept: " ~ acceptKey_ ~ "\r\n" ~
+      "\r\n";
+  }
+
+  override void parseHeader(const(ubyte)[] header) {
+    auto split = findSplit(header, [':']);
+    if (split[1].empty) {
+      // No colon found.
+      return;
+    }
+
+    static bool compToLower(ubyte a, ubyte b) {
+      return toLower(a) == toLower(b);
+    }
+
+    if (startsWith!compToLower(split[0], cast(ubyte[])"upgrade")) {
+      auto upgrade = stripLeft(cast(const(char)[])split[2]);
+      upgrade_ = sicmp(upgrade, "websocket") == 0;
+    } else if (startsWith!compToLower(split[0], cast(ubyte[])"connection")) {
+      auto connection = stripLeft(cast(const(char)[])split[2]);
+      connection_ = sicmp(connection, "upgrade") == 0;
+    } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-key")) {
+      auto secWebSocketKey = stripLeft(cast(const(char)[])split[2]);
+      auto hash = sha1Of(secWebSocketKey ~ WEBSOCKET_GUID);
+      acceptKey_ = Base64.encode(hash);
+      secWebSocketKey_ = true;
+    } else if (startsWith!compToLower(split[0], cast(ubyte[])"sec-websocket-version")) {
+      auto secWebSocketVersion = stripLeft(cast(const(char)[])split[2]);
+      secWebSocketVersion_ = sicmp(secWebSocketVersion, "13") == 0;
+    }
+  }
+
+  override bool parseStatusLine(const(ubyte)[] status) {
+    // Method SP Request-URI SP HTTP-Version CRLF.
+    auto split = findSplit(status, [' ']);
+    if (split[1].empty) {
+      throw new TTransportException("Bad status: " ~ to!string(status),
+        TTransportException.Type.CORRUPTED_DATA);
+    }
+
+    auto uriVersion = split[2][countUntil!"a != b"(split[2], ' ') .. $];
+    if (!canFind(uriVersion, ' ')) {
+      throw new TTransportException("Bad status: " ~ to!string(status),
+        TTransportException.Type.CORRUPTED_DATA);
+    }
+
+    if (split[0] == "GET") {
+      // GET method ok, looking for content.
+      return true;
+    }
+
+    throw new TTransportException("Bad status (unsupported method): " ~
+      to!string(status), TTransportException.Type.CORRUPTED_DATA);
+  }
+
+private:
+  @property bool handshakeComplete() { 
+    return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
+  }
+
+  void failConnection(CloseCode reason) {
+    writeFrameHeader(Opcode.Close);
+    transport_.write(nativeToBigEndian!ushort(reason));
+    transport_.flush();
+    transport_.close();
+  }
+
+  void pong() {
+    writeFrameHeader(Opcode.Pong);
+    transport_.write(readBuffer_);
+    transport_.flush();
+  }
+
+  bool readFrame() {
+    ubyte[8] headerBuffer;
+
+    auto read = transport_.read(headerBuffer[0..2]);
+    if (read < 2) {
+      return false;
+    }
+    // Since Thrift has its own message end marker and we read frame by frame,
+    // it doesn't really matter if the frame is marked as FIN.
+    // Capture it only for debugging only.
+    debug auto fin = (headerBuffer[0] & 0x80) != 0;
+
+    // RSV1, RSV2, RSV3
+    if ((headerBuffer[0] & 0x70) != 0) {
+      failConnection(CloseCode.ProtocolError);
+      throw new TTransportException("Reserved bits must be zeroes", TTransportException.Type.CORRUPTED_DATA);
+    }
+
+    Opcode opcode;
+    try {
+      opcode = to!Opcode(headerBuffer[0] & 0x0F);
+    } catch (ConvException) {
+      failConnection(CloseCode.ProtocolError);
+      throw new TTransportException("Unknown opcode", TTransportException.Type.CORRUPTED_DATA);
+    }
+
+    // Mask
+    if ((headerBuffer[1] & 0x80) == 0) {
+      failConnection(CloseCode.ProtocolError);
+      throw new TTransportException("Messages from the client must be masked", TTransportException.Type.CORRUPTED_DATA);
+    }
+
+    // Read the length
+    ulong payloadLength = headerBuffer[1] & 0x7F;
+    if (payloadLength == 126) {
+      read = transport_.read(headerBuffer[0..2]);
+      if (read < 2) {
+        return false;
+      }
+      payloadLength = bigEndianToNative!ushort(headerBuffer[0..2]);
+    } else if (payloadLength == 127) {
+      read = transport_.read(headerBuffer);
+      if (read < headerBuffer.length) {
+        return false;
+      }
+      payloadLength = bigEndianToNative!ulong(headerBuffer);
+      if ((payloadLength & 0x8000000000000000) != 0) {
+        failConnection(CloseCode.ProtocolError);
+        throw new TTransportException("The most significant bit of the payload length must be zero", 
+          TTransportException.Type.CORRUPTED_DATA);
+      }
+    }
+
+    // size_t is smaller than a ulong on a 32-bit system
+    static if (size_t.max < ulong.max) {
+      if(payloadLength > size_t.max) {
+        failConnection(CloseCode.MessageTooBig);
+        return false;
+      }
+    }
+
+    auto length = cast(size_t)payloadLength;
+
+    if (length > 0) {
+      // Read the masking key
+      read = transport_.read(headerBuffer[0..4]);
+      if (read < 4) {
+        return false;
+      }
+
+      readBuffer_ = new ubyte[](length);
+      read = transport_.read(readBuffer_);
+      if (read < length) {
+        return false;
+      }
+
+      // Unmask the data
+      for (size_t i = 0; i < length; i++) {
+        readBuffer_[i] ^= headerBuffer[i % 4];
+      }
+
+      debug writef("FIN=%d, Opcode=%X, length=%d, payload=%s\n",
+          fin,
+          opcode,
+          length,
+          binary ? readBuffer_.toHexString() : cast(string)readBuffer_);
+    }
+
+    switch (opcode) {
+      case Opcode.Close:
+        debug {
+          if (length >= 2) {
+            CloseCode closeCode;
+            try {
+              closeCode = to!CloseCode(bigEndianToNative!ushort(readBuffer_[0..2]));
+            } catch (ConvException) {
+              closeCode = CloseCode.NoStatusCode;
+            }
+
+            string closeReason;
+            if (length == 2) {
+              closeReason = to!string(cast(CloseCode)closeCode);
+            } else {
+              closeReason = cast(string)readBuffer_[2..$];
+            }
+
+            writef("Connection closed: %d %s\n", closeCode, closeReason);
+          }
+        }
+        transport_.close();
+        return false;
+      case Opcode.Ping:
+        pong();
+        return readFrame();
+      default:
+        return true;
+    }
+  }
+
+  void resetHandshake() {
+    connection_ = false;
+    secWebSocketKey_ = false;
+    secWebSocketVersion_ = false;
+    upgrade_ = false;
+  }
+
+  void sendBadRequest() {
+    auto header = "HTTP/1.1 400 Bad Request\r\n" ~
+      "Server: Thrift/" ~ VERSION ~ "\r\n" ~
+      "\r\n";
+    transport_.write(cast(const(ubyte[]))header);
+    transport_.flush();
+    transport_.close();
+  }
+
+  void writeFrameHeader(Opcode opcode = Opcode.Continuation) {
+    size_t headerSize = 1;
+    if (writeBuffer_.length < 126) {
+      ++headerSize;
+    } else if (writeBuffer_.length < 65536) {
+      headerSize += 3;
+    } else {
+      headerSize += 9;
+    }
+    // The server does not mask the response
+
+    ubyte[] header = new ubyte[headerSize];
+    if (opcode == Opcode.Continuation) {
+      header[0] = binary ? Opcode.Binary : Opcode.Text;
+    }
+    else {
+      header[0] = opcode;
+    }
+    header[0] |= 0x80;
+    if (writeBuffer_.length < 126) {
+      header[1] = cast(ubyte)writeBuffer_.length;
+    } else if (writeBuffer_.length < 65536) {
+      header[1] = 126;
+      header[2..4] = nativeToBigEndian(cast(ushort)writeBuffer_.length);
+    } else {
+      header[1] = 127;
+      header[2..10] = nativeToBigEndian(cast(ulong)writeBuffer_.length);
+    }
+
+    transport_.write(header);
+  }
+
+  enum WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
+
+  TTransport transport_;
+
+  string acceptKey_;
+  bool connection_;
+  bool secWebSocketKey_;
+  bool secWebSocketVersion_;
+  bool upgrade_;
+  ubyte[] readBuffer_;
+  ubyte[] writeBuffer_;
+}
+
+class TServerWebSocketTransportFactory(bool binary) : TTransportFactory {
+  override TTransport getTransport(TTransport trans) {
+    return new TServerWebSocketTransport!binary(trans);
+  }
+}
+
+alias TServerBinaryWebSocketTransportFactory = TServerWebSocketTransportFactory!true;
+alias TServerTextWebSocketTransportFactory = TServerWebSocketTransportFactory!false;
+
+private enum CloseCode : ushort {
+  NormalClosure = 1000,
+  GoingAway = 1001,
+  ProtocolError = 1002,
+  UnsupportedDataType = 1003,
+  NoStatusCode = 1005,
+  AbnormalClosure = 1006,
+  InvalidData = 1007,
+  PolicyViolation = 1008,
+  MessageTooBig = 1009,
+  ExtensionExpected = 1010,
+  UnexpectedError = 1011,
+  NotSecure = 1015
+}
+
+private enum Opcode : ubyte {
+  Continuation = 0x0,
+  Text = 0x1,
+  Binary = 0x2,
+  Close = 0x8,
+  Ping = 0x9,
+  Pong = 0xA
+}