You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by se...@apache.org on 2017/08/25 18:53:08 UTC

hive git commit: HIVE-17330 : refactor TezSessionPoolManager to separate its multiple functions (Sergey Shelukhin, reviewed by Gunther Hagleitner)

Repository: hive
Updated Branches:
  refs/heads/master 6b1038288 -> 733bc5f02


HIVE-17330 : refactor TezSessionPoolManager to separate its multiple functions (Sergey Shelukhin, reviewed by Gunther Hagleitner)


Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/733bc5f0
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/733bc5f0
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/733bc5f0

Branch: refs/heads/master
Commit: 733bc5f0280f0f8145b953d67bb50e34af713e04
Parents: 6b10382
Author: sergey <se...@apache.org>
Authored: Fri Aug 25 11:52:09 2017 -0700
Committer: sergey <se...@apache.org>
Committed: Fri Aug 25 11:53:05 2017 -0700

----------------------------------------------------------------------
 .../ql/exec/tez/RestrictedConfigChecker.java    |  79 +++
 .../ql/exec/tez/SessionExpirationTracker.java   | 235 +++++++++
 .../hadoop/hive/ql/exec/tez/TezSessionPool.java | 165 ++++++
 .../hive/ql/exec/tez/TezSessionPoolManager.java | 516 +++----------------
 .../hive/ql/exec/tez/TezSessionPoolSession.java | 158 ++++++
 .../hive/ql/exec/tez/SampleTezSessionState.java |   3 +-
 .../hive/ql/exec/tez/TestTezSessionPool.java    |   2 +-
 7 files changed, 704 insertions(+), 454 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/RestrictedConfigChecker.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/RestrictedConfigChecker.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/RestrictedConfigChecker.java
new file mode 100644
index 0000000..f6b1c1d
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/RestrictedConfigChecker.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class RestrictedConfigChecker {
+  private final static Logger LOG = LoggerFactory.getLogger(RestrictedConfigChecker.class);
+  private final List<ConfVars> restrictedHiveConf = new ArrayList<>();
+  private final List<String> restrictedNonHiveConf = new ArrayList<>();
+  private final HiveConf initConf;
+
+  RestrictedConfigChecker(HiveConf initConf) {
+    this.initConf = initConf;
+    String[] restrictedConfigs = HiveConf.getTrimmedStringsVar(initConf,
+        ConfVars.HIVE_SERVER2_TEZ_SESSION_RESTRICTED_CONFIGS);
+    if (restrictedConfigs == null || restrictedConfigs.length == 0) return;
+    HashMap<String, ConfVars> confVars = HiveConf.getOrCreateReverseMap();
+    for (String confName : restrictedConfigs) {
+      if (confName == null || confName.isEmpty()) continue;
+      confName = confName.toLowerCase();
+      ConfVars cv = confVars.get(confName);
+      if (cv != null) {
+        restrictedHiveConf.add(cv);
+      } else {
+        LOG.warn("A restricted config " + confName + " is not recognized as a Hive setting.");
+        restrictedNonHiveConf.add(confName);
+      }
+    }
+  }
+
+  public void validate(HiveConf conf) throws HiveException {
+    for (ConfVars var : restrictedHiveConf) {
+      String userValue = HiveConf.getVarWithoutType(conf, var),
+          serverValue = HiveConf.getVarWithoutType(initConf, var);
+      // Note: with some trickery, we could add logic for each type in ConfVars; for now the
+      // potential spurious mismatches (e.g. 0 and 0.0 for float) should be easy to work around.
+      validateRestrictedConfigValues(var.varname, userValue, serverValue);
+    }
+    for (String var : restrictedNonHiveConf) {
+      String userValue = conf.get(var), serverValue = initConf.get(var);
+      validateRestrictedConfigValues(var, userValue, serverValue);
+    }
+  }
+
+  private void validateRestrictedConfigValues(
+      String var, String userValue, String serverValue) throws HiveException {
+    if ((userValue == null) != (serverValue == null)
+        || (userValue != null && !userValue.equals(serverValue))) {
+      String logValue = initConf.isHiddenConfig(var) ? "(hidden)" : serverValue;
+      throw new HiveException(var + " is restricted from being set; server is configured"
+          + " to use " + logValue + ", but the query configuration specifies " + userValue);
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SessionExpirationTracker.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SessionExpirationTracker.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SessionExpirationTracker.java
new file mode 100644
index 0000000..8bee77e
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/SessionExpirationTracker.java
@@ -0,0 +1,235 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.util.Comparator;
+import java.util.Random;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.PriorityBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
+import org.apache.hadoop.hive.ql.session.SessionState;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+class SessionExpirationTracker {
+  private static final Logger LOG = LoggerFactory.getLogger(SessionExpirationTracker.class);
+  private static final Random rdm = new Random();
+
+  /** Priority queue sorted by expiration time of live sessions that could be expired. */
+  private final PriorityBlockingQueue<TezSessionPoolSession> expirationQueue;
+  /** The background restart queue that is populated when expiration is triggered by a foreground
+   * thread (i.e. getting or returning a session), to avoid delaying it. */
+  private final BlockingQueue<TezSessionPoolSession> restartQueue;
+  private final Thread expirationThread;
+  private final Thread restartThread;
+  private final long sessionLifetimeMs;
+  private final long sessionLifetimeJitterMs;
+  private final RestartImpl sessionRestartImpl;
+
+  interface RestartImpl {
+    void closeAndReopenPoolSession(TezSessionPoolSession session) throws Exception;
+  }
+
+  public static SessionExpirationTracker create(HiveConf conf, RestartImpl restartImpl) {
+    long sessionLifetimeMs = conf.getTimeVar(
+        ConfVars.HIVE_SERVER2_TEZ_SESSION_LIFETIME, TimeUnit.MILLISECONDS);
+    if (sessionLifetimeMs == 0) return null;
+    return new SessionExpirationTracker(sessionLifetimeMs, conf.getTimeVar(
+        ConfVars.HIVE_SERVER2_TEZ_SESSION_LIFETIME_JITTER, TimeUnit.MILLISECONDS), restartImpl);
+  }
+
+  private SessionExpirationTracker(
+      long sessionLifetimeMs, long sessionLifetimeJitterMs, RestartImpl restartImpl) {
+    this.sessionRestartImpl = restartImpl;
+    this.sessionLifetimeMs = sessionLifetimeMs;
+    this.sessionLifetimeJitterMs = sessionLifetimeJitterMs;
+
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Session expiration is enabled; session lifetime is "
+          + sessionLifetimeMs + " + [0, " + sessionLifetimeJitterMs + ") ms");
+    }
+    final SessionState initSessionState = SessionState.get();
+    expirationQueue = new PriorityBlockingQueue<>(11, new Comparator<TezSessionPoolSession>() {
+      @Override
+      public int compare(TezSessionPoolSession o1, TezSessionPoolSession o2) {
+        assert o1.getExpirationNs() != null && o2.getExpirationNs() != null;
+        return o1.getExpirationNs().compareTo(o2.getExpirationNs());
+      }
+    });
+    restartQueue = new LinkedBlockingQueue<>();
+
+    expirationThread = new Thread(new Runnable() {
+      @Override
+      public void run() {
+        try {
+          SessionState.setCurrentSessionState(initSessionState);
+          runExpirationThread();
+        } catch (Exception e) {
+          LOG.warn("Exception in TezSessionPool-expiration thread. Thread will shut down", e);
+        } finally {
+          LOG.info("TezSessionPool-expiration thread exiting");
+        }
+      }
+    }, "TezSessionPool-expiration");
+    restartThread = new Thread(new Runnable() {
+      @Override
+      public void run() {
+        try {
+          SessionState.setCurrentSessionState(initSessionState);
+          runRestartThread();
+        } catch (Exception e) {
+          LOG.warn("Exception in TezSessionPool-cleanup thread. Thread will shut down", e);
+        } finally {
+          LOG.info("TezSessionPool-cleanup thread exiting");
+        }
+      }
+    }, "TezSessionPool-cleanup");
+  }
+
+
+  /** Logic for the thread that restarts the sessions expired during foreground operations. */
+  private void runRestartThread() {
+    try {
+      while (true) {
+        TezSessionPoolSession next = restartQueue.take();
+        LOG.info("Restarting the expired session [" + next + "]");
+        try {
+          sessionRestartImpl.closeAndReopenPoolSession(next);
+        } catch (InterruptedException ie) {
+          throw ie;
+        } catch (Exception e) {
+          LOG.error("Failed to close or restart a session, ignoring", e);
+        }
+      }
+    } catch (InterruptedException e) {
+      LOG.info("Restart thread is exiting due to an interruption");
+    }
+  }
+
+  /** Logic for the thread that tracks session expiration and restarts sessions in background. */
+  private void runExpirationThread() {
+    try {
+      while (true) {
+        TezSessionPoolSession nextToExpire = null;
+        while (true) {
+          // Restart the sessions until one of them refuses to restart.
+          nextToExpire = expirationQueue.take();
+          if (LOG.isDebugEnabled()) {
+            LOG.debug("Seeing if we can expire [" + nextToExpire + "]");
+          }
+          try {
+            if (!nextToExpire.tryExpire(false)) break;
+          } catch (Exception e) {
+            // Reopen happens even when close fails, so there's not much to do here.
+            LOG.error("Failed to expire session " + nextToExpire + "; ignoring", e);
+            nextToExpire = null;
+            break; // Not strictly necessary; do the whole queue check again.
+          }
+          LOG.info("Tez session [" + nextToExpire + "] has expired");
+        }
+        if (nextToExpire != null && LOG.isDebugEnabled()) {
+          LOG.debug("[" + nextToExpire + "] is not ready to expire; adding it back");
+        }
+
+        // See addToExpirationQueue for why we re-check the queue.
+        synchronized (expirationQueue) {
+          // Add back the non-expired session. No need to notify, we are the only ones waiting.
+          if (nextToExpire != null) {
+            expirationQueue.add(nextToExpire);
+          }
+          nextToExpire = expirationQueue.peek();
+          if (nextToExpire != null) {
+            // Add some margin to the wait to avoid rechecking close to the boundary.
+            long timeToWaitMs = (nextToExpire.getExpirationNs() - System.nanoTime()) / 1000000L;
+            timeToWaitMs = Math.max(1, timeToWaitMs + 10);
+            if (LOG.isDebugEnabled()) {
+              LOG.debug("Waiting for ~" + timeToWaitMs + "ms to expire [" + nextToExpire + "]");
+            }
+            expirationQueue.wait(timeToWaitMs);
+          } else if (LOG.isDebugEnabled()) {
+            // Don't wait if empty - go to take() above, that will wait for us.
+            LOG.debug("Expiration queue is empty");
+          }
+        }
+      }
+    } catch (InterruptedException e) {
+      LOG.info("Expiration thread is exiting due to an interruption");
+    }
+  }
+
+
+  public void start() {
+    expirationThread.start();
+    restartThread.start();
+  }
+
+
+  public void stop() {
+    if (expirationThread != null) {
+      expirationThread.interrupt();
+    }
+    if (restartThread != null) {
+      restartThread.interrupt();
+    }
+  }
+
+
+  public void addToExpirationQueue(TezSessionPoolSession session) {
+    long jitterModMs = (long)(sessionLifetimeJitterMs * rdm.nextFloat());
+    session.setExpirationNs(System.nanoTime() + (sessionLifetimeMs + jitterModMs) * 1000000L);
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Adding a pool session [" + this + "] to expiration queue");
+    }
+    // Expiration queue is synchronized and notified upon when adding elements. Without jitter, we
+    // wouldn't need this, and could simple look at the first element and sleep for the wait time.
+    // However, when many things are added at once, it may happen that we will see the one that
+    // expires later first, and will sleep past the earlier expiration times. When we wake up we
+    // may kill many sessions at once. To avoid this, we will add to queue under lock and recheck
+    // time before we wait. We don't have to worry about removals; at worst we'd wake up in vain.
+    // Example: expirations of 1:03:00, 1:00:00, 1:02:00 are added (in this order due to jitter).
+    // If the expiration threads sees that 1:03 first, it will sleep for 1:03, then wake up and
+    // kill all 3 sessions at once because they all have expired, removing any effect from jitter.
+    // Instead, expiration thread rechecks the first queue item and waits on the queue. If nothing
+    // is added to the queue, the item examined is still the earliest to be expired. If someone
+    // adds to the queue while it is waiting, it will notify the thread and it would wake up and
+    // recheck the queue.
+    synchronized (expirationQueue) {
+      expirationQueue.add(session);
+      expirationQueue.notifyAll();
+    }
+  }
+
+
+  public void removeFromExpirationQueue(TezSessionPoolSession session) {
+    expirationQueue.remove(session);
+  }
+
+  public void closeAndRestartExpiredSession(
+      TezSessionPoolSession session, boolean isAsync) throws Exception {
+    if (isAsync) {
+      restartQueue.add(session);
+    } else {
+      sessionRestartImpl.closeAndReopenPoolSession(session);
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPool.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPool.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPool.java
new file mode 100644
index 0000000..4f58565
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPool.java
@@ -0,0 +1,165 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.io.IOException;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.BlockingDeque;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
+import org.apache.hadoop.hive.ql.session.SessionState;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Distinct from TezSessionPool manager in that it implements a session pool, and nothing else.
+ */
+class TezSessionPool {
+  private static final Logger LOG = LoggerFactory.getLogger(TezSessionPool.class);
+
+  /** A queue for initial sessions that have not been started yet. */
+  private final Queue<TezSessionPoolSession> initialSessions =
+      new ConcurrentLinkedQueue<TezSessionPoolSession>();
+
+  private final HiveConf initConf;
+  private final BlockingDeque<TezSessionPoolSession> defaultQueuePool;
+
+  TezSessionPool(HiveConf initConf, int numSessionsTotal) {
+    this.initConf = initConf;
+    assert numSessionsTotal > 0;
+    defaultQueuePool = new LinkedBlockingDeque<TezSessionPoolSession>(numSessionsTotal);
+  }
+
+  void startInitialSessions() throws Exception {
+    if (initialSessions.isEmpty()) return;
+    int threadCount = Math.min(initialSessions.size(),
+        HiveConf.getIntVar(initConf, ConfVars.HIVE_SERVER2_TEZ_SESSION_MAX_INIT_THREADS));
+    Preconditions.checkArgument(threadCount > 0);
+    if (threadCount == 1) {
+      while (true) {
+        TezSessionPoolSession session = initialSessions.poll();
+        if (session == null) break;
+        startInitialSession(session);
+      }
+    } else {
+      final SessionState parentSessionState = SessionState.get();
+      // The runnable has no mutable state, so each thread can run the same thing.
+      final AtomicReference<Exception> firstError = new AtomicReference<>(null);
+      Runnable runnable = new Runnable() {
+        public void run() {
+          if (parentSessionState != null) {
+            SessionState.setCurrentSessionState(parentSessionState);
+          }
+          while (true) {
+            TezSessionPoolSession session = initialSessions.poll();
+            if (session == null) break;
+            if (firstError.get() != null) break; // Best-effort.
+            try {
+              startInitialSession(session);
+            } catch (Exception e) {
+              if (!firstError.compareAndSet(null, e)) {
+                LOG.error("Failed to start session; ignoring due to previous error", e);
+              }
+              break;
+            }
+          }
+        }
+      };
+      Thread[] threads = new Thread[threadCount - 1];
+      for (int i = 0; i < threads.length; ++i) {
+        threads[i] = new Thread(runnable, "Tez session init " + i);
+        threads[i].start();
+      }
+      runnable.run();
+      for (int i = 0; i < threads.length; ++i) {
+        threads[i].join();
+      }
+      Exception ex = firstError.get();
+      if (ex != null) {
+        throw ex;
+      }
+    }
+  }
+
+  void addInitialSession(TezSessionPoolSession session) {
+    initialSessions.add(session);
+  }
+
+  TezSessionState getSession() throws Exception {
+    while (true) {
+      TezSessionPoolSession result = defaultQueuePool.take();
+      if (result.tryUse()) return result;
+      LOG.info("Couldn't use a session [" + result + "]; attempting another one");
+    }
+  }
+
+  void returnSession(TezSessionPoolSession session) throws Exception {
+    // TODO: should this be in pool, or pool manager? Probably common to all the use cases.
+    SessionState sessionState = SessionState.get();
+    if (sessionState != null) {
+      sessionState.setTezSession(null);
+    }
+    if (session.returnAfterUse()) {
+      defaultQueuePool.putFirst(session);
+    }
+  }
+
+  void replaceSession(
+      TezSessionPoolSession oldSession, TezSessionPoolSession newSession) throws Exception {
+    // Retain the stuff from the old session.
+    // Re-setting the queue config is an old hack that we may remove in future.
+    Path scratchDir = oldSession.getTezScratchDir();
+    Set<String> additionalFiles = oldSession.getAdditionalFilesNotFromConf();
+    HiveConf conf = oldSession.getConf();
+    String queueName = oldSession.getQueueName();
+    try {
+      oldSession.close(false);
+      boolean wasRemoved = defaultQueuePool.remove(oldSession);
+      if (!wasRemoved) {
+        LOG.error("Old session was closed but it was not in the pool", oldSession);
+      }
+    } finally {
+      // There's some bogus code that can modify the queue name. Force-set it for pool sessions.
+      // TODO: this might only be applicable to TezSessionPoolManager; try moving it there?
+      conf.set(TezConfiguration.TEZ_QUEUE_NAME, queueName);
+      newSession.open(conf, additionalFiles, scratchDir);
+      defaultQueuePool.put(newSession);
+    }
+  }
+
+  private void startInitialSession(TezSessionPoolSession sessionState) throws Exception {
+    HiveConf newConf = new HiveConf(initConf);
+    // Makes no senses for it to be mixed up like this.
+    boolean isUsable = sessionState.tryUse();
+    if (!isUsable) throw new IOException(sessionState + " is not usable at pool startup");
+    newConf.set(TezConfiguration.TEZ_QUEUE_NAME, sessionState.getQueueName());
+    sessionState.open(newConf);
+    if (sessionState.returnAfterUse()) {
+      defaultQueuePool.put(sessionState);
+    }
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolManager.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolManager.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolManager.java
index 7c94002..1f4705c 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolManager.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolManager.java
@@ -21,45 +21,23 @@ package org.apache.hadoop.hive.ql.exec.tez;
 import org.apache.hadoop.conf.Configuration;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-
-import java.io.IOException;
-import java.net.URISyntaxException;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.BlockingDeque;
-import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.LinkedBlockingDeque;
-import java.util.concurrent.PriorityBlockingQueue;
-import java.util.concurrent.LinkedBlockingQueue;
+
 import java.util.concurrent.Semaphore;
-import java.util.concurrent.TimeUnit;
 import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Comparator;
-import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
-import java.util.Queue;
 import java.util.Random;
 import java.util.Set;
 
-import javax.security.auth.login.LoginException;
-
 import org.apache.tez.dag.api.TezConfiguration;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.hive.conf.HiveConf;
 import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
+import org.apache.hadoop.hive.ql.exec.tez.TezSessionPoolSession.OpenSessionTracker;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.session.SessionState;
-import org.apache.hadoop.hive.ql.session.SessionState.LogHelper;
 import org.apache.hadoop.hive.shims.Utils;
 import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.tez.dag.api.TezException;
 
 /**
  * This class is for managing multiple tez sessions particularly when
@@ -68,7 +46,8 @@ import org.apache.tez.dag.api.TezException;
  * In case the user specifies a queue explicitly, a new session is created
  * on that queue and assigned to the session state.
  */
-public class TezSessionPoolManager {
+public class TezSessionPoolManager
+  implements SessionExpirationTracker.RestartImpl, OpenSessionTracker {
 
   private enum CustomQueueAllowed {
     TRUE,
@@ -77,127 +56,55 @@ public class TezSessionPoolManager {
   }
 
   private static final Logger LOG = LoggerFactory.getLogger(TezSessionPoolManager.class);
-  private static final Random rdm = new Random();
-
-  private volatile SessionState initSessionState;
-  private BlockingDeque<TezSessionPoolSession> defaultQueuePool;
-
-  /** Priority queue sorted by expiration time of live sessions that could be expired. */
-  private PriorityBlockingQueue<TezSessionPoolSession> expirationQueue;
-  /** The background restart queue that is populated when expiration is triggered by a foreground
-   * thread (i.e. getting or returning a session), to avoid delaying it. */
-  private BlockingQueue<TezSessionPoolSession> restartQueue;
-  private Thread expirationThread;
-  private Thread restartThread;
+  static final Random rdm = new Random();
 
   private Semaphore llapQueue;
   private HiveConf initConf = null;
   // Config settings.
   private int numConcurrentLlapQueries = -1;
-  private long sessionLifetimeMs = 0;
-  private long sessionLifetimeJitterMs = 0;
   private CustomQueueAllowed customQueueAllowed = CustomQueueAllowed.TRUE;
-  private List<ConfVars> restrictedHiveConf = new ArrayList<>();
-  private List<String> restrictedNonHiveConf = new ArrayList<>();
 
-  /** A queue for initial sessions that have not been started yet. */
-  private Queue<TezSessionPoolSession> initialSessions =
-      new ConcurrentLinkedQueue<TezSessionPoolSession>();
+  private TezSessionPool defaultSessionPool;
+  private SessionExpirationTracker expirationTracker;
+  private RestrictedConfigChecker restrictedConfig;
+
   /**
    * Indicates whether we should try to use defaultSessionPool.
    * We assume that setupPool is either called before any activity, or not called at all.
    */
   private volatile boolean hasInitialSessions = false;
 
-  private static TezSessionPoolManager sessionPool = null;
+  private static TezSessionPoolManager instance = null;
 
-  private static final List<TezSessionPoolSession> openSessions
-    = new LinkedList<TezSessionPoolSession>();
+  /** This is used to close non-default sessions, and also all sessions when stopping. */
+  private final List<TezSessionPoolSession> openSessions = new LinkedList<>();
 
-  public static TezSessionPoolManager getInstance()
-      throws Exception {
-    if (sessionPool == null) {
-      sessionPool = new TezSessionPoolManager();
+  /** Note: this is not thread-safe. */
+  public static TezSessionPoolManager getInstance() throws Exception {
+    TezSessionPoolManager local = instance;
+    if (local == null) {
+      instance = local = new TezSessionPoolManager();
     }
 
-    return sessionPool;
+    return local;
   }
 
   protected TezSessionPoolManager() {
   }
 
-  private void startInitialSession(TezSessionPoolSession sessionState) throws Exception {
-    HiveConf newConf = new HiveConf(initConf); // TODO Why is this configuration management not happening inside TezSessionPool.
-    // Makes no senses for it to be mixed up like this.
-    boolean isUsable = sessionState.tryUse();
-    if (!isUsable) throw new IOException(sessionState + " is not usable at pool startup");
-    newConf.set(TezConfiguration.TEZ_QUEUE_NAME, sessionState.getQueueName());
-    sessionState.open(newConf);
-    if (sessionState.returnAfterUse()) {
-      defaultQueuePool.put(sessionState);
-    }
-  }
-
   public void startPool() throws Exception {
-    if (initialSessions.isEmpty()) return;
-    // Hive SessionState available at this point.
-    initSessionState = SessionState.get();
-    int threadCount = Math.min(initialSessions.size(),
-        HiveConf.getIntVar(initConf, ConfVars.HIVE_SERVER2_TEZ_SESSION_MAX_INIT_THREADS));
-    Preconditions.checkArgument(threadCount > 0);
-    if (threadCount == 1) {
-      while (true) {
-        TezSessionPoolSession session = initialSessions.poll();
-        if (session == null) break;
-        startInitialSession(session);
-      }
-    } else {
-      // TODO What is this doing now ?
-      final SessionState parentSessionState = SessionState.get();
-      // The runnable has no mutable state, so each thread can run the same thing.
-      final AtomicReference<Exception> firstError = new AtomicReference<>(null);
-      Runnable runnable = new Runnable() {
-        public void run() {
-          if (parentSessionState != null) {
-            SessionState.setCurrentSessionState(parentSessionState);
-          }
-          while (true) {
-            TezSessionPoolSession session = initialSessions.poll();
-            if (session == null) break;
-            try {
-              startInitialSession(session);
-            } catch (Exception e) {
-              if (!firstError.compareAndSet(null, e)) {
-                LOG.error("Failed to start session; ignoring due to previous error", e);
-                // TODO Why even continue after this. We're already in a state where things are messed up ?
-              }
-            }
-          }
-        }
-      };
-      Thread[] threads = new Thread[threadCount - 1];
-      for (int i = 0; i < threads.length; ++i) {
-        threads[i] = new Thread(runnable, "Tez session init " + i);
-        threads[i].start();
-      }
-      runnable.run();
-      for (int i = 0; i < threads.length; ++i) {
-        threads[i].join();
-      }
-      Exception ex = firstError.get();
-      if (ex != null) {
-        throw ex;
-      }
+    if (defaultSessionPool != null) {
+      defaultSessionPool.startInitialSessions();
     }
-    if (expirationThread != null) {
-      expirationThread.start();
-      restartThread.start();
+    if (expirationTracker != null) {
+      expirationTracker.start();
     }
   }
 
   public void setupPool(HiveConf conf) throws InterruptedException {
     String[] defaultQueueList = HiveConf.getTrimmedStringsVar(
         conf, HiveConf.ConfVars.HIVE_SERVER2_TEZ_DEFAULT_QUEUES);
+    this.initConf = conf;
     int emptyNames = 0; // We don't create sessions for empty entries.
     for (String queueName : defaultQueueList) {
       if (queueName.isEmpty()) {
@@ -207,13 +114,12 @@ public class TezSessionPoolManager {
     int numSessions = conf.getIntVar(ConfVars.HIVE_SERVER2_TEZ_SESSIONS_PER_DEFAULT_QUEUE);
     int numSessionsTotal = numSessions * (defaultQueueList.length - emptyNames);
     if (numSessionsTotal > 0) {
-      defaultQueuePool = new LinkedBlockingDeque<TezSessionPoolSession>(numSessionsTotal);
+      defaultSessionPool = new TezSessionPool(initConf, numSessionsTotal);
     }
 
     numConcurrentLlapQueries = conf.getIntVar(ConfVars.HIVE_SERVER2_LLAP_CONCURRENT_QUERIES);
     llapQueue = new Semaphore(numConcurrentLlapQueries, true);
 
-    this.initConf = conf;
     String queueAllowedStr = HiveConf.getVar(initConf,
         ConfVars.HIVE_SERVER2_TEZ_SESSION_CUSTOM_QUEUE_ALLOWED);
     try {
@@ -222,73 +128,18 @@ public class TezSessionPoolManager {
       throw new RuntimeException("Invalid value '" + queueAllowedStr + "' for " +
           ConfVars.HIVE_SERVER2_TEZ_SESSION_CUSTOM_QUEUE_ALLOWED.varname);
     }
-    String[] restrictedConfigs = HiveConf.getTrimmedStringsVar(initConf,
-        ConfVars.HIVE_SERVER2_TEZ_SESSION_RESTRICTED_CONFIGS);
-    if (restrictedConfigs != null && restrictedConfigs.length > 0) {
-      HashMap<String, ConfVars> confVars = HiveConf.getOrCreateReverseMap();
-      for (String confName : restrictedConfigs) {
-        if (confName == null || confName.isEmpty()) continue;
-        confName = confName.toLowerCase();
-        ConfVars cv = confVars.get(confName);
-        if (cv != null) {
-          restrictedHiveConf.add(cv);
-        } else {
-          LOG.warn("A restricted config " + confName + " is not recognized as a Hive setting.");
-          restrictedNonHiveConf.add(confName);
-        }
-      }
-    }
 
-    sessionLifetimeMs = conf.getTimeVar(
-        ConfVars.HIVE_SERVER2_TEZ_SESSION_LIFETIME, TimeUnit.MILLISECONDS);
-    if (sessionLifetimeMs != 0) {
-      sessionLifetimeJitterMs = conf.getTimeVar(
-          ConfVars.HIVE_SERVER2_TEZ_SESSION_LIFETIME_JITTER, TimeUnit.MILLISECONDS);
-      if (LOG.isDebugEnabled()) {
-        LOG.debug("Session expiration is enabled; session lifetime is "
-            + sessionLifetimeMs + " + [0, " + sessionLifetimeJitterMs + ") ms");
-      }
-      expirationQueue = new PriorityBlockingQueue<>(11, new Comparator<TezSessionPoolSession>() {
-        @Override
-        public int compare(TezSessionPoolSession o1, TezSessionPoolSession o2) {
-          assert o1.expirationNs != null && o2.expirationNs != null;
-          return o1.expirationNs.compareTo(o2.expirationNs);
-        }
-      });
-      restartQueue = new LinkedBlockingQueue<>();
-    }
-    this.hasInitialSessions = numSessionsTotal > 0;
-    // From this point on, session creation will wait for the default pool (if # of sessions > 0).
+    restrictedConfig = new RestrictedConfigChecker(conf);
+    // Only creates the expiration tracker if expiration is configured.
+    expirationTracker = SessionExpirationTracker.create(conf, this);
 
-    if (sessionLifetimeMs != 0) {
-      expirationThread = new Thread(new Runnable() {
-        @Override
-        public void run() {
-          try {
-            SessionState.setCurrentSessionState(initSessionState);
-            runExpirationThread();
-          } catch (Exception e) {
-            LOG.warn("Exception in TezSessionPool-expiration thread. Thread will shut down", e);
-          } finally {
-            LOG.info("TezSessionPool-expiration thread exiting");
-          }
-        }
-      }, "TezSessionPool-expiration");
-      restartThread = new Thread(new Runnable() {
-        @Override
-        public void run() {
-          try {
-            SessionState.setCurrentSessionState(initSessionState);
-            runRestartThread();
-          } catch (Exception e) {
-            LOG.warn("Exception in TezSessionPool-cleanup thread. Thread will shut down", e);
-          } finally {
-            LOG.info("TezSessionPool-cleanup thread exiting");
-          }
-        }
-      }, "TezSessionPool-cleanup");
+    // From this point on, session creation will wait for the default pool (if # of sessions > 0).
+    this.hasInitialSessions = numSessionsTotal > 0;
+    if (!hasInitialSessions) {
+      return;
     }
 
+
     /*
      * In a single-threaded init case, with this the ordering of sessions in the queue will be
      * (with 2 sessions 3 queues) s1q1, s1q2, s1q3, s2q1, s2q2, s2q3 there by ensuring uniform
@@ -301,7 +152,7 @@ public class TezSessionPoolManager {
         if (queueName.isEmpty()) {
           continue;
         }
-        initialSessions.add(createAndInitSession(queueName, true));
+        defaultSessionPool.addInitialSession(createAndInitSession(queueName, true));
       }
     }
   }
@@ -322,8 +173,10 @@ public class TezSessionPoolManager {
     return sessionState;
   }
 
-  private TezSessionState getSession(HiveConf conf, boolean doOpen)
-      throws Exception {
+  private TezSessionState getSession(HiveConf conf, boolean doOpen) throws Exception {
+    // NOTE: this can be called outside of HS2, without calling setupPool. Basically it should be
+    //       able to handle not being initialized. Perhaps we should get rid of the instance and
+    //       move the setupPool code to ctor. For now, at least hasInitialSessions will be false.
     String queueName = conf.get(TezConfiguration.TEZ_QUEUE_NAME);
     boolean hasQueue = (queueName != null) && !queueName.isEmpty();
     if (hasQueue) {
@@ -340,16 +193,8 @@ public class TezSessionPoolManager {
     }
 
     // Check the restricted configs that the users cannot set.
-    for (ConfVars var : restrictedHiveConf) {
-      String userValue = HiveConf.getVarWithoutType(conf, var),
-          serverValue = HiveConf.getVarWithoutType(initConf, var);
-      // Note: with some trickery, we could add logic for each type in ConfVars; for now the
-      // potential spurious mismatches (e.g. 0 and 0.0 for float) should be easy to work around.
-      validateRestrictedConfigValues(var.varname, userValue, serverValue);
-    }
-    for (String var : restrictedNonHiveConf) {
-      String userValue = conf.get(var), serverValue = initConf.get(var);
-      validateRestrictedConfigValues(var, userValue, serverValue);
+    if (restrictedConfig != null) {
+      restrictedConfig.validate(conf);
     }
 
     // Propagate this value from HS2; don't allow users to set it.
@@ -371,26 +216,12 @@ public class TezSessionPoolManager {
      */
     if (nonDefaultUser || !hasInitialSessions || hasQueue) {
       LOG.info("QueueName: {} nonDefaultUser: {} defaultQueuePool: {} hasInitialSessions: {}",
-              queueName, nonDefaultUser, defaultQueuePool, hasInitialSessions);
+              queueName, nonDefaultUser, defaultSessionPool, hasInitialSessions);
       return getNewSessionState(conf, queueName, doOpen);
     }
 
     LOG.info("Choosing a session from the defaultQueuePool");
-    while (true) {
-      TezSessionPoolSession result = defaultQueuePool.take();
-      if (result.tryUse()) return result;
-      LOG.info("Couldn't use a session [" + result + "]; attempting another one");
-    }
-  }
-
-  private void validateRestrictedConfigValues(
-      String var, String userValue, String serverValue) throws HiveException {
-    if ((userValue == null) != (serverValue == null)
-        || (userValue != null && !userValue.equals(serverValue))) {
-      String logValue = initConf.isHiddenConfig(var) ? "(hidden)" : serverValue;
-      throw new HiveException(var + " is restricted from being set; server is configured"
-          + " to use " + logValue + ", but the query configuration specifies " + userValue);
-    }
+    return defaultSessionPool.getSession();
   }
 
   /**
@@ -431,15 +262,7 @@ public class TezSessionPoolManager {
           tezSessionState instanceof TezSessionPoolSession) {
         LOG.info("The session " + tezSessionState.getSessionId()
             + " belongs to the pool. Put it back in");
-        SessionState sessionState = SessionState.get();
-        if (sessionState != null) {
-          sessionState.setTezSession(null);
-        }
-        TezSessionPoolSession poolSession =
-            (TezSessionPoolSession) tezSessionState;
-        if (poolSession.returnAfterUse()) {
-          defaultQueuePool.putFirst(poolSession);
-        }
+        defaultSessionPool.returnSession((TezSessionPoolSession)tezSessionState);
       }
       // non default session nothing changes. The user can continue to use the existing
       // session in the SessionState
@@ -460,7 +283,7 @@ public class TezSessionPoolManager {
   }
 
   public void stop() throws Exception {
-    if ((sessionPool == null) || !this.hasInitialSessions) {
+    if ((instance == null) || !this.hasInitialSessions) {
       return;
     }
 
@@ -474,12 +297,8 @@ public class TezSessionPoolManager {
         sessionState.close(false);
       }
     }
-
-    if (expirationThread != null) {
-      expirationThread.interrupt();
-    }
-    if (restartThread != null) {
-      restartThread.interrupt();
+    if (expirationTracker != null) {
+      expirationTracker.stop();
     }
   }
 
@@ -498,7 +317,7 @@ public class TezSessionPoolManager {
   }
 
   protected TezSessionPoolSession createSession(String sessionId) {
-    return new TezSessionPoolSession(sessionId, this);
+    return new TezSessionPoolSession(sessionId, this, expirationTracker);
   }
 
   /*
@@ -597,242 +416,37 @@ public class TezSessionPoolManager {
   }
 
   /** Closes a running (expired) pool session and reopens it. */
-  private void closeAndReopenPoolSession(TezSessionPoolSession oldSession) throws Exception {
+  @Override
+  public void closeAndReopenPoolSession(TezSessionPoolSession oldSession) throws Exception {
     String queueName = oldSession.getQueueName();
     if (queueName == null) {
       LOG.warn("Pool session has a null queue: " + oldSession);
     }
-    HiveConf conf = oldSession.getConf();
-    Path scratchDir = oldSession.getTezScratchDir();
-    boolean isDefault = oldSession.isDefault();
-    Set<String> additionalFiles = oldSession.getAdditionalFilesNotFromConf();
-    try {
-      oldSession.close(false);
-      defaultQueuePool.remove(oldSession);  // Make sure it's removed.
-    } finally {
-      TezSessionPoolSession newSession = createAndInitSession(queueName, isDefault);
-      // There's some bogus code that can modify the queue name. Force-set it for pool sessions.
-      conf.set(TezConfiguration.TEZ_QUEUE_NAME, queueName);
-      newSession.open(conf, additionalFiles, scratchDir);
-      defaultQueuePool.put(newSession);
-    }
-  }
-
-  /** Logic for the thread that restarts the sessions expired during foreground operations. */
-  private void runRestartThread() {
-    try {
-      while (true) {
-        TezSessionPoolSession next = restartQueue.take();
-        LOG.info("Restarting the expired session [" + next + "]");
-        try {
-          closeAndReopenPoolSession(next);
-        } catch (InterruptedException ie) {
-          throw ie;
-        } catch (Exception e) {
-          LOG.error("Failed to close or restart a session, ignoring", e);
-        }
-      }
-    } catch (InterruptedException e) {
-      LOG.info("Restart thread is exiting due to an interruption");
-    }
+    TezSessionPoolSession newSession = createAndInitSession(queueName, oldSession.isDefault());
+    defaultSessionPool.replaceSession(oldSession, newSession);
   }
 
-  /** Logic for the thread that tracks session expiration and restarts sessions in background. */
-  private void runExpirationThread() {
-    try {
-      while (true) {
-        TezSessionPoolSession nextToExpire = null;
-        while (true) {
-          // Restart the sessions until one of them refuses to restart.
-          nextToExpire = expirationQueue.take();
-          if (LOG.isDebugEnabled()) {
-            LOG.debug("Seeing if we can expire [" + nextToExpire + "]");
-          }
-          try {
-            if (!nextToExpire.tryExpire(false)) break;
-          } catch (Exception e) {
-            // Reopen happens even when close fails, so there's not much to do here.
-            LOG.error("Failed to expire session " + nextToExpire + "; ignoring", e);
-            nextToExpire = null;
-            break; // Not strictly necessary; do the whole queue check again.
-          }
-          LOG.info("Tez session [" + nextToExpire + "] has expired");
-        }
-        if (nextToExpire != null && LOG.isDebugEnabled()) {
-          LOG.debug("[" + nextToExpire + "] is not ready to expire; adding it back");
-        }
-
-        // See addToExpirationQueue for why we re-check the queue.
-        synchronized (expirationQueue) {
-          // Add back the non-expired session. No need to notify, we are the only ones waiting.
-          if (nextToExpire != null) {
-            expirationQueue.add(nextToExpire);
-          }
-          nextToExpire = expirationQueue.peek();
-          if (nextToExpire != null) {
-            // Add some margin to the wait to avoid rechecking close to the boundary.
-            long timeToWaitMs = 10 + (nextToExpire.expirationNs - System.nanoTime()) / 1000000L;
-            timeToWaitMs = Math.max(1, timeToWaitMs);
-            if (LOG.isDebugEnabled()) {
-              LOG.debug("Waiting for ~" + timeToWaitMs + "ms to expire [" + nextToExpire + "]");
-            }
-            expirationQueue.wait(timeToWaitMs);
-          } else if (LOG.isDebugEnabled()) {
-            // Don't wait if empty - go to take() above, that will wait for us.
-            LOG.debug("Expiration queue is empty");
-          }
-        }
-      }
-    } catch (InterruptedException e) {
-      LOG.info("Expiration thread is exiting due to an interruption");
+  /** Called by TezSessionPoolSession when opened. */
+  @Override
+  public void registerOpenSession(TezSessionPoolSession session) {
+    synchronized (openSessions) {
+      openSessions.add(session);
     }
   }
 
-  /**
-   * TezSession that keeps track of expiration and use.
-   * It has 3 states - not in use, in use, and expired. When in the pool, it is not in use;
-   * use and expiration may compete to take the session out of the pool and change it to the
-   * corresponding states. When someone tries to get a session, they check for expiration time;
-   * if it's time, the expiration is triggered; in that case, or if it was already triggered, the
-   * caller gets a different session. When the session is in use when it expires, the expiration
-   * thread ignores it and lets the return to the pool take care of the expiration.
-   */
-  @VisibleForTesting
-  static class TezSessionPoolSession extends TezSessionState {
-    private static final int STATE_NONE = 0, STATE_IN_USE = 1, STATE_EXPIRED = 2;
-
-    private final AtomicInteger sessionState = new AtomicInteger(STATE_NONE);
-    private Long expirationNs;
-    private final TezSessionPoolManager parent; // Static class allows us to be used in tests.
-
-    public TezSessionPoolSession(String sessionId, TezSessionPoolManager parent) {
-      super(sessionId);
-      this.parent = parent;
-    }
-
-    @Override
-    public void close(boolean keepTmpDir) throws Exception {
-      try {
-        super.close(keepTmpDir);
-      } finally {
-        if (LOG.isDebugEnabled()) {
-          LOG.debug("Closed a pool session [" + this + "]");
-        }
-        synchronized (openSessions) {
-          openSessions.remove(this);
-        }
-        if (parent.expirationQueue != null) {
-          parent.expirationQueue.remove(this);
-        }
-      }
-    }
-
-    @Override
-    protected void openInternal(HiveConf conf, Collection<String> additionalFiles,
-        boolean isAsync, LogHelper console, Path scratchDir)
-            throws IOException, LoginException, URISyntaxException, TezException {
-      super.openInternal(conf, additionalFiles, isAsync, console, scratchDir);
-      synchronized (openSessions) {
-        openSessions.add(this);
-      }
-      if (parent.expirationQueue != null) {
-        long jitterModMs = (long)(parent.sessionLifetimeJitterMs * rdm.nextFloat());
-        expirationNs = System.nanoTime() + (parent.sessionLifetimeMs + jitterModMs) * 1000000L;
-        if (LOG.isDebugEnabled()) {
-          LOG.debug("Adding a pool session [" + this + "] to expiration queue");
-        }
-        parent.addToExpirationQueue(this);
-      }
+  /** Called by TezSessionPoolSession when closed. */
+  @Override
+  public void unregisterOpenSession(TezSessionPoolSession session) {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("Closed a pool session [" + this + "]");
     }
-
-    @Override
-    public String toString() {
-      if (expirationNs == null) return super.toString();
-      long expiresInMs = (expirationNs - System.nanoTime()) / 1000000L;
-      return super.toString() + ", expires in " + expiresInMs + "ms";
-    }
-
-    /**
-     * Tries to use this session. When the session is in use, it will not expire.
-     * @return true if the session can be used; false if it has already expired.
-     */
-    public boolean tryUse() throws Exception {
-      while (true) {
-        int oldValue = sessionState.get();
-        if (oldValue == STATE_IN_USE) throw new AssertionError(this + " is already in use");
-        if (oldValue == STATE_EXPIRED) return false;
-        int finalState = shouldExpire() ? STATE_EXPIRED : STATE_IN_USE;
-        if (sessionState.compareAndSet(STATE_NONE, finalState)) {
-          if (finalState == STATE_IN_USE) return true;
-          closeAndRestartExpiredSession(true); // Restart asynchronously, don't block the caller.
-          return false;
-        }
-      }
-    }
-
-    /**
-     * Notifies the session that it's no longer in use. If the session has expired while in use,
-     * this method will take care of the expiration.
-     * @return True if the session was returned, false if it was restarted.
-     */
-    public boolean returnAfterUse() throws Exception {
-      int finalState = shouldExpire() ? STATE_EXPIRED : STATE_NONE;
-      if (!sessionState.compareAndSet(STATE_IN_USE, finalState)) {
-        throw new AssertionError("Unexpected state change; currently " + sessionState.get());
-      }
-      if (finalState == STATE_NONE) return true;
-      closeAndRestartExpiredSession(true);
-      return false;
-    }
-
-    /**
-     * Tries to expire and restart the session.
-     * @param isAsync Whether the restart should happen asynchronously.
-     * @return True if the session was, or will be restarted.
-     */
-    public boolean tryExpire(boolean isAsync) throws Exception {
-      if (expirationNs == null) return true;
-      if (!shouldExpire()) return false;
-      // Try to expire the session if it's not in use; if in use, bail.
-      while (true) {
-        if (sessionState.get() != STATE_NONE) return true; // returnAfterUse will take care of this
-        if (sessionState.compareAndSet(STATE_NONE, STATE_EXPIRED)) {
-          closeAndRestartExpiredSession(isAsync);
-          return true;
-        }
-      }
-    }
-
-    private void closeAndRestartExpiredSession(boolean async) throws Exception {
-      if (async) {
-        parent.restartQueue.add(this);
-      } else {
-        parent.closeAndReopenPoolSession(this);
-      }
-    }
-
-    private boolean shouldExpire() {
-      return expirationNs != null && (System.nanoTime() - expirationNs) >= 0;
+    synchronized (openSessions) {
+      openSessions.remove(session);
     }
   }
 
-  private void addToExpirationQueue(TezSessionPoolSession session) {
-    // Expiration queue is synchronized and notified upon when adding elements. Without jitter, we
-    // wouldn't need this, and could simple look at the first element and sleep for the wait time.
-    // However, when many things are added at once, it may happen that we will see the one that
-    // expires later first, and will sleep past the earlier expiration times. When we wake up we
-    // may kill many sessions at once. To avoid this, we will add to queue under lock and recheck
-    // time before we wait. We don't have to worry about removals; at worst we'd wake up in vain.
-    // Example: expirations of 1:03:00, 1:00:00, 1:02:00 are added (in this order due to jitter).
-    // If the expiration threads sees that 1:03 first, it will sleep for 1:03, then wake up and
-    // kill all 3 sessions at once because they all have expired, removing any effect from jitter.
-    // Instead, expiration thread rechecks the first queue item and waits on the queue. If nothing
-    // is added to the queue, the item examined is still the earliest to be expired. If someone
-    // adds to the queue while it is waiting, it will notify the thread and it would wake up and
-    // recheck the queue.
-    synchronized (expirationQueue) {
-      expirationQueue.add(session);
-      expirationQueue.notifyAll();
-    }
+  @VisibleForTesting
+  public SessionExpirationTracker getExpirationTracker() {
+    return expirationTracker;
   }
 }

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolSession.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolSession.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolSession.java
new file mode 100644
index 0000000..005eeed
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/tez/TezSessionPoolSession.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.tez;
+
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.util.Collection;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.security.auth.login.LoginException;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.session.SessionState.LogHelper;
+import org.apache.tez.dag.api.TezException;
+
+import com.google.common.annotations.VisibleForTesting;
+
+/**
+ * TezSession that is aware of the session pool, and also keeps track of expiration and use.
+ * It has 3 states - not in use, in use, and expired. When in the pool, it is not in use;
+ * use and expiration may compete to take the session out of the pool and change it to the
+ * corresponding states. When someone tries to get a session, they check for expiration time;
+ * if it's time, the expiration is triggered; in that case, or if it was already triggered, the
+ * caller gets a different session. When the session is in use when it expires, the expiration
+ * thread ignores it and lets the return to the pool take care of the expiration.
+ */
+@VisibleForTesting
+class TezSessionPoolSession extends TezSessionState {
+  private static final int STATE_NONE = 0, STATE_IN_USE = 1, STATE_EXPIRED = 2;
+
+  interface OpenSessionTracker {
+    void registerOpenSession(TezSessionPoolSession session);
+    void unregisterOpenSession(TezSessionPoolSession session);
+  }
+
+  private final AtomicInteger sessionState = new AtomicInteger(STATE_NONE);
+  private Long expirationNs;
+  private final OpenSessionTracker parent;
+  private final SessionExpirationTracker expirationTracker;
+
+  public TezSessionPoolSession(String sessionId, OpenSessionTracker parent,
+      SessionExpirationTracker expirationTracker) {
+    super(sessionId);
+    this.parent = parent;
+    this.expirationTracker = expirationTracker;
+  }
+
+  void setExpirationNs(long expirationNs) {
+    this.expirationNs = expirationNs;
+  }
+
+  Long getExpirationNs() {
+    return expirationNs;
+  }
+
+  @Override
+  public void close(boolean keepTmpDir) throws Exception {
+    try {
+      super.close(keepTmpDir);
+    } finally {
+      parent.unregisterOpenSession(this);
+      if (expirationTracker != null) {
+        expirationTracker.removeFromExpirationQueue(this);
+      }
+    }
+  }
+
+  @Override
+  protected void openInternal(HiveConf conf, Collection<String> additionalFiles,
+      boolean isAsync, LogHelper console, Path scratchDir)
+          throws IOException, LoginException, URISyntaxException, TezException {
+    super.openInternal(conf, additionalFiles, isAsync, console, scratchDir);
+    parent.registerOpenSession(this);
+    if (expirationTracker != null) {
+      expirationTracker.addToExpirationQueue(this);
+    }
+  }
+
+  @Override
+  public String toString() {
+    if (expirationNs == null) return super.toString();
+    long expiresInMs = (expirationNs - System.nanoTime()) / 1000000L;
+    return super.toString() + ", expires in " + expiresInMs + "ms";
+  }
+
+  /**
+   * Tries to use this session. When the session is in use, it will not expire.
+   * @return true if the session can be used; false if it has already expired.
+   */
+  public boolean tryUse() throws Exception {
+    while (true) {
+      int oldValue = sessionState.get();
+      if (oldValue == STATE_IN_USE) throw new AssertionError(this + " is already in use");
+      if (oldValue == STATE_EXPIRED) return false;
+      int finalState = shouldExpire() ? STATE_EXPIRED : STATE_IN_USE;
+      if (sessionState.compareAndSet(STATE_NONE, finalState)) {
+        if (finalState == STATE_IN_USE) return true;
+        // Restart asynchronously, don't block the caller.
+        expirationTracker.closeAndRestartExpiredSession(this, true);
+        return false;
+      }
+    }
+  }
+
+  /**
+   * Notifies the session that it's no longer in use. If the session has expired while in use,
+   * this method will take care of the expiration.
+   * @return True if the session was returned, false if it was restarted.
+   */
+  public boolean returnAfterUse() throws Exception {
+    int finalState = shouldExpire() ? STATE_EXPIRED : STATE_NONE;
+    if (!sessionState.compareAndSet(STATE_IN_USE, finalState)) {
+      throw new AssertionError("Unexpected state change; currently " + sessionState.get());
+    }
+    if (finalState == STATE_NONE) return true;
+    expirationTracker.closeAndRestartExpiredSession(this, true);
+    return false;
+  }
+
+  /**
+   * Tries to expire and restart the session.
+   * @param isAsync Whether the restart should happen asynchronously.
+   * @return True if the session was, or will be restarted.
+   */
+  public boolean tryExpire(boolean isAsync) throws Exception {
+    if (expirationNs == null) return true;
+    if (!shouldExpire()) return false;
+    // Try to expire the session if it's not in use; if in use, bail.
+    while (true) {
+      if (sessionState.get() != STATE_NONE) return true; // returnAfterUse will take care of this
+      if (sessionState.compareAndSet(STATE_NONE, STATE_EXPIRED)) {
+        expirationTracker.closeAndRestartExpiredSession(this, isAsync);
+        return true;
+      }
+    }
+  }
+
+  private boolean shouldExpire() {
+    return expirationNs != null && (System.nanoTime() - expirationNs) >= 0;
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/SampleTezSessionState.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/SampleTezSessionState.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/SampleTezSessionState.java
index 2d1c687..973c0cc 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/SampleTezSessionState.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/SampleTezSessionState.java
@@ -18,7 +18,6 @@
 
 package org.apache.hadoop.hive.ql.exec.tez;
 
-import org.apache.hadoop.hive.ql.exec.tez.TezSessionPoolManager.TezSessionPoolSession;
 
 import java.io.IOException;
 import java.net.URISyntaxException;
@@ -45,7 +44,7 @@ public class SampleTezSessionState extends TezSessionPoolSession {
   private boolean doAsEnabled;
 
   public SampleTezSessionState(String sessionId, TezSessionPoolManager parent) {
-    super(sessionId, parent);
+    super(sessionId, parent, parent.getExpirationTracker());
     this.sessionId = sessionId;
   }
 

http://git-wip-us.apache.org/repos/asf/hive/blob/733bc5f0/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestTezSessionPool.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestTezSessionPool.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestTezSessionPool.java
index 88c8122..d2b98c4 100644
--- a/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestTezSessionPool.java
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/tez/TestTezSessionPool.java
@@ -49,7 +49,7 @@ public class TestTezSessionPool {
     }
 
     @Override
-    public TezSessionPoolManager.TezSessionPoolSession createSession(String sessionId) {
+    public TezSessionPoolSession createSession(String sessionId) {
       return new SampleTezSessionState(sessionId, this);
     }
   }