You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@livy.apache.org by GitBox <gi...@apache.org> on 2020/03/10 01:05:50 UTC

[GitHub] [incubator-livy] squito commented on a change in pull request #284: [LIVY-752][THRIFT] Fix implementation of limits on connections.

squito commented on a change in pull request #284: [LIVY-752][THRIFT] Fix implementation of limits on connections.
URL: https://github.com/apache/incubator-livy/pull/284#discussion_r390042082
 
 

 ##########
 File path: thriftserver/server/src/main/scala/org/apache/livy/thriftserver/LivyThriftSessionManager.scala
 ##########
 @@ -455,21 +443,58 @@ class LivyThriftSessionManager(val server: LivyThriftServer, val livyConf: LivyC
     }
   }
 
-  // Taken from Hive
-  private def anyViolations(username: String, ipAddress: String): Option[String] = {
-    val userAndAddress = username + ":" + ipAddress
-    if (trackConnectionsPerUser(username) && !withinLimits(username, userLimit)) {
-      Some(s"Connection limit per user reached (user: $username limit: $userLimit)")
-    } else if (trackConnectionsPerIpAddress(ipAddress) &&
-        !withinLimits(ipAddress, ipAddressLimit)) {
-      Some(s"Connection limit per ipaddress reached (ipaddress: $ipAddress limit: " +
-        s"$ipAddressLimit)")
-    } else if (trackConnectionsPerUserIpAddress(username, ipAddress) &&
-        !withinLimits(userAndAddress, userIpAddressLimit)) {
-      Some(s"Connection limit per user:ipaddress reached (user:ipaddress: $userAndAddress " +
-        s"limit: $userIpAddressLimit)")
-    } else {
-      None
+  // Visible for testing
+  @throws[HiveSQLException]
+  private[thriftserver] def incrementConnections(
+      username: String,
+      ipAddress: String,
+      forwardedAddresses: util.List[String]): Unit = {
+    val clientIpAddress: String = getOriginClientIpAddress(ipAddress, forwardedAddresses)
+    val userAndAddress = username + ":" + clientIpAddress
+    val trackUser = trackConnectionsPerUser(username)
+    val trackIpAddress = trackConnectionsPerIpAddress(clientIpAddress)
+    val trackUserIpAddress = trackConnectionsPerUserIpAddress(username, clientIpAddress)
+    var userLimitExceeded = false
+    var ipAddressLimitExceeded = false
+    var userIpAddressLimitExceeded = false
+
+    // Optimistically increment the counts while getting them to check for violations.
+    if (trackUser) {
+      val userCount = incrementConnectionsCount(username)
+      if (userCount > userLimit) userLimitExceeded = true
+    }
+    if (trackIpAddress) {
+      val ipAddressCount = incrementConnectionsCount(clientIpAddress)
+      if (ipAddressCount > ipAddressLimit) ipAddressLimitExceeded = true
+    }
+    if (trackUserIpAddress) {
+      val userIpAddressCount = incrementConnectionsCount(userAndAddress)
+      if (userIpAddressCount > userIpAddressLimit) userIpAddressLimitExceeded = true
+    }
+
+    // If any limit has been exceeded, we won't be going ahead with the connection,
+    // so decrement all counts that have been incremented.
+    if (userLimitExceeded || ipAddressLimitExceeded || userIpAddressLimitExceeded) {
+      if (trackUser) decrementConnectionsCount(username)
 
 Review comment:
   I think marco is suggesting that you include the throw in there as well, which then decreases the complexity of the conditions, and lets you skip the `userLimitExceeded` etc vals and the entire `val violation` section below.  (though you still repeat the `decrementCounts` part
   
   ```scala
   def errorAndThrow(msg: String): Unit = {
     error(msg)
     throw new HiveSQLException(msg)
   }
   
   
       if (trackUser) {
         val userCount = incrementConnectionsCount(username)
         if (userCount > userLimit) {
           decrementConnectionsCount(username)
           errorAndThrow(s"Connection limit per user reached (user: $username limit: $userLimit)")        
         }
       }
       if (trackIpAddress) {
         val ipAddressCount = incrementConnectionsCount(clientIpAddress)
         if (ipAddressCount > ipAddressLimit) {
           if (trackUser) decrementConnectionsCount(username)
           decrementConnectionsCount(clientIpAddress)
           errorAndThrow(s"Connection limit per ipaddress reached (ipaddress: $clientIpAddress limit: " +
             s"$ipAddressLimit)")
         }
       }
       if (trackUserIpAddress) {
         val userIpAddressCount = incrementConnectionsCount(userAndAddress)
         if (userIpAddressCount > userIpAddressLimit) {
           if (trackUser) decrementConnectionsCount(username)
           if (trackIpAddress) decrementConnectionsCount(clientIpAddress)
           decrementConnectionsCount(userAndAddress)
           errorAndThrow(s"Connection limit per user:ipaddress reached (user:ipaddress: $userAndAddress " +
             s"limit: $userIpAddressLimit)")
         }
       }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services