You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by je...@apache.org on 2019/11/23 21:41:51 UTC

[thrift] branch master updated: THRIFT-5027 Implement remaining read bytes checks Client: netstd Patch: Jens Geyer

This is an automated email from the ASF dual-hosted git repository.

jensg pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 5080645  THRIFT-5027 Implement remaining read bytes checks Client: netstd Patch: Jens Geyer
5080645 is described below

commit 5080645cb0ffe52de9d82685b2ab2d6b03aa6f3e
Author: Jens Geyer <je...@apache.org>
AuthorDate: Sat Nov 23 01:55:58 2019 +0100

    THRIFT-5027 Implement remaining read bytes checks
    Client: netstd
    Patch: Jens Geyer
    
    This closes #1946
---
 lib/netstd/Thrift/Protocol/TBinaryProtocol.cs      | 30 ++++++++++++++++++--
 lib/netstd/Thrift/Protocol/TCompactProtocol.cs     | 32 ++++++++++++++++++++--
 lib/netstd/Thrift/Protocol/TJSONProtocol.cs        | 25 +++++++++++++++++
 lib/netstd/Thrift/Protocol/TProtocol.cs            | 21 ++++++++++++++
 lib/netstd/Thrift/Protocol/TProtocolDecorator.cs   |  8 ++++++
 .../Transport/Client/TMemoryBufferTransport.cs     |  2 +-
 .../Thrift/Transport/Layered/TBufferedTransport.cs | 11 ++++++++
 .../Thrift/Transport/Layered/TFramedTransport.cs   | 10 +++++++
 .../Thrift/Transport/Layered/TLayeredTransport.cs  |  5 ++++
 lib/netstd/Thrift/Transport/TEndpointTransport.cs  | 26 +++++++++++++-----
 lib/netstd/Thrift/Transport/TTransport.cs          |  2 +-
 test/netstd/Client/Performance/PerformanceTests.cs |  3 +-
 test/netstd/Client/TestClient.cs                   |  6 ++--
 13 files changed, 162 insertions(+), 19 deletions(-)

diff --git a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
index f0772aa..a00c5c1 100644
--- a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
@@ -370,7 +370,7 @@ namespace Thrift.Protocol
                 ValueType = (TType) await ReadByteAsync(cancellationToken),
                 Count = await ReadI32Async(cancellationToken)
             };
-
+            CheckReadBytesAvailable(map);
             return map;
         }
 
@@ -394,7 +394,7 @@ namespace Thrift.Protocol
                 ElementType = (TType) await ReadByteAsync(cancellationToken),
                 Count = await ReadI32Async(cancellationToken)
             };
-
+            CheckReadBytesAvailable(list);
             return list;
         }
 
@@ -418,7 +418,7 @@ namespace Thrift.Protocol
                 ElementType = (TType) await ReadByteAsync(cancellationToken),
                 Count = await ReadI32Async(cancellationToken)
             };
-
+            CheckReadBytesAvailable(set);
             return set;
         }
 
@@ -507,6 +507,7 @@ namespace Thrift.Protocol
             }
 
             var size = await ReadI32Async(cancellationToken);
+            Transport.CheckReadBytesAvailable(size);
             var buf = new byte[size];
             await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
             return buf;
@@ -536,11 +537,34 @@ namespace Thrift.Protocol
                 return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, size);
             }
 
+            Transport.CheckReadBytesAvailable(size);
             var buf = new byte[size];
             await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
             return Encoding.UTF8.GetString(buf, 0, buf.Length);
         }
 
+        // Return the minimum number of bytes a type will consume on the wire
+        public override int GetMinSerializedSize(TType type)
+        {
+            switch (type)
+            {
+                case TType.Stop: return 0;
+                case TType.Void: return 0;
+                case TType.Bool: return sizeof(byte);
+                case TType.Byte: return sizeof(byte);
+                case TType.Double: return sizeof(double);
+                case TType.I16: return sizeof(short);
+                case TType.I32: return sizeof(int);
+                case TType.I64: return sizeof(long);
+                case TType.String: return sizeof(int);  // string length
+                case TType.Struct: return 0;  // empty struct
+                case TType.Map: return sizeof(int);  // element count
+                case TType.Set: return sizeof(int);  // element count
+                case TType.List: return sizeof(int);  // element count
+                default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+            }
+        }
+
         public class Factory : TProtocolFactory
         {
             protected bool StrictRead;
diff --git a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
index 921507c..a8a46f2 100644
--- a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
@@ -590,7 +590,9 @@ namespace Thrift.Protocol
 
             var size = (int) await ReadVarInt32Async(cancellationToken);
             var keyAndValueType = size == 0 ? (byte) 0 : (byte) await ReadByteAsync(cancellationToken);
-            return new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+            var map = new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+            CheckReadBytesAvailable(map);
+            return map;
         }
 
         public override async Task ReadMapEndAsync(CancellationToken cancellationToken)
@@ -703,6 +705,7 @@ namespace Thrift.Protocol
                 return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, length);
             }
 
+            Transport.CheckReadBytesAvailable(length);
             var buf = new byte[length];
             await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
             return Encoding.UTF8.GetString(buf, 0, length);
@@ -718,6 +721,7 @@ namespace Thrift.Protocol
             }
 
             // read data
+            Transport.CheckReadBytesAvailable(length);
             var buf = new byte[length];
             await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
             return buf;
@@ -745,7 +749,9 @@ namespace Thrift.Protocol
             }
 
             var type = GetTType(sizeAndType);
-            return new TList(type, size);
+            var list = new TList(type, size);
+            CheckReadBytesAvailable(list);
+            return list;
         }
 
         public override async Task ReadListEndAsync(CancellationToken cancellationToken)
@@ -856,6 +862,28 @@ namespace Thrift.Protocol
             return (uint) (n << 1) ^ (uint) (n >> 31);
         }
 
+        // Return the minimum number of bytes a type will consume on the wire
+        public override int GetMinSerializedSize(TType type)
+        {
+            switch (type)
+            {
+                case TType.Stop:    return 0;
+                case TType.Void:    return 0;
+                case TType.Bool:   return sizeof(byte);
+                case TType.Double: return 8;  // uses fixedLongToBytes() which always writes 8 bytes
+                case TType.Byte: return sizeof(byte);
+                case TType.I16:     return sizeof(byte);  // zigzag
+                case TType.I32:     return sizeof(byte);  // zigzag
+                case TType.I64:     return sizeof(byte);  // zigzag
+                case TType.String: return sizeof(byte);  // string length
+                case TType.Struct:  return 0;             // empty struct
+                case TType.Map:     return sizeof(byte);  // element count
+                case TType.Set:    return sizeof(byte);  // element count
+                case TType.List:    return sizeof(byte);  // element count
+                default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+            }
+        }
+
         public class Factory : TProtocolFactory
         {
             public override TProtocol GetProtocol(TTransport trans)
diff --git a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
index 464bd62..7bc7130 100644
--- a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
@@ -703,6 +703,7 @@ namespace Thrift.Protocol
             map.KeyType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
             map.ValueType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
             map.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+            CheckReadBytesAvailable(map);
             await ReadJsonObjectStartAsync(cancellationToken);
             return map;
         }
@@ -719,6 +720,7 @@ namespace Thrift.Protocol
             await ReadJsonArrayStartAsync(cancellationToken);
             list.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
             list.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+            CheckReadBytesAvailable(list);
             return list;
         }
 
@@ -733,6 +735,7 @@ namespace Thrift.Protocol
             await ReadJsonArrayStartAsync(cancellationToken);
             set.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
             set.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+            CheckReadBytesAvailable(set);
             return set;
         }
 
@@ -782,6 +785,28 @@ namespace Thrift.Protocol
             return await ReadJsonBase64Async(cancellationToken);
         }
 
+        // Return the minimum number of bytes a type will consume on the wire
+        public override int GetMinSerializedSize(TType type)
+        {
+            switch (type)
+            {
+                case TType.Stop: return 0;
+                case TType.Void: return 0;
+                case TType.Bool: return 1;  // written as int  
+                case TType.Byte: return 1;
+                case TType.Double: return 1;
+                case TType.I16: return 1;
+                case TType.I32: return 1;
+                case TType.I64: return 1;
+                case TType.String: return 2;  // empty string
+                case TType.Struct: return 2;  // empty struct
+                case TType.Map: return 2;  // empty map
+                case TType.Set: return 2;  // empty set
+                case TType.List: return 2;  // empty list
+                default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+            }
+        }
+
         /// <summary>
         ///     Factory for JSON protocol objects
         /// </summary>
diff --git a/lib/netstd/Thrift/Protocol/TProtocol.cs b/lib/netstd/Thrift/Protocol/TProtocol.cs
index dca3f9e..5275c9c 100644
--- a/lib/netstd/Thrift/Protocol/TProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocol.cs
@@ -77,6 +77,27 @@ namespace Thrift.Protocol
             _isDisposed = true;
         }
 
+
+        protected void CheckReadBytesAvailable(TSet set)
+        {
+            Transport.CheckReadBytesAvailable(set.Count * GetMinSerializedSize(set.ElementType));
+        }
+
+        protected void CheckReadBytesAvailable(TList list)
+        {
+            Transport.CheckReadBytesAvailable(list.Count * GetMinSerializedSize(list.ElementType));
+        }
+
+        protected void CheckReadBytesAvailable(TMap map)
+        {
+            var elmSize = GetMinSerializedSize(map.KeyType) + GetMinSerializedSize(map.ValueType);
+            Transport.CheckReadBytesAvailable(map.Count * elmSize);
+        }
+
+        // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+        public abstract int GetMinSerializedSize(TType type);
+
+
         public virtual async Task WriteMessageBeginAsync(TMessage message)
         {
             await WriteMessageBeginAsync(message, CancellationToken.None);
diff --git a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
index 845c827..b032e83 100644
--- a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
@@ -243,5 +243,13 @@ namespace Thrift.Protocol
         {
             return await _wrappedProtocol.ReadBinaryAsync(cancellationToken);
         }
+
+        // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+        public override int GetMinSerializedSize(TType type)
+        {
+            return _wrappedProtocol.GetMinSerializedSize(type);
+        }
+
+
     }
 }
diff --git a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
index abf8f14..290e50c 100644
--- a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
@@ -41,6 +41,7 @@ namespace Thrift.Transport.Client
         {
             Bytes = (byte[])buf.Clone();
             _bytesUsed = Bytes.Length;
+            UpdateKnownMessageSize(_bytesUsed);
         }
 
         public int Position { get; set; }
@@ -121,7 +122,6 @@ namespace Thrift.Transport.Client
 
         public override ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
         {
-            CheckReadBytesAvailable(length);
             var count = Math.Min(Length - Position, length);
             Buffer.BlockCopy(Bytes, Position, buffer, offset, count);
             Position += count;
diff --git a/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
index 10cec3c..dee52dd 100644
--- a/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
@@ -172,6 +172,17 @@ namespace Thrift.Transport
             await InnerTransport.FlushAsync(cancellationToken);
         }
 
+        public override void CheckReadBytesAvailable(long numBytes)
+        {
+            var buffered = ReadBuffer.Length - ReadBuffer.Position;
+            if (buffered < numBytes)
+            {
+                numBytes -= buffered;
+                InnerTransport.CheckReadBytesAvailable(numBytes);
+            }
+        }
+
+
         private void CheckNotDisposed()
         {
             if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
index 58b45f7..be1513f 100644
--- a/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
@@ -155,6 +155,16 @@ namespace Thrift.Transport
             WriteBuffer.Seek(0, SeekOrigin.End);
         }
 
+        public override void CheckReadBytesAvailable(long numBytes)
+        {
+            var buffered = ReadBuffer.Length - ReadBuffer.Position;
+            if (buffered < numBytes)
+            {
+                numBytes -= buffered;
+                InnerTransport.CheckReadBytesAvailable(numBytes);
+            }
+        }
+
         private void CheckNotDisposed()
         {
             if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
index 59d98ff..2137ae4 100644
--- a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
@@ -19,5 +19,10 @@ namespace Thrift.Transport
         {
             InnerTransport.UpdateKnownMessageSize(size);
         }
+
+        public override void CheckReadBytesAvailable(long numBytes)
+        {
+            InnerTransport.CheckReadBytesAvailable(numBytes);
+        }
     }
 }
diff --git a/lib/netstd/Thrift/Transport/TEndpointTransport.cs b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
index 810f3f4..fa2ac6b 100644
--- a/lib/netstd/Thrift/Transport/TEndpointTransport.cs
+++ b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
@@ -9,6 +9,7 @@ namespace Thrift.Transport
     abstract public class TEndpointTransport : TTransport
     {
         protected long MaxMessageSize { get => Configuration.MaxMessageSize; }
+        protected long KnownMessageSize { get; private set; }
         protected long RemainingMessageSize { get; private set; }
 
         private readonly TConfiguration _configuration;
@@ -25,22 +26,33 @@ namespace Thrift.Transport
         /// <summary>
         /// Resets RemainingMessageSize to the configured maximum 
         /// </summary>
-        protected void ResetConsumedMessageSize(long knownSize = -1)
+        protected void ResetConsumedMessageSize(long newSize = -1)
         {
-            if(knownSize >= 0)
-                RemainingMessageSize = Math.Min( MaxMessageSize, knownSize);
-            else
+            // full reset 
+            if (newSize < 0)
+            {
+                KnownMessageSize = MaxMessageSize;
                 RemainingMessageSize = MaxMessageSize;
+                return;
+            }
+
+            // update only: message size can shrink, but not grow
+            Debug.Assert(KnownMessageSize <= MaxMessageSize);
+            if (newSize > KnownMessageSize)
+                throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+
+            KnownMessageSize = newSize;
+            RemainingMessageSize = newSize;
         }
 
         /// <summary>
         /// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
-        /// Will throw if we already consumed too many bytes.
+        /// Will throw if we already consumed too many bytes or if the new size is larger than allowed.
         /// </summary>
         /// <param name="size"></param>
         public override void UpdateKnownMessageSize(long size)
         {
-            var consumed = MaxMessageSize - RemainingMessageSize;
+            var consumed = KnownMessageSize - RemainingMessageSize;
             ResetConsumedMessageSize(size);
             CountConsumedMessageBytes(consumed);
         }
@@ -49,7 +61,7 @@ namespace Thrift.Transport
         /// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
         /// </summary>
         /// <param name="numBytes"></param>
-        protected void CheckReadBytesAvailable(long numBytes)
+        public override void CheckReadBytesAvailable(long numBytes)
         {
             if (RemainingMessageSize < numBytes)
                 throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
diff --git a/lib/netstd/Thrift/Transport/TTransport.cs b/lib/netstd/Thrift/Transport/TTransport.cs
index 8f510dd..dedd51d 100644
--- a/lib/netstd/Thrift/Transport/TTransport.cs
+++ b/lib/netstd/Thrift/Transport/TTransport.cs
@@ -34,7 +34,7 @@ namespace Thrift.Transport
         public abstract bool IsOpen { get; }
         public abstract TConfiguration Configuration { get; }
         public abstract void UpdateKnownMessageSize(long size);
-
+        public abstract void CheckReadBytesAvailable(long numBytes);
         public void Dispose()
         {
             Dispose(true);
diff --git a/test/netstd/Client/Performance/PerformanceTests.cs b/test/netstd/Client/Performance/PerformanceTests.cs
index 05c64b2..2c79aa6 100644
--- a/test/netstd/Client/Performance/PerformanceTests.cs
+++ b/test/netstd/Client/Performance/PerformanceTests.cs
@@ -68,10 +68,9 @@ namespace Client.Tests
             foreach (var layered in Enum.GetValues(typeof(LayeredChoice)))
             {
                 Layered = (LayeredChoice)layered;
-
                 await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TBinaryProtocol>(b); });
                 await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TCompactProtocol>(b); });
-                //await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TJsonProtocol>(b); });
+                await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TJsonProtocol>(b); });
             }
         }
 
diff --git a/test/netstd/Client/TestClient.cs b/test/netstd/Client/TestClient.cs
index 0c147dc..3eab865 100644
--- a/test/netstd/Client/TestClient.cs
+++ b/test/netstd/Client/TestClient.cs
@@ -446,7 +446,7 @@ namespace ThriftTest
             Normal,          // Fairly small array of usual size (256 bytes)
             Large,           // Large writes/reads may cause range check errors
             PipeWriteLimit,  // Windows Limit: Pipe write operations across a network are limited to 65,535 bytes per write.
-            TwentyMB         // that's quite a bit of data
+            FifteenMB        // that's quite a bit of data
         };
 
         public static byte[] PrepareTestData(bool randomDist, BinaryTestSize testcase)
@@ -466,8 +466,8 @@ namespace ThriftTest
                 case BinaryTestSize.PipeWriteLimit:
                     amount = 0xFFFF + 128;
                     break;
-                case BinaryTestSize.TwentyMB:
-                    amount = 20 * 1024 * 1024;
+                case BinaryTestSize.FifteenMB:
+                    amount = 15 * 1024 * 1024;
                     break;
                 default:
                     throw new ArgumentException(nameof(testcase));