You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by xu...@apache.org on 2014/12/10 04:14:58 UTC

svn commit: r1644322 [2/2] - in /hive/branches/spark: ./ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/ ql/src/java/org/apache/hadoop/hive/ql/exec/spark/status/impl/ spark-client/ spark-client/src/main/java/org/apache/hive/spark/client/ spark-client...

Added: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcDispatcher.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcDispatcher.java?rev=1644322&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcDispatcher.java (added)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcDispatcher.java Wed Dec 10 03:14:57 2014
@@ -0,0 +1,203 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.Maps;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.SimpleChannelInboundHandler;
+import io.netty.util.concurrent.Promise;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
+
+/**
+ * An implementation of ChannelInboundHandler that dispatches incoming messages to an instance
+ * method based on the method signature.
+ * <p/>
+ * A handler's signature must be of the form:
+ * <p/>
+ * <blockquote><tt>protected void handle(ChannelHandlerContext, MessageType)</tt></blockquote>
+ * <p/>
+ * Where "MessageType" must match exactly the type of the message to handle. Polymorphism is not
+ * supported. Handlers can return a value, which becomes the RPC reply; if a null is returned, then
+ * a reply is still sent, with an empty payload.
+ */
+@InterfaceAudience.Private
+public abstract class RpcDispatcher extends SimpleChannelInboundHandler<Object> {
+
+  private static final Logger LOG = LoggerFactory.getLogger(RpcDispatcher.class);
+
+  private final Map<Class<?>, Method> handlers = Maps.newConcurrentMap();
+  private final Collection<OutstandingRpc> rpcs = new ConcurrentLinkedQueue<OutstandingRpc>();
+
+  private volatile Rpc.MessageHeader lastHeader;
+
+  /** Override this to add a name to the dispatcher, for debugging purposes. */
+  protected String name() {
+    return getClass().getSimpleName();
+  }
+
+  @Override
+  protected final void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
+    if (lastHeader == null) {
+      if (!(msg instanceof Rpc.MessageHeader)) {
+        LOG.warn("[{}] Expected RPC header, got {} instead.", name(),
+            msg != null ? msg.getClass().getName() : null);
+        throw new IllegalArgumentException();
+      }
+      lastHeader = (Rpc.MessageHeader) msg;
+    } else {
+      LOG.debug("[{}] Received RPC message: type={} id={} payload={}", name(),
+        lastHeader.type, lastHeader.id, msg != null ? msg.getClass().getName() : null);
+      try {
+        switch (lastHeader.type) {
+        case CALL:
+          handleCall(ctx, msg);
+          break;
+        case REPLY:
+          handleReply(ctx, msg, findRpc(lastHeader.id));
+          break;
+        case ERROR:
+          handleError(ctx, msg, findRpc(lastHeader.id));
+          break;
+        default:
+          throw new IllegalArgumentException("Unknown RPC message type: " + lastHeader.type);
+        }
+      } finally {
+        lastHeader = null;
+      }
+    }
+  }
+
+  private OutstandingRpc findRpc(long id) {
+    for (Iterator<OutstandingRpc> it = rpcs.iterator(); it.hasNext(); ) {
+      OutstandingRpc rpc = it.next();
+      if (rpc.id == id) {
+        it.remove();
+        return rpc;
+      }
+    }
+    throw new IllegalArgumentException(String.format(
+        "Received RPC reply for unknown RPC (%d).", id));
+  }
+
+  private void handleCall(ChannelHandlerContext ctx, Object msg) throws Exception {
+    Method handler = handlers.get(msg.getClass());
+    if (handler == null) {
+      handler = getClass().getDeclaredMethod("handle", ChannelHandlerContext.class,
+          msg.getClass());
+      handler.setAccessible(true);
+      handlers.put(msg.getClass(), handler);
+    }
+
+    Rpc.MessageType replyType;
+    Object replyPayload;
+    try {
+      replyPayload = handler.invoke(this, ctx, msg);
+      if (replyPayload == null) {
+        replyPayload = new Rpc.NullMessage();
+      }
+      replyType = Rpc.MessageType.REPLY;
+    } catch (InvocationTargetException ite) {
+      LOG.debug(String.format("[%s] Error in RPC handler.", name()), ite.getCause());
+      replyPayload = Throwables.getStackTraceAsString(ite.getCause());
+      replyType = Rpc.MessageType.ERROR;
+    }
+    ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, replyType));
+    ctx.channel().writeAndFlush(replyPayload);
+  }
+
+  private void handleReply(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
+      throws Exception {
+    rpc.future.setSuccess(msg instanceof Rpc.NullMessage ? null : msg);
+  }
+
+  private void handleError(ChannelHandlerContext ctx, Object msg, OutstandingRpc rpc)
+      throws Exception {
+    if (msg instanceof String) {
+      rpc.future.setFailure(new RpcException((String) msg));
+    } else {
+      String error = String.format("Received error with unexpected payload (%s).",
+          msg != null ? msg.getClass().getName() : null);
+      LOG.warn(String.format("[%s] %s", name(), error));
+      rpc.future.setFailure(new IllegalArgumentException(error));
+      ctx.close();
+    }
+  }
+
+  @Override
+  public final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug(String.format("[%s] Caught exception in channel pipeline.", name()), cause);
+    } else {
+      LOG.info("[{}] Closing channel due to exception in pipeline ({}).", name(),
+          cause.getMessage());
+    }
+
+    if (lastHeader != null) {
+      // There's an RPC waiting for a reply. Exception was most probably caught while processing
+      // the RPC, so send an error.
+      ctx.channel().write(new Rpc.MessageHeader(lastHeader.id, Rpc.MessageType.ERROR));
+      ctx.channel().writeAndFlush(Throwables.getStackTraceAsString(cause));
+      lastHeader = null;
+    }
+
+    ctx.close();
+  }
+
+  @Override
+  public final void channelInactive(ChannelHandlerContext ctx) throws Exception {
+    if (rpcs.size() > 0) {
+      LOG.warn("[{}] Closing RPC channel with {} outstanding RPCs.", name(), rpcs.size());
+      for (OutstandingRpc rpc : rpcs) {
+        rpc.future.cancel(true);
+      }
+    }
+    super.channelInactive(ctx);
+  }
+
+  void registerRpc(long id, Promise promise, String type) {
+    LOG.debug("[{}] Registered outstanding rpc {} ({}).", name(), id, type);
+    rpcs.add(new OutstandingRpc(id, promise));
+  }
+
+  void discardRpc(long id) {
+    LOG.debug("[{}] Discarding failed RPC {}.", name(), id);
+    findRpc(id);
+  }
+
+  private static class OutstandingRpc {
+    final long id;
+    final Promise future;
+
+    OutstandingRpc(long id, Promise future) {
+      this.id = id;
+      this.future = future;
+    }
+  }
+
+}

Added: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcException.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcException.java?rev=1644322&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcException.java (added)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcException.java Wed Dec 10 03:14:57 2014
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
+
+@InterfaceAudience.Private
+public class RpcException extends RuntimeException {
+
+  RpcException(String remoteStackTrace) {
+    super(remoteStackTrace);
+  }
+
+}

Added: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java?rev=1644322&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java (added)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/client/rpc/RpcServer.java Wed Dec 10 03:14:57 2014
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.security.SecureRandom;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Optional;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+import io.netty.bootstrap.ServerBootstrap;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.util.concurrent.Future;
+import io.netty.util.concurrent.GenericFutureListener;
+import io.netty.util.concurrent.Promise;
+import io.netty.util.concurrent.ScheduledFuture;
+
+import org.apache.hadoop.hive.common.classification.InterfaceAudience;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * An RPC server. The server matches remote clients based on a secret that is generated on
+ * the server - the secret needs to be given to the client through some other mechanism for
+ * this to work.
+ */
+@InterfaceAudience.Private
+public class RpcServer implements Closeable {
+
+  private static final Logger LOG = LoggerFactory.getLogger(RpcServer.class);
+  private static final SecureRandom RND = new SecureRandom();
+
+  private final String address;
+  private final Channel channel;
+  private final EventLoopGroup group;
+  private final int port;
+  private final Collection<ClientInfo> pendingClients;
+  private final RpcConfiguration config;
+
+  public RpcServer(Map<String, String> config) throws IOException, InterruptedException {
+    this.config = new RpcConfiguration(config);
+    this.group = new NioEventLoopGroup(
+        this.config.getRpcThreadCount(),
+        new ThreadFactoryBuilder()
+            .setNameFormat("RPC-Handler-%d")
+            .setDaemon(true)
+            .build());
+    this.channel = new ServerBootstrap()
+      .group(group)
+      .channel(NioServerSocketChannel.class)
+      .childHandler(new ChannelInitializer<SocketChannel>() {
+          @Override
+          public void initChannel(SocketChannel ch) throws Exception {
+            HelloDispatcher dispatcher = new HelloDispatcher();
+            final Rpc newRpc = Rpc.createRpc(RpcServer.this.config, ch, dispatcher, group);
+            dispatcher.rpc = newRpc;
+
+            Runnable cancelTask = new Runnable() {
+                @Override
+                public void run() {
+                  LOG.warn("Timed out waiting for hello from client.");
+                  newRpc.close();
+                }
+            };
+            dispatcher.cancelTask = group.schedule(cancelTask,
+                RpcServer.this.config.getServerConnectTimeoutMs(),
+                TimeUnit.MILLISECONDS);
+
+          }
+      })
+      .option(ChannelOption.SO_BACKLOG, 1)
+      .option(ChannelOption.SO_REUSEADDR, true)
+      .childOption(ChannelOption.SO_KEEPALIVE, true)
+      .bind(0)
+      .sync()
+      .channel();
+    this.port = ((InetSocketAddress) channel.localAddress()).getPort();
+    this.pendingClients = new ConcurrentLinkedQueue<ClientInfo>();
+    this.address = this.config.getServerAddress();
+  }
+
+  /**
+   * Tells the RPC server to expect a connection from a new client.
+   *
+   * @param secret The secret the client will send to the server to identify itself.
+   * @param serverDispatcher The dispatcher to use when setting up the RPC instance.
+   * @return A future that can be used to wait for the client connection, which also provides the
+   *         secret needed for the client to connect.
+   */
+  public Future<Rpc> registerClient(String secret, RpcDispatcher serverDispatcher) {
+    final Promise<Rpc> promise = group.next().newPromise();
+
+    Runnable timeout = new Runnable() {
+      @Override
+      public void run() {
+        promise.setFailure(new TimeoutException("Timed out waiting for client connection."));
+      }
+    };
+    ScheduledFuture<?> timeoutFuture = group.schedule(timeout,
+        config.getServerConnectTimeoutMs(),
+        TimeUnit.MILLISECONDS);
+    final ClientInfo client = new ClientInfo(promise, secret, serverDispatcher, timeoutFuture);
+    pendingClients.add(client);
+
+
+    promise.addListener(new GenericFutureListener<Promise<Rpc>>() {
+      @Override
+      public void operationComplete(Promise<Rpc> p) {
+        if (p.isCancelled()) {
+          pendingClients.remove(client);
+        }
+      }
+    });
+
+    return promise;
+  }
+
+  /**
+   * Creates a secret for identifying a client connection.
+   */
+  public String createSecret() {
+    byte[] secret = new byte[config.getSecretBits() / 8];
+    RND.nextBytes(secret);
+
+    StringBuilder sb = new StringBuilder();
+    for (byte b : secret) {
+      if (b < 10) {
+        sb.append("0");
+      }
+      sb.append(Integer.toHexString(b));
+    }
+    return sb.toString();
+  }
+
+  public String getAddress() {
+    return address;
+  }
+
+  public int getPort() {
+    return port;
+  }
+
+  @Override
+  public void close() {
+    try {
+      channel.close();
+      for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext(); ) {
+        ClientInfo client = clients.next();
+        clients.remove();
+        client.promise.cancel(true);
+      }
+    } finally {
+      group.shutdownGracefully();
+    }
+  }
+
+  private class HelloDispatcher extends RpcDispatcher {
+
+    private Rpc rpc;
+    private ScheduledFuture<?> cancelTask;
+
+    protected void handle(ChannelHandlerContext ctx, Rpc.Hello msg) {
+      cancelTask.cancel(true);
+
+      for (Iterator<ClientInfo> clients = pendingClients.iterator(); clients.hasNext(); ) {
+        ClientInfo client = clients.next();
+        if (client.secret.equals(msg.secret)) {
+          rpc.replaceDispatcher(client.dispatcher);
+          client.timeoutFuture.cancel(true);
+          client.promise.setSuccess(rpc);
+          return;
+        }
+      }
+
+      LOG.debug("Closing channel because secret '{}' does not match any pending client.",
+          msg.secret);
+      ctx.close();
+    }
+
+  }
+
+  private static class ClientInfo {
+
+    final Promise<Rpc> promise;
+    final String secret;
+    final RpcDispatcher dispatcher;
+    final ScheduledFuture<?> timeoutFuture;
+
+    private ClientInfo(Promise<Rpc> promise, String secret, RpcDispatcher dispatcher,
+        ScheduledFuture<?> timeoutFuture) {
+      this.promise = promise;
+      this.secret = secret;
+      this.dispatcher = dispatcher;
+      this.timeoutFuture = timeoutFuture;
+    }
+
+  }
+
+}

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounter.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounter.java?rev=1644322&r1=1644321&r2=1644322&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounter.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounter.java Wed Dec 10 03:14:57 2014
@@ -28,9 +28,23 @@ public class SparkCounter implements Ser
   private String name;
   private String displayName;
   private Accumulator<Long> accumulator;
-  // Values of accumulators can only be read on the SparkContext side
-  // In case of RSC, we have to keep the data here
-  private long accumValue = -1;
+
+  // Values of accumulators can only be read on the SparkContext side. This field is used when
+  // creating a snapshot to be sent to the RSC client.
+  private long accumValue;
+
+  public SparkCounter() {
+    // For serialization.
+  }
+
+  private SparkCounter(
+      String name,
+      String displayName,
+      long value) {
+    this.name = name;
+    this.displayName = displayName;
+    this.accumValue = value;
+  }
 
   public SparkCounter(
     String name,
@@ -47,9 +61,9 @@ public class SparkCounter implements Ser
   }
 
   public long getValue() {
-    try {
+    if (accumulator != null) {
       return accumulator.value();
-    } catch (UnsupportedOperationException e) {
+    } else {
       return accumValue;
     }
   }
@@ -70,8 +84,8 @@ public class SparkCounter implements Ser
     this.displayName = displayName;
   }
 
-  public void dumpValue() {
-    accumValue = accumulator.value();
+  SparkCounter snapshot() {
+    return new SparkCounter(name, displayName, accumulator.value());
   }
 
   class LongAccumulatorParam implements AccumulatorParam<Long> {

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounterGroup.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounterGroup.java?rev=1644322&r1=1644321&r2=1644322&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounterGroup.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounterGroup.java Wed Dec 10 03:14:57 2014
@@ -34,15 +34,18 @@ public class SparkCounterGroup implement
 
   private transient JavaSparkContext javaSparkContext;
 
-  public SparkCounterGroup(
-    String groupName,
-    String groupDisplayName,
-    JavaSparkContext javaSparkContext) {
+  private SparkCounterGroup() {
+    // For serialization.
+  }
 
+  public SparkCounterGroup(
+      String groupName,
+      String groupDisplayName,
+      JavaSparkContext javaSparkContext) {
     this.groupName = groupName;
     this.groupDisplayName = groupDisplayName;
     this.javaSparkContext = javaSparkContext;
-    sparkCounters = new HashMap<String, SparkCounter>();
+    this.sparkCounters = new HashMap<String, SparkCounter>();
   }
 
   public void createCounter(String name, long initValue) {
@@ -69,4 +72,14 @@ public class SparkCounterGroup implement
   public Map<String, SparkCounter> getSparkCounters() {
     return sparkCounters;
   }
+
+  SparkCounterGroup snapshot() {
+    SparkCounterGroup snapshot = new SparkCounterGroup(getGroupName(), getGroupDisplayName(), null);
+    for (SparkCounter counter : sparkCounters.values()) {
+      SparkCounter counterSnapshot = counter.snapshot();
+      snapshot.sparkCounters.put(counterSnapshot.getName(), counterSnapshot);
+    }
+    return snapshot;
+  }
+
 }

Modified: hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounters.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounters.java?rev=1644322&r1=1644321&r2=1644322&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounters.java (original)
+++ hive/branches/spark/spark-client/src/main/java/org/apache/hive/spark/counter/SparkCounters.java Wed Dec 10 03:14:57 2014
@@ -46,13 +46,15 @@ public class SparkCounters implements Se
 
   private Map<String, SparkCounterGroup> sparkCounterGroups;
 
-  private transient JavaSparkContext javaSparkContext;
-  private transient Configuration hiveConf;
+  private final transient JavaSparkContext javaSparkContext;
 
-  public SparkCounters(JavaSparkContext javaSparkContext, Configuration hiveConf) {
+  private SparkCounters() {
+    this(null);
+  }
+
+  public SparkCounters(JavaSparkContext javaSparkContext) {
     this.javaSparkContext = javaSparkContext;
-    this.hiveConf = hiveConf;
-    sparkCounterGroups = new HashMap<String, SparkCounterGroup>();
+    this.sparkCounterGroups = new HashMap<String, SparkCounterGroup>();
   }
 
   public void createCounter(Enum<?> key) {
@@ -143,14 +145,17 @@ public class SparkCounters implements Se
   }
 
   /**
-   * Dump all SparkCounter values.
-   * RSC should call this method before sending back the counters to client
+   * Create a snapshot of the current counters to send back to the client. This copies the values
+   * of all current counters into a new SparkCounters instance that cannot be used to update the
+   * counters, but will serialize cleanly when sent back to the RSC client.
    */
-  public void dumpAllCounters() {
-    for (SparkCounterGroup counterGroup : sparkCounterGroups.values()) {
-      for (SparkCounter counter : counterGroup.getSparkCounters().values()) {
-        counter.dumpValue();
-      }
+  public SparkCounters snapshot() {
+    SparkCounters snapshot = new SparkCounters();
+    for (SparkCounterGroup group : sparkCounterGroups.values()) {
+      SparkCounterGroup groupSnapshot = group.snapshot();
+      snapshot.sparkCounterGroups.put(groupSnapshot.getGroupName(), groupSnapshot);
     }
+    return snapshot;
   }
+
 }

Modified: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java?rev=1644322&r1=1644321&r2=1644322&view=diff
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java (original)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/TestSparkClient.java Wed Dec 10 03:14:57 2014
@@ -44,6 +44,8 @@ import org.junit.Before;
 import org.junit.Test;
 import static org.junit.Assert.*;
 
+import org.apache.hive.spark.counter.SparkCounters;
+
 public class TestSparkClient {
 
   // Timeouts are bad... mmmkay.
@@ -52,7 +54,7 @@ public class TestSparkClient {
   private Map<String, String> createConf(boolean local) {
     Map<String, String> conf = new HashMap<String, String>();
     if (local) {
-      conf.put(ClientUtils.CONF_KEY_IN_PROCESS, "true");
+      conf.put(SparkClientFactory.CONF_KEY_IN_PROCESS, "true");
       conf.put("spark.master", "local");
       conf.put("spark.app.name", "SparkClientSuite Local App");
     } else {
@@ -194,21 +196,23 @@ public class TestSparkClient {
   }
 
   @Test
-  public void testKryoSerializer() throws Exception {
+  public void testCounters() throws Exception {
     runTest(true, new TestFunction() {
       @Override
       public void call(SparkClient client) throws Exception {
-        JobHandle<Long> handle = client.submit(new SparkJob());
-        assertEquals(Long.valueOf(5L), handle.get(TIMEOUT, TimeUnit.SECONDS));
-      }
+        JobHandle<?> job = client.submit(new CounterIncrementJob());
+        job.get(TIMEOUT, TimeUnit.SECONDS);
+
+        SparkCounters counters = job.getSparkCounters();
+        assertNotNull(counters);
 
-      @Override void config(Map<String, String> conf) {
-        conf.put(ClientUtils.CONF_KEY_SERIALIZER, "kryo");
+        long expected = 1 + 2 + 3 + 4 + 5;
+        assertEquals(expected, counters.getCounter("group1", "counter1").getValue());
+        assertEquals(expected, counters.getCounter("group2", "counter2").getValue());
       }
     });
   }
 
-
   private void runTest(boolean local, TestFunction test) throws Exception {
     Map<String, String> conf = createConf(local);
     SparkClientFactory.initialize(conf);
@@ -294,6 +298,10 @@ public class TestSparkClient {
 
     private final String fileName;
 
+    FileJob() {
+      this(null);
+    }
+
     FileJob(String fileName) {
       this.fileName = fileName;
     }
@@ -312,6 +320,30 @@ public class TestSparkClient {
     }
 
   }
+
+  private static class CounterIncrementJob implements Job<String>, VoidFunction<Integer> {
+
+    private SparkCounters counters;
+
+    @Override
+    public String call(JobContext jc) {
+      counters = new SparkCounters(jc.sc());
+      counters.createCounter("group1", "counter1");
+      counters.createCounter("group2", "counter2");
+
+      jc.monitor(jc.sc().parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).foreachAsync(this),
+          counters);
+
+      return null;
+    }
+
+    @Override
+    public void call(Integer l) throws Exception {
+      counters.getCounter("group1", "counter1").increment(l.longValue());
+      counters.getCounter("group2", "counter2").increment(l.longValue());
+    }
+
+  }
 
   private static abstract class TestFunction {
     abstract void call(SparkClient client) throws Exception;

Added: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java?rev=1644322&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java (added)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestKryoMessageCodec.java Wed Dec 10 03:14:57 2014
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.UnpooledByteBufAllocator;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.logging.LoggingHandler;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+public class TestKryoMessageCodec {
+
+  @Test
+  public void testKryoCodec() throws Exception {
+    ByteBuf buf = newBuffer();
+    Object message = "Hello World!";
+
+    KryoMessageCodec codec = new KryoMessageCodec(0);
+    codec.encode(null, message, buf);
+
+    List<Object> objects = Lists.newArrayList();
+    codec.decode(null, buf, objects);
+
+    assertEquals(1, objects.size());
+    assertEquals(message, objects.get(0));
+  }
+
+  @Test
+  public void testFragmentation() throws Exception {
+    ByteBuf buf = newBuffer();
+    Object[] messages = { "msg1", "msg2" };
+    int[] indices = new int[messages.length];
+
+    KryoMessageCodec codec = new KryoMessageCodec(0);
+
+    for (int i = 0; i < messages.length; i++) {
+      codec.encode(null, messages[i], buf);
+      indices[i] = buf.writerIndex();
+    }
+
+    List<Object> objects = Lists.newArrayList();
+
+    // Don't read enough data for the first message to be decoded.
+    codec.decode(null, buf.slice(0, indices[0] - 1), objects);
+    assertEquals(0, objects.size());
+
+    // Read enough data for just the first message to be decoded.
+    codec.decode(null, buf.slice(0, indices[0] + 1), objects);
+    assertEquals(1, objects.size());
+  }
+
+  @Test
+  public void testEmbeddedChannel() throws Exception {
+    Object message = "Hello World!";
+    EmbeddedChannel c = new EmbeddedChannel(
+      new LoggingHandler(getClass()),
+      new KryoMessageCodec(0));
+    c.writeAndFlush(message);
+    assertEquals(1, c.outboundMessages().size());
+    assertFalse(message.getClass().equals(c.outboundMessages().peek().getClass()));
+    c.writeInbound(c.readOutbound());
+    assertEquals(1, c.inboundMessages().size());
+    assertEquals(message, c.readInbound());
+    c.close();
+  }
+
+  @Test
+  public void testAutoRegistration() throws Exception {
+    KryoMessageCodec codec = new KryoMessageCodec(0, TestMessage.class);
+    ByteBuf buf = newBuffer();
+    codec.encode(null, new TestMessage(), buf);
+
+    List<Object> out = Lists.newArrayList();
+    codec.decode(null, buf, out);
+
+    assertEquals(1, out.size());
+    assertTrue(out.get(0) instanceof TestMessage);
+  }
+
+  @Test
+  public void testMaxMessageSize() throws Exception {
+    KryoMessageCodec codec = new KryoMessageCodec(1024);
+    ByteBuf buf = newBuffer();
+    codec.encode(null, new TestMessage(new byte[512]), buf);
+
+    try {
+      codec.encode(null, new TestMessage(new byte[1025]), buf);
+      fail("Should have failed to encode large message.");
+    } catch (IllegalArgumentException e) {
+      assertTrue(e.getMessage().indexOf("maximum allowed size") > 0);
+    }
+
+    KryoMessageCodec unlimited = new KryoMessageCodec(0);
+    buf = newBuffer();
+    unlimited.encode(null, new TestMessage(new byte[1025]), buf);
+
+    try {
+      List<Object> out = Lists.newArrayList();
+      codec.decode(null, buf, out);
+      fail("Should have failed to decode large message.");
+    } catch (IllegalArgumentException e) {
+      assertTrue(e.getMessage().indexOf("maximum allowed size") > 0);
+    }
+  }
+
+  @Test
+  public void testNegativeMessageSize() throws Exception {
+    KryoMessageCodec codec = new KryoMessageCodec(1024);
+    ByteBuf buf = newBuffer();
+    buf.writeInt(-1);
+
+    try {
+      List<Object> out = Lists.newArrayList();
+      codec.decode(null, buf, out);
+      fail("Should have failed to decode message with negative size.");
+    } catch (IllegalArgumentException e) {
+      assertTrue(e.getMessage().indexOf("must be positive") > 0);
+    }
+  }
+
+  private ByteBuf newBuffer() {
+    return UnpooledByteBufAllocator.DEFAULT.buffer(1024);
+  }
+
+  private static class TestMessage {
+    byte[] data;
+
+    TestMessage() {
+      this(null);
+    }
+
+    TestMessage(byte[] data) {
+      this.data = data;
+    }
+  }
+
+}

Added: hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java
URL: http://svn.apache.org/viewvc/hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java?rev=1644322&view=auto
==============================================================================
--- hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java (added)
+++ hive/branches/spark/spark-client/src/test/java/org/apache/hive/spark/client/rpc/TestRpc.java Wed Dec 10 03:14:57 2014
@@ -0,0 +1,260 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hive.spark.client.rpc;
+
+import java.io.Closeable;
+import java.util.Collection;
+import java.util.Map;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.util.concurrent.Future;
+import org.apache.commons.io.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+public class TestRpc {
+
+  private static final Logger LOG = LoggerFactory.getLogger(TestRpc.class);
+
+  private Collection<Closeable> closeables;
+  private Map<String, String> emptyConfig =
+      ImmutableMap.of(RpcConfiguration.RPC_CHANNEL_LOG_LEVEL_KEY, "DEBUG");
+
+  @Before
+  public void setUp() {
+    closeables = Lists.newArrayList();
+  }
+
+  @After
+  public void cleanUp() throws Exception {
+    for (Closeable c : closeables) {
+      IOUtils.closeQuietly(c);
+    }
+  }
+
+  private <T extends Closeable> T autoClose(T closeable) {
+    closeables.add(closeable);
+    return closeable;
+  }
+
+  @Test
+  public void testRpcDispatcher() throws Exception {
+    Rpc serverRpc = autoClose(Rpc.createEmbedded(new TestDispatcher()));
+    Rpc clientRpc = autoClose(Rpc.createEmbedded(new TestDispatcher()));
+
+    TestMessage outbound = new TestMessage("Hello World!");
+    Future<TestMessage> call = clientRpc.call(outbound, TestMessage.class);
+
+    LOG.debug("Transferring messages...");
+    transfer(serverRpc, clientRpc);
+
+    TestMessage reply = call.get(10, TimeUnit.SECONDS);
+    assertEquals(outbound.message, reply.message);
+  }
+
+  @Test
+  public void testClientServer() throws Exception {
+    RpcServer server = autoClose(new RpcServer(emptyConfig));
+    Rpc[] rpcs = createRpcConnection(server);
+    Rpc serverRpc = rpcs[0];
+    Rpc client = rpcs[1];
+
+    TestMessage outbound = new TestMessage("Hello World!");
+    Future<TestMessage> call = client.call(outbound, TestMessage.class);
+    TestMessage reply = call.get(10, TimeUnit.SECONDS);
+    assertEquals(outbound.message, reply.message);
+
+    TestMessage another = new TestMessage("Hello again!");
+    Future<TestMessage> anotherCall = client.call(another, TestMessage.class);
+    TestMessage anotherReply = anotherCall.get(10, TimeUnit.SECONDS);
+    assertEquals(another.message, anotherReply.message);
+
+    String errorMsg = "This is an error.";
+    try {
+      client.call(new ErrorCall(errorMsg)).get(10, TimeUnit.SECONDS);
+    } catch (ExecutionException ee) {
+      assertTrue(ee.getCause() instanceof RpcException);
+      assertTrue(ee.getCause().getMessage().indexOf(errorMsg) >= 0);
+    }
+
+    // Test from server to client too.
+    TestMessage serverMsg = new TestMessage("Hello from the server!");
+    Future<TestMessage> serverCall = serverRpc.call(serverMsg, TestMessage.class);
+    TestMessage serverReply = serverCall.get(10, TimeUnit.SECONDS);
+    assertEquals(serverMsg.message, serverReply.message);
+  }
+
+  @Test
+  public void testBadHello() throws Exception {
+    RpcServer server = autoClose(new RpcServer(emptyConfig));
+
+    Future<Rpc> serverRpcFuture = server.registerClient("newClient", new TestDispatcher());
+    NioEventLoopGroup eloop = new NioEventLoopGroup();
+
+    Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
+        "localhost", server.getPort(), "wrongClient", new TestDispatcher());
+
+    try {
+      autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
+      fail("Should have failed to create client with wrong secret.");
+    } catch (ExecutionException ee) {
+      // On failure, the server will close the channel. This will cause the client's promise
+      // to be cancelled.
+      assertTrue("Unexpected exception: " + ee.getCause(),
+        ee.getCause() instanceof CancellationException);
+    }
+
+    serverRpcFuture.cancel(true);
+  }
+
+  @Test
+  public void testCloseListener() throws Exception {
+    RpcServer server = autoClose(new RpcServer(emptyConfig));
+    Rpc[] rpcs = createRpcConnection(server);
+    Rpc client = rpcs[1];
+
+    final AtomicInteger closeCount = new AtomicInteger();
+    client.addListener(new Rpc.Listener() {
+        @Override
+        public void rpcClosed(Rpc rpc) {
+          closeCount.incrementAndGet();
+        }
+    });
+
+    client.close();
+    client.close();
+    assertEquals(1, closeCount.get());
+  }
+
+  @Test
+  public void testNotDeserializableRpc() throws Exception {
+    RpcServer server = autoClose(new RpcServer(emptyConfig));
+    Rpc[] rpcs = createRpcConnection(server);
+    Rpc client = rpcs[1];
+
+    try {
+      client.call(new NotDeserializable(42)).get(10, TimeUnit.SECONDS);
+    } catch (ExecutionException ee) {
+      assertTrue(ee.getCause() instanceof RpcException);
+      assertTrue(ee.getCause().getMessage().indexOf("KryoException") >= 0);
+    }
+  }
+
+  private void transfer(Rpc serverRpc, Rpc clientRpc) {
+    EmbeddedChannel client = (EmbeddedChannel) clientRpc.getChannel();
+    EmbeddedChannel server = (EmbeddedChannel) serverRpc.getChannel();
+
+    int count = 0;
+    while (!client.outboundMessages().isEmpty()) {
+      server.writeInbound(client.readOutbound());
+      count++;
+    }
+    server.flush();
+    LOG.debug("Transferred {} outbound client messages.", count);
+
+    count = 0;
+    while (!server.outboundMessages().isEmpty()) {
+      client.writeInbound(server.readOutbound());
+      count++;
+    }
+    client.flush();
+    LOG.debug("Transferred {} outbound server messages.", count);
+  }
+
+  /**
+   * Creates a client connection between the server and a client.
+   *
+   * @return two-tuple (server rpc, client rpc)
+   */
+  private Rpc[] createRpcConnection(RpcServer server) throws Exception {
+    String secret = server.createSecret();
+    Future<Rpc> serverRpcFuture = server.registerClient(secret, new TestDispatcher());
+    NioEventLoopGroup eloop = new NioEventLoopGroup();
+    Future<Rpc> clientRpcFuture = Rpc.createClient(emptyConfig, eloop,
+        "localhost", server.getPort(), secret, new TestDispatcher());
+
+    Rpc serverRpc = autoClose(serverRpcFuture.get(10, TimeUnit.SECONDS));
+    Rpc clientRpc = autoClose(clientRpcFuture.get(10, TimeUnit.SECONDS));
+    return new Rpc[] { serverRpc, clientRpc };
+  }
+
+  private static class TestMessage {
+
+    final String message;
+
+    public TestMessage() {
+      this(null);
+    }
+
+    public TestMessage(String message) {
+      this.message = message;
+    }
+
+  }
+
+  private static class ErrorCall {
+
+    final String error;
+
+    public ErrorCall() {
+      this(null);
+    }
+
+    public ErrorCall(String error) {
+      this.error = error;
+    }
+
+  }
+
+  private static class NotDeserializable {
+
+    NotDeserializable(int unused) {
+
+    }
+
+  }
+
+  private static class TestDispatcher extends RpcDispatcher {
+    protected TestMessage handle(ChannelHandlerContext ctx, TestMessage msg) {
+      return msg;
+    }
+
+    protected void handle(ChannelHandlerContext ctx, ErrorCall msg) {
+      throw new IllegalArgumentException(msg.error);
+    }
+
+    protected void handle(ChannelHandlerContext ctx, NotDeserializable msg) {
+      // No op. Shouldn't actually be called, if it is, the test will fail.
+    }
+
+  };
+
+}