You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@zeppelin.apache.org by zj...@apache.org on 2018/02/02 06:00:52 UTC
[07/10] zeppelin git commit: ZEPPELIN-3111. Refactor SparkInterpreter
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java
new file mode 100644
index 0000000..80ea03b
--- /dev/null
+++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/ZeppelinRContext.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import org.apache.spark.SparkContext;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SQLContext;
+
+/**
+ * Contains the Spark and Zeppelin Contexts made available to SparkR.
+ */
+public class ZeppelinRContext {
+ private static SparkContext sparkContext;
+ private static SQLContext sqlContext;
+ private static SparkZeppelinContext zeppelinContext;
+ private static Object sparkSession;
+ private static JavaSparkContext javaSparkContext;
+
+ public static void setSparkContext(SparkContext sparkContext) {
+ ZeppelinRContext.sparkContext = sparkContext;
+ }
+
+ public static void setZeppelinContext(SparkZeppelinContext zeppelinContext) {
+ ZeppelinRContext.zeppelinContext = zeppelinContext;
+ }
+
+ public static void setSqlContext(SQLContext sqlContext) {
+ ZeppelinRContext.sqlContext = sqlContext;
+ }
+
+ public static void setSparkSession(Object sparkSession) {
+ ZeppelinRContext.sparkSession = sparkSession;
+ }
+
+ public static SparkContext getSparkContext() {
+ return sparkContext;
+ }
+
+ public static SQLContext getSqlContext() {
+ return sqlContext;
+ }
+
+ public static SparkZeppelinContext getZeppelinContext() {
+ return zeppelinContext;
+ }
+
+ public static Object getSparkSession() {
+ return sparkSession;
+ }
+
+ public static void setJavaSparkContext(JavaSparkContext jsc) { javaSparkContext = jsc; }
+
+ public static JavaSparkContext getJavaSparkContext() { return javaSparkContext; }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java
new file mode 100644
index 0000000..0235fc6
--- /dev/null
+++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyContext.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark.dep;
+
+import java.io.File;
+import java.net.MalformedURLException;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.zeppelin.dep.Booter;
+import org.apache.zeppelin.dep.Dependency;
+import org.apache.zeppelin.dep.Repository;
+
+import org.sonatype.aether.RepositorySystem;
+import org.sonatype.aether.RepositorySystemSession;
+import org.sonatype.aether.artifact.Artifact;
+import org.sonatype.aether.collection.CollectRequest;
+import org.sonatype.aether.graph.DependencyFilter;
+import org.sonatype.aether.repository.RemoteRepository;
+import org.sonatype.aether.repository.Authentication;
+import org.sonatype.aether.resolution.ArtifactResolutionException;
+import org.sonatype.aether.resolution.ArtifactResult;
+import org.sonatype.aether.resolution.DependencyRequest;
+import org.sonatype.aether.resolution.DependencyResolutionException;
+import org.sonatype.aether.util.artifact.DefaultArtifact;
+import org.sonatype.aether.util.artifact.JavaScopes;
+import org.sonatype.aether.util.filter.DependencyFilterUtils;
+import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter;
+
+
+/**
+ *
+ */
+public class SparkDependencyContext {
+ List<Dependency> dependencies = new LinkedList<>();
+ List<Repository> repositories = new LinkedList<>();
+
+ List<File> files = new LinkedList<>();
+ List<File> filesDist = new LinkedList<>();
+ private RepositorySystem system = Booter.newRepositorySystem();
+ private RepositorySystemSession session;
+ private RemoteRepository mavenCentral = Booter.newCentralRepository();
+ private RemoteRepository mavenLocal = Booter.newLocalRepository();
+ private List<RemoteRepository> additionalRepos = new LinkedList<>();
+
+ public SparkDependencyContext(String localRepoPath, String additionalRemoteRepository) {
+ session = Booter.newRepositorySystemSession(system, localRepoPath);
+ addRepoFromProperty(additionalRemoteRepository);
+ }
+
+ public Dependency load(String lib) {
+ Dependency dep = new Dependency(lib);
+
+ if (dependencies.contains(dep)) {
+ dependencies.remove(dep);
+ }
+ dependencies.add(dep);
+ return dep;
+ }
+
+ public Repository addRepo(String name) {
+ Repository rep = new Repository(name);
+ repositories.add(rep);
+ return rep;
+ }
+
+ public void reset() {
+ dependencies = new LinkedList<>();
+ repositories = new LinkedList<>();
+
+ files = new LinkedList<>();
+ filesDist = new LinkedList<>();
+ }
+
+ private void addRepoFromProperty(String listOfRepo) {
+ if (listOfRepo != null) {
+ String[] repos = listOfRepo.split(";");
+ for (String repo : repos) {
+ String[] parts = repo.split(",");
+ if (parts.length == 3) {
+ String id = parts[0].trim();
+ String url = parts[1].trim();
+ boolean isSnapshot = Boolean.parseBoolean(parts[2].trim());
+ if (id.length() > 1 && url.length() > 1) {
+ RemoteRepository rr = new RemoteRepository(id, "default", url);
+ rr.setPolicy(isSnapshot, null);
+ additionalRepos.add(rr);
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * fetch all artifacts
+ * @return
+ * @throws MalformedURLException
+ * @throws ArtifactResolutionException
+ * @throws DependencyResolutionException
+ */
+ public List<File> fetch() throws MalformedURLException,
+ DependencyResolutionException, ArtifactResolutionException {
+
+ for (Dependency dep : dependencies) {
+ if (!dep.isLocalFsArtifact()) {
+ List<ArtifactResult> artifacts = fetchArtifactWithDep(dep);
+ for (ArtifactResult artifact : artifacts) {
+ if (dep.isDist()) {
+ filesDist.add(artifact.getArtifact().getFile());
+ }
+ files.add(artifact.getArtifact().getFile());
+ }
+ } else {
+ if (dep.isDist()) {
+ filesDist.add(new File(dep.getGroupArtifactVersion()));
+ }
+ files.add(new File(dep.getGroupArtifactVersion()));
+ }
+ }
+
+ return files;
+ }
+
+ private List<ArtifactResult> fetchArtifactWithDep(Dependency dep)
+ throws DependencyResolutionException, ArtifactResolutionException {
+ Artifact artifact = new DefaultArtifact(
+ SparkDependencyResolver.inferScalaVersion(dep.getGroupArtifactVersion()));
+
+ DependencyFilter classpathFlter = DependencyFilterUtils
+ .classpathFilter(JavaScopes.COMPILE);
+ PatternExclusionsDependencyFilter exclusionFilter = new PatternExclusionsDependencyFilter(
+ SparkDependencyResolver.inferScalaVersion(dep.getExclusions()));
+
+ CollectRequest collectRequest = new CollectRequest();
+ collectRequest.setRoot(new org.sonatype.aether.graph.Dependency(artifact,
+ JavaScopes.COMPILE));
+
+ collectRequest.addRepository(mavenCentral);
+ collectRequest.addRepository(mavenLocal);
+ for (RemoteRepository repo : additionalRepos) {
+ collectRequest.addRepository(repo);
+ }
+ for (Repository repo : repositories) {
+ RemoteRepository rr = new RemoteRepository(repo.getId(), "default", repo.getUrl());
+ rr.setPolicy(repo.isSnapshot(), null);
+ Authentication auth = repo.getAuthentication();
+ if (auth != null) {
+ rr.setAuthentication(auth);
+ }
+ collectRequest.addRepository(rr);
+ }
+
+ DependencyRequest dependencyRequest = new DependencyRequest(collectRequest,
+ DependencyFilterUtils.andFilter(exclusionFilter, classpathFlter));
+
+ return system.resolveDependencies(session, dependencyRequest).getArtifactResults();
+ }
+
+ public List<File> getFiles() {
+ return files;
+ }
+
+ public List<File> getFilesDist() {
+ return filesDist;
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java
new file mode 100644
index 0000000..46224a8
--- /dev/null
+++ b/spark/interpreter/src/main/java/org/apache/zeppelin/spark/dep/SparkDependencyResolver.java
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark.dep;
+
+import java.io.File;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.net.URL;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.spark.SparkContext;
+import org.apache.zeppelin.dep.AbstractDependencyResolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.sonatype.aether.artifact.Artifact;
+import org.sonatype.aether.collection.CollectRequest;
+import org.sonatype.aether.graph.Dependency;
+import org.sonatype.aether.graph.DependencyFilter;
+import org.sonatype.aether.repository.RemoteRepository;
+import org.sonatype.aether.resolution.ArtifactResult;
+import org.sonatype.aether.resolution.DependencyRequest;
+import org.sonatype.aether.util.artifact.DefaultArtifact;
+import org.sonatype.aether.util.artifact.JavaScopes;
+import org.sonatype.aether.util.filter.DependencyFilterUtils;
+import org.sonatype.aether.util.filter.PatternExclusionsDependencyFilter;
+
+import scala.Some;
+import scala.collection.IndexedSeq;
+import scala.reflect.io.AbstractFile;
+import scala.tools.nsc.Global;
+import scala.tools.nsc.backend.JavaPlatform;
+import scala.tools.nsc.util.ClassPath;
+import scala.tools.nsc.util.MergedClassPath;
+
+/**
+ * Deps resolver.
+ * Add new dependencies from mvn repo (at runtime) to Spark interpreter group.
+ */
+public class SparkDependencyResolver extends AbstractDependencyResolver {
+ Logger logger = LoggerFactory.getLogger(SparkDependencyResolver.class);
+ private Global global;
+ private ClassLoader runtimeClassLoader;
+ private SparkContext sc;
+
+ private final String[] exclusions = new String[] {"org.scala-lang:scala-library",
+ "org.scala-lang:scala-compiler",
+ "org.scala-lang:scala-reflect",
+ "org.scala-lang:scalap",
+ "org.apache.zeppelin:zeppelin-zengine",
+ "org.apache.zeppelin:zeppelin-spark",
+ "org.apache.zeppelin:zeppelin-server"};
+
+ public SparkDependencyResolver(Global global,
+ ClassLoader runtimeClassLoader,
+ SparkContext sc,
+ String localRepoPath,
+ String additionalRemoteRepository) {
+ super(localRepoPath);
+ this.global = global;
+ this.runtimeClassLoader = runtimeClassLoader;
+ this.sc = sc;
+ addRepoFromProperty(additionalRemoteRepository);
+ }
+
+ private void addRepoFromProperty(String listOfRepo) {
+ if (listOfRepo != null) {
+ String[] repos = listOfRepo.split(";");
+ for (String repo : repos) {
+ String[] parts = repo.split(",");
+ if (parts.length == 3) {
+ String id = parts[0].trim();
+ String url = parts[1].trim();
+ boolean isSnapshot = Boolean.parseBoolean(parts[2].trim());
+ if (id.length() > 1 && url.length() > 1) {
+ addRepo(id, url, isSnapshot);
+ }
+ }
+ }
+ }
+ }
+
+ private void updateCompilerClassPath(URL[] urls) throws IllegalAccessException,
+ IllegalArgumentException, InvocationTargetException {
+
+ JavaPlatform platform = (JavaPlatform) global.platform();
+ MergedClassPath<AbstractFile> newClassPath = mergeUrlsIntoClassPath(platform, urls);
+
+ Method[] methods = platform.getClass().getMethods();
+ for (Method m : methods) {
+ if (m.getName().endsWith("currentClassPath_$eq")) {
+ m.invoke(platform, new Some(newClassPath));
+ break;
+ }
+ }
+
+ // NOTE: Must use reflection until this is exposed/fixed upstream in Scala
+ List<String> classPaths = new LinkedList<>();
+ for (URL url : urls) {
+ classPaths.add(url.getPath());
+ }
+
+ // Reload all jars specified into our compiler
+ global.invalidateClassPathEntries(scala.collection.JavaConversions.asScalaBuffer(classPaths)
+ .toList());
+ }
+
+ // Until spark 1.1.x
+ // check https://github.com/apache/spark/commit/191d7cf2a655d032f160b9fa181730364681d0e7
+ private void updateRuntimeClassPath_1_x(URL[] urls) throws SecurityException,
+ IllegalAccessException, IllegalArgumentException,
+ InvocationTargetException, NoSuchMethodException {
+ Method addURL;
+ addURL = runtimeClassLoader.getClass().getDeclaredMethod("addURL", new Class[] {URL.class});
+ addURL.setAccessible(true);
+ for (URL url : urls) {
+ addURL.invoke(runtimeClassLoader, url);
+ }
+ }
+
+ private void updateRuntimeClassPath_2_x(URL[] urls) throws SecurityException,
+ IllegalAccessException, IllegalArgumentException,
+ InvocationTargetException, NoSuchMethodException {
+ Method addURL;
+ addURL = runtimeClassLoader.getClass().getDeclaredMethod("addNewUrl", new Class[] {URL.class});
+ addURL.setAccessible(true);
+ for (URL url : urls) {
+ addURL.invoke(runtimeClassLoader, url);
+ }
+ }
+
+ private MergedClassPath<AbstractFile> mergeUrlsIntoClassPath(JavaPlatform platform, URL[] urls) {
+ IndexedSeq<ClassPath<AbstractFile>> entries =
+ ((MergedClassPath<AbstractFile>) platform.classPath()).entries();
+ List<ClassPath<AbstractFile>> cp = new LinkedList<>();
+
+ for (int i = 0; i < entries.size(); i++) {
+ cp.add(entries.apply(i));
+ }
+
+ for (URL url : urls) {
+ AbstractFile file;
+ if ("file".equals(url.getProtocol())) {
+ File f = new File(url.getPath());
+ if (f.isDirectory()) {
+ file = AbstractFile.getDirectory(scala.reflect.io.File.jfile2path(f));
+ } else {
+ file = AbstractFile.getFile(scala.reflect.io.File.jfile2path(f));
+ }
+ } else {
+ file = AbstractFile.getURL(url);
+ }
+
+ ClassPath<AbstractFile> newcp = platform.classPath().context().newClassPath(file);
+
+ // distinct
+ if (cp.contains(newcp) == false) {
+ cp.add(newcp);
+ }
+ }
+
+ return new MergedClassPath(scala.collection.JavaConversions.asScalaBuffer(cp).toIndexedSeq(),
+ platform.classPath().context());
+ }
+
+ public List<String> load(String artifact,
+ boolean addSparkContext) throws Exception {
+ return load(artifact, new LinkedList<String>(), addSparkContext);
+ }
+
+ public List<String> load(String artifact, Collection<String> excludes,
+ boolean addSparkContext) throws Exception {
+ if (StringUtils.isBlank(artifact)) {
+ // Should throw here
+ throw new RuntimeException("Invalid artifact to load");
+ }
+
+ // <groupId>:<artifactId>[:<extension>[:<classifier>]]:<version>
+ int numSplits = artifact.split(":").length;
+ if (numSplits >= 3 && numSplits <= 6) {
+ return loadFromMvn(artifact, excludes, addSparkContext);
+ } else {
+ loadFromFs(artifact, addSparkContext);
+ LinkedList<String> libs = new LinkedList<>();
+ libs.add(artifact);
+ return libs;
+ }
+ }
+
+ private void loadFromFs(String artifact, boolean addSparkContext) throws Exception {
+ File jarFile = new File(artifact);
+
+ global.new Run();
+
+ if (sc.version().startsWith("1.1")) {
+ updateRuntimeClassPath_1_x(new URL[] {jarFile.toURI().toURL()});
+ } else {
+ updateRuntimeClassPath_2_x(new URL[] {jarFile.toURI().toURL()});
+ }
+
+ if (addSparkContext) {
+ sc.addJar(jarFile.getAbsolutePath());
+ }
+ }
+
+ private List<String> loadFromMvn(String artifact, Collection<String> excludes,
+ boolean addSparkContext) throws Exception {
+ List<String> loadedLibs = new LinkedList<>();
+ Collection<String> allExclusions = new LinkedList<>();
+ allExclusions.addAll(excludes);
+ allExclusions.addAll(Arrays.asList(exclusions));
+
+ List<ArtifactResult> listOfArtifact;
+ listOfArtifact = getArtifactsWithDep(artifact, allExclusions);
+
+ Iterator<ArtifactResult> it = listOfArtifact.iterator();
+ while (it.hasNext()) {
+ Artifact a = it.next().getArtifact();
+ String gav = a.getGroupId() + ":" + a.getArtifactId() + ":" + a.getVersion();
+ for (String exclude : allExclusions) {
+ if (gav.startsWith(exclude)) {
+ it.remove();
+ break;
+ }
+ }
+ }
+
+ List<URL> newClassPathList = new LinkedList<>();
+ List<File> files = new LinkedList<>();
+ for (ArtifactResult artifactResult : listOfArtifact) {
+ logger.info("Load " + artifactResult.getArtifact().getGroupId() + ":"
+ + artifactResult.getArtifact().getArtifactId() + ":"
+ + artifactResult.getArtifact().getVersion());
+ newClassPathList.add(artifactResult.getArtifact().getFile().toURI().toURL());
+ files.add(artifactResult.getArtifact().getFile());
+ loadedLibs.add(artifactResult.getArtifact().getGroupId() + ":"
+ + artifactResult.getArtifact().getArtifactId() + ":"
+ + artifactResult.getArtifact().getVersion());
+ }
+
+ global.new Run();
+ if (sc.version().startsWith("1.1")) {
+ updateRuntimeClassPath_1_x(newClassPathList.toArray(new URL[0]));
+ } else {
+ updateRuntimeClassPath_2_x(newClassPathList.toArray(new URL[0]));
+ }
+ updateCompilerClassPath(newClassPathList.toArray(new URL[0]));
+
+ if (addSparkContext) {
+ for (File f : files) {
+ sc.addJar(f.getAbsolutePath());
+ }
+ }
+
+ return loadedLibs;
+ }
+
+ /**
+ * @param dependency
+ * @param excludes list of pattern can either be of the form groupId:artifactId
+ * @return
+ * @throws Exception
+ */
+ @Override
+ public List<ArtifactResult> getArtifactsWithDep(String dependency,
+ Collection<String> excludes) throws Exception {
+ Artifact artifact = new DefaultArtifact(inferScalaVersion(dependency));
+ DependencyFilter classpathFilter = DependencyFilterUtils.classpathFilter(JavaScopes.COMPILE);
+ PatternExclusionsDependencyFilter exclusionFilter =
+ new PatternExclusionsDependencyFilter(inferScalaVersion(excludes));
+
+ CollectRequest collectRequest = new CollectRequest();
+ collectRequest.setRoot(new Dependency(artifact, JavaScopes.COMPILE));
+
+ synchronized (repos) {
+ for (RemoteRepository repo : repos) {
+ collectRequest.addRepository(repo);
+ }
+ }
+ DependencyRequest dependencyRequest = new DependencyRequest(collectRequest,
+ DependencyFilterUtils.andFilter(exclusionFilter, classpathFilter));
+ return system.resolveDependencies(session, dependencyRequest).getArtifactResults();
+ }
+
+ public static Collection<String> inferScalaVersion(Collection<String> artifact) {
+ List<String> list = new LinkedList<>();
+ for (String a : artifact) {
+ list.add(inferScalaVersion(a));
+ }
+ return list;
+ }
+
+ public static String inferScalaVersion(String artifact) {
+ int pos = artifact.indexOf(":");
+ if (pos < 0 || pos + 2 >= artifact.length()) {
+ // failed to infer
+ return artifact;
+ }
+
+ if (':' == artifact.charAt(pos + 1)) {
+ String restOfthem = "";
+ String versionSep = ":";
+
+ String groupId = artifact.substring(0, pos);
+ int nextPos = artifact.indexOf(":", pos + 2);
+ if (nextPos < 0) {
+ if (artifact.charAt(artifact.length() - 1) == '*') {
+ nextPos = artifact.length() - 1;
+ versionSep = "";
+ restOfthem = "*";
+ } else {
+ versionSep = "";
+ nextPos = artifact.length();
+ }
+ }
+
+ String artifactId = artifact.substring(pos + 2, nextPos);
+ if (nextPos < artifact.length()) {
+ if (!restOfthem.equals("*")) {
+ restOfthem = artifact.substring(nextPos + 1);
+ }
+ }
+
+ String [] version = scala.util.Properties.versionNumberString().split("[.]");
+ String scalaVersion = version[0] + "." + version[1];
+
+ return groupId + ":" + artifactId + "_" + scalaVersion + versionSep + restOfthem;
+ } else {
+ return artifact;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R b/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R
new file mode 100644
index 0000000..525c6c5
--- /dev/null
+++ b/spark/interpreter/src/main/resources/R/zeppelin_sparkr.R
@@ -0,0 +1,105 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+args <- commandArgs(trailingOnly = TRUE)
+
+hashCode <- as.integer(args[1])
+port <- as.integer(args[2])
+libPath <- args[3]
+version <- as.integer(args[4])
+rm(args)
+
+print(paste("Port ", toString(port)))
+print(paste("LibPath ", libPath))
+
+.libPaths(c(file.path(libPath), .libPaths()))
+library(SparkR)
+
+
+SparkR:::connectBackend("localhost", port, 6000)
+
+# scStartTime is needed by R/pkg/R/sparkR.R
+assign(".scStartTime", as.integer(Sys.time()), envir = SparkR:::.sparkREnv)
+
+# getZeppelinR
+.zeppelinR = SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinR", "getZeppelinR", hashCode)
+
+# setup spark env
+assign(".sc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkContext"), envir = SparkR:::.sparkREnv)
+assign("sc", get(".sc", envir = SparkR:::.sparkREnv), envir=.GlobalEnv)
+if (version >= 20000) {
+ assign(".sparkRsession", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSparkSession"), envir = SparkR:::.sparkREnv)
+ assign("spark", get(".sparkRsession", envir = SparkR:::.sparkREnv), envir = .GlobalEnv)
+ assign(".sparkRjsc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getJavaSparkContext"), envir = SparkR:::.sparkREnv)
+}
+assign(".sqlc", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getSqlContext"), envir = SparkR:::.sparkREnv)
+assign("sqlContext", get(".sqlc", envir = SparkR:::.sparkREnv), envir = .GlobalEnv)
+assign(".zeppelinContext", SparkR:::callJStatic("org.apache.zeppelin.spark.ZeppelinRContext", "getZeppelinContext"), envir = .GlobalEnv)
+
+z.put <- function(name, object) {
+ SparkR:::callJMethod(.zeppelinContext, "put", name, object)
+}
+z.get <- function(name) {
+ SparkR:::callJMethod(.zeppelinContext, "get", name)
+}
+z.input <- function(name, value) {
+ SparkR:::callJMethod(.zeppelinContext, "input", name, value)
+}
+
+# notify script is initialized
+SparkR:::callJMethod(.zeppelinR, "onScriptInitialized")
+
+while (TRUE) {
+ req <- SparkR:::callJMethod(.zeppelinR, "getRequest")
+ type <- SparkR:::callJMethod(req, "getType")
+ stmt <- SparkR:::callJMethod(req, "getStmt")
+ value <- SparkR:::callJMethod(req, "getValue")
+
+ if (type == "eval") {
+ tryCatch({
+ ret <- eval(parse(text=stmt))
+ SparkR:::callJMethod(.zeppelinR, "setResponse", "", FALSE)
+ }, error = function(e) {
+ SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE)
+ })
+ } else if (type == "set") {
+ tryCatch({
+ ret <- assign(stmt, value)
+ SparkR:::callJMethod(.zeppelinR, "setResponse", "", FALSE)
+ }, error = function(e) {
+ SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE)
+ })
+ } else if (type == "get") {
+ tryCatch({
+ ret <- eval(parse(text=stmt))
+ SparkR:::callJMethod(.zeppelinR, "setResponse", ret, FALSE)
+ }, error = function(e) {
+ SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE)
+ })
+ } else if (type == "getS") {
+ tryCatch({
+ ret <- eval(parse(text=stmt))
+ SparkR:::callJMethod(.zeppelinR, "setResponse", toString(ret), FALSE)
+ }, error = function(e) {
+ SparkR:::callJMethod(.zeppelinR, "setResponse", toString(e), TRUE)
+ })
+ } else {
+ # unsupported type
+ SparkR:::callJMethod(.zeppelinR, "setResponse", paste("Unsupported type ", type), TRUE)
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/resources/interpreter-setting.json
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/resources/interpreter-setting.json b/spark/interpreter/src/main/resources/interpreter-setting.json
new file mode 100644
index 0000000..7e647d7
--- /dev/null
+++ b/spark/interpreter/src/main/resources/interpreter-setting.json
@@ -0,0 +1,233 @@
+[
+ {
+ "group": "spark",
+ "name": "spark",
+ "className": "org.apache.zeppelin.spark.SparkInterpreter",
+ "defaultInterpreter": true,
+ "properties": {
+ "spark.executor.memory": {
+ "envName": null,
+ "propertyName": "spark.executor.memory",
+ "defaultValue": "",
+ "description": "Executor memory per worker instance. ex) 512m, 32g",
+ "type": "string"
+ },
+ "args": {
+ "envName": null,
+ "propertyName": null,
+ "defaultValue": "",
+ "description": "spark commandline args",
+ "type": "textarea"
+ },
+ "zeppelin.spark.useHiveContext": {
+ "envName": "ZEPPELIN_SPARK_USEHIVECONTEXT",
+ "propertyName": "zeppelin.spark.useHiveContext",
+ "defaultValue": true,
+ "description": "Use HiveContext instead of SQLContext if it is true.",
+ "type": "checkbox"
+ },
+ "spark.app.name": {
+ "envName": "SPARK_APP_NAME",
+ "propertyName": "spark.app.name",
+ "defaultValue": "Zeppelin",
+ "description": "The name of spark application.",
+ "type": "string"
+ },
+ "zeppelin.spark.printREPLOutput": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.printREPLOutput",
+ "defaultValue": true,
+ "description": "Print REPL output",
+ "type": "checkbox"
+ },
+ "spark.cores.max": {
+ "envName": null,
+ "propertyName": "spark.cores.max",
+ "defaultValue": "",
+ "description": "Total number of cores to use. Empty value uses all available core.",
+ "type": "number"
+ },
+ "zeppelin.spark.maxResult": {
+ "envName": "ZEPPELIN_SPARK_MAXRESULT",
+ "propertyName": "zeppelin.spark.maxResult",
+ "defaultValue": "1000",
+ "description": "Max number of Spark SQL result to display.",
+ "type": "number"
+ },
+ "master": {
+ "envName": "MASTER",
+ "propertyName": "spark.master",
+ "defaultValue": "local[*]",
+ "description": "Spark master uri. ex) spark://masterhost:7077",
+ "type": "string"
+ },
+ "zeppelin.spark.enableSupportedVersionCheck": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.enableSupportedVersionCheck",
+ "defaultValue": true,
+ "description": "Do not change - developer only setting, not for production use",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.uiWebUrl": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.uiWebUrl",
+ "defaultValue": "",
+ "description": "Override Spark UI default URL",
+ "type": "string"
+ },
+ "zeppelin.spark.useNew": {
+ "envName": null,
+ "propertyName": "zeppelin.spark.useNew",
+ "defaultValue": "false",
+ "description": "Whether use new spark interpreter implementation",
+ "type": "checkbox"
+ }
+ },
+ "editor": {
+ "language": "scala",
+ "editOnDblClick": false,
+ "completionKey": "TAB"
+ }
+ },
+ {
+ "group": "spark",
+ "name": "sql",
+ "className": "org.apache.zeppelin.spark.SparkSqlInterpreter",
+ "properties": {
+ "zeppelin.spark.concurrentSQL": {
+ "envName": "ZEPPELIN_SPARK_CONCURRENTSQL",
+ "propertyName": "zeppelin.spark.concurrentSQL",
+ "defaultValue": false,
+ "description": "Execute multiple SQL concurrently if set true.",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.sql.stacktrace": {
+ "envName": "ZEPPELIN_SPARK_SQL_STACKTRACE",
+ "propertyName": "zeppelin.spark.sql.stacktrace",
+ "defaultValue": false,
+ "description": "Show full exception stacktrace for SQL queries if set to true.",
+ "type": "checkbox"
+ },
+ "zeppelin.spark.maxResult": {
+ "envName": "ZEPPELIN_SPARK_MAXRESULT",
+ "propertyName": "zeppelin.spark.maxResult",
+ "defaultValue": "1000",
+ "description": "Max number of Spark SQL result to display.",
+ "type": "number"
+ },
+ "zeppelin.spark.importImplicit": {
+ "envName": "ZEPPELIN_SPARK_IMPORTIMPLICIT",
+ "propertyName": "zeppelin.spark.importImplicit",
+ "defaultValue": true,
+ "description": "Import implicits, UDF collection, and sql if set true. true by default.",
+ "type": "checkbox"
+ }
+ },
+ "editor": {
+ "language": "sql",
+ "editOnDblClick": false,
+ "completionKey": "TAB"
+ }
+ },
+ {
+ "group": "spark",
+ "name": "dep",
+ "className": "org.apache.zeppelin.spark.DepInterpreter",
+ "properties": {
+ "zeppelin.dep.localrepo": {
+ "envName": "ZEPPELIN_DEP_LOCALREPO",
+ "propertyName": null,
+ "defaultValue": "local-repo",
+ "description": "local repository for dependency loader",
+ "type": "string"
+ },
+ "zeppelin.dep.additionalRemoteRepository": {
+ "envName": null,
+ "propertyName": null,
+ "defaultValue": "spark-packages,http://dl.bintray.com/spark-packages/maven,false;",
+ "description": "A list of 'id,remote-repository-URL,is-snapshot;' for each remote repository.",
+ "type": "textarea"
+ }
+ },
+ "editor": {
+ "language": "scala",
+ "editOnDblClick": false,
+ "completionKey": "TAB"
+ }
+ },
+ {
+ "group": "spark",
+ "name": "pyspark",
+ "className": "org.apache.zeppelin.spark.PySparkInterpreter",
+ "properties": {
+ "zeppelin.pyspark.python": {
+ "envName": "PYSPARK_PYTHON",
+ "propertyName": null,
+ "defaultValue": "python",
+ "description": "Python command to run pyspark with",
+ "type": "string"
+ },
+ "zeppelin.pyspark.useIPython": {
+ "envName": null,
+ "propertyName": "zeppelin.pyspark.useIPython",
+ "defaultValue": true,
+ "description": "whether use IPython when it is available",
+ "type": "checkbox"
+ }
+ },
+ "editor": {
+ "language": "python",
+ "editOnDblClick": false,
+ "completionKey": "TAB"
+ }
+ },
+ {
+ "group": "spark",
+ "name": "ipyspark",
+ "className": "org.apache.zeppelin.spark.IPySparkInterpreter",
+ "properties": {},
+ "editor": {
+ "language": "python",
+ "editOnDblClick": false
+ }
+ },
+ {
+ "group": "spark",
+ "name": "r",
+ "className": "org.apache.zeppelin.spark.SparkRInterpreter",
+ "properties": {
+ "zeppelin.R.knitr": {
+ "envName": "ZEPPELIN_R_KNITR",
+ "propertyName": "zeppelin.R.knitr",
+ "defaultValue": true,
+ "description": "whether use knitr or not",
+ "type": "checkbox"
+ },
+ "zeppelin.R.cmd": {
+ "envName": "ZEPPELIN_R_CMD",
+ "propertyName": "zeppelin.R.cmd",
+ "defaultValue": "R",
+ "description": "R repl path",
+ "type": "string"
+ },
+ "zeppelin.R.image.width": {
+ "envName": "ZEPPELIN_R_IMAGE_WIDTH",
+ "propertyName": "zeppelin.R.image.width",
+ "defaultValue": "100%",
+ "description": "",
+ "type": "number"
+ },
+ "zeppelin.R.render.options": {
+ "envName": "ZEPPELIN_R_RENDER_OPTIONS",
+ "propertyName": "zeppelin.R.render.options",
+ "defaultValue": "out.format = 'html', comment = NA, echo = FALSE, results = 'asis', message = F, warning = F, fig.retina = 2",
+ "description": "",
+ "type": "textarea"
+ }
+ },
+ "editor": {
+ "language": "r",
+ "editOnDblClick": false
+ }
+ }
+]
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py
new file mode 100644
index 0000000..324f481
--- /dev/null
+++ b/spark/interpreter/src/main/resources/python/zeppelin_ipyspark.py
@@ -0,0 +1,53 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+
+# for back compatibility
+from pyspark.sql import SQLContext
+
+# start JVM gateway
+client = GatewayClient(port=${JVM_GATEWAY_PORT})
+gateway = JavaGateway(client, auto_convert=True)
+
+java_import(gateway.jvm, "org.apache.spark.SparkEnv")
+java_import(gateway.jvm, "org.apache.spark.SparkConf")
+java_import(gateway.jvm, "org.apache.spark.api.java.*")
+java_import(gateway.jvm, "org.apache.spark.api.python.*")
+java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
+
+intp = gateway.entry_point
+jsc = intp.getJavaSparkContext()
+
+java_import(gateway.jvm, "org.apache.spark.sql.*")
+java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+java_import(gateway.jvm, "scala.Tuple2")
+
+jconf = jsc.getConf()
+conf = SparkConf(_jvm=gateway.jvm, _jconf=jconf)
+sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
+
+if intp.isSpark2():
+ from pyspark.sql import SparkSession
+
+ spark = __zSpark__ = SparkSession(sc, intp.getSparkSession())
+ sqlContext = sqlc = __zSqlc__ = __zSpark__._wrapped
+else:
+ sqlContext = sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext())
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
new file mode 100644
index 0000000..c10855a
--- /dev/null
+++ b/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py
@@ -0,0 +1,393 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os, sys, getopt, traceback, json, re
+
+from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from py4j.protocol import Py4JJavaError
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+import ast
+import warnings
+
+# for back compatibility
+from pyspark.sql import SQLContext, HiveContext, Row
+
+class Logger(object):
+ def __init__(self):
+ pass
+
+ def write(self, message):
+ intp.appendOutput(message)
+
+ def reset(self):
+ pass
+
+ def flush(self):
+ pass
+
+
+class PyZeppelinContext(dict):
+ def __init__(self, zc):
+ self.z = zc
+ self._displayhook = lambda *args: None
+
+ def show(self, obj):
+ from pyspark.sql import DataFrame
+ if isinstance(obj, DataFrame):
+ print(self.z.showData(obj._jdf))
+ else:
+ print(str(obj))
+
+ # By implementing special methods it makes operating on it more Pythonic
+ def __setitem__(self, key, item):
+ self.z.put(key, item)
+
+ def __getitem__(self, key):
+ return self.z.get(key)
+
+ def __delitem__(self, key):
+ self.z.remove(key)
+
+ def __contains__(self, item):
+ return self.z.containsKey(item)
+
+ def add(self, key, value):
+ self.__setitem__(key, value)
+
+ def put(self, key, value):
+ self.__setitem__(key, value)
+
+ def get(self, key):
+ return self.__getitem__(key)
+
+ def getInterpreterContext(self):
+ return self.z.getInterpreterContext()
+
+ def input(self, name, defaultValue=""):
+ return self.z.input(name, defaultValue)
+
+ def textbox(self, name, defaultValue=""):
+ return self.z.textbox(name, defaultValue)
+
+ def noteTextbox(self, name, defaultValue=""):
+ return self.z.noteTextbox(name, defaultValue)
+
+ def select(self, name, options, defaultValue=""):
+ # auto_convert to ArrayList doesn't match the method signature on JVM side
+ return self.z.select(name, defaultValue, self.getParamOptions(options))
+
+ def noteSelect(self, name, options, defaultValue=""):
+ return self.z.noteSelect(name, defaultValue, self.getParamOptions(options))
+
+ def checkbox(self, name, options, defaultChecked=None):
+ optionsIterable = self.getParamOptions(options)
+ defaultCheckedIterables = self.getDefaultChecked(defaultChecked)
+ checkedItems = gateway.jvm.scala.collection.JavaConversions.seqAsJavaList(self.z.checkbox(name, defaultCheckedIterables, optionsIterable))
+ result = []
+ for checkedItem in checkedItems:
+ result.append(checkedItem)
+ return result;
+
+ def noteCheckbox(self, name, options, defaultChecked=None):
+ optionsIterable = self.getParamOptions(options)
+ defaultCheckedIterables = self.getDefaultChecked(defaultChecked)
+ checkedItems = gateway.jvm.scala.collection.JavaConversions.seqAsJavaList(self.z.noteCheckbox(name, defaultCheckedIterables, optionsIterable))
+ result = []
+ for checkedItem in checkedItems:
+ result.append(checkedItem)
+ return result;
+
+ def getParamOptions(self, options):
+ tuples = list(map(lambda items: self.__tupleToScalaTuple2(items), options))
+ return gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(tuples)
+
+ def getDefaultChecked(self, defaultChecked):
+ if defaultChecked is None:
+ defaultChecked = []
+ return gateway.jvm.scala.collection.JavaConversions.collectionAsScalaIterable(defaultChecked)
+
+ def registerHook(self, event, cmd, replName=None):
+ if replName is None:
+ self.z.registerHook(event, cmd)
+ else:
+ self.z.registerHook(event, cmd, replName)
+
+ def unregisterHook(self, event, replName=None):
+ if replName is None:
+ self.z.unregisterHook(event)
+ else:
+ self.z.unregisterHook(event, replName)
+
+ def getHook(self, event, replName=None):
+ if replName is None:
+ return self.z.getHook(event)
+ return self.z.getHook(event, replName)
+
+ def _setup_matplotlib(self):
+ # If we don't have matplotlib installed don't bother continuing
+ try:
+ import matplotlib
+ except ImportError:
+ return
+
+ # Make sure custom backends are available in the PYTHONPATH
+ rootdir = os.environ.get('ZEPPELIN_HOME', os.getcwd())
+ mpl_path = os.path.join(rootdir, 'interpreter', 'lib', 'python')
+ if mpl_path not in sys.path:
+ sys.path.append(mpl_path)
+
+ # Finally check if backend exists, and if so configure as appropriate
+ try:
+ matplotlib.use('module://backend_zinline')
+ import backend_zinline
+
+ # Everything looks good so make config assuming that we are using
+ # an inline backend
+ self._displayhook = backend_zinline.displayhook
+ self.configure_mpl(width=600, height=400, dpi=72, fontsize=10,
+ interactive=True, format='png', context=self.z)
+ except ImportError:
+ # Fall back to Agg if no custom backend installed
+ matplotlib.use('Agg')
+ warnings.warn("Unable to load inline matplotlib backend, "
+ "falling back to Agg")
+
+ def configure_mpl(self, **kwargs):
+ import mpl_config
+ mpl_config.configure(**kwargs)
+
+ def __tupleToScalaTuple2(self, tuple):
+ if (len(tuple) == 2):
+ return gateway.jvm.scala.Tuple2(tuple[0], tuple[1])
+ else:
+ raise IndexError("options must be a list of tuple of 2")
+
+
+class SparkVersion(object):
+ SPARK_1_4_0 = 10400
+ SPARK_1_3_0 = 10300
+ SPARK_2_0_0 = 20000
+
+ def __init__(self, versionNumber):
+ self.version = versionNumber
+
+ def isAutoConvertEnabled(self):
+ return self.version >= self.SPARK_1_4_0
+
+ def isImportAllPackageUnderSparkSql(self):
+ return self.version >= self.SPARK_1_3_0
+
+ def isSpark2(self):
+ return self.version >= self.SPARK_2_0_0
+
+class PySparkCompletion:
+ def __init__(self, interpreterObject):
+ self.interpreterObject = interpreterObject
+
+ def getGlobalCompletion(self):
+ objectDefList = []
+ try:
+ for completionItem in list(globals().keys()):
+ objectDefList.append(completionItem)
+ except:
+ return None
+ else:
+ return objectDefList
+
+ def getMethodCompletion(self, text_value):
+ execResult = locals()
+ if text_value == None:
+ return None
+ completion_target = text_value
+ try:
+ if len(completion_target) <= 0:
+ return None
+ if text_value[-1] == ".":
+ completion_target = text_value[:-1]
+ exec("{} = dir({})".format("objectDefList", completion_target), globals(), execResult)
+ except:
+ return None
+ else:
+ return list(execResult['objectDefList'])
+
+
+ def getCompletion(self, text_value):
+ completionList = set()
+
+ globalCompletionList = self.getGlobalCompletion()
+ if globalCompletionList != None:
+ for completionItem in list(globalCompletionList):
+ completionList.add(completionItem)
+
+ if text_value != None:
+ objectCompletionList = self.getMethodCompletion(text_value)
+ if objectCompletionList != None:
+ for completionItem in list(objectCompletionList):
+ completionList.add(completionItem)
+ if len(completionList) <= 0:
+ self.interpreterObject.setStatementsFinished("", False)
+ else:
+ result = json.dumps(list(filter(lambda x : not re.match("^__.*", x), list(completionList))))
+ self.interpreterObject.setStatementsFinished(result, False)
+
+client = GatewayClient(port=int(sys.argv[1]))
+sparkVersion = SparkVersion(int(sys.argv[2]))
+if sparkVersion.isSpark2():
+ from pyspark.sql import SparkSession
+else:
+ from pyspark.sql import SchemaRDD
+
+if sparkVersion.isAutoConvertEnabled():
+ gateway = JavaGateway(client, auto_convert = True)
+else:
+ gateway = JavaGateway(client)
+
+java_import(gateway.jvm, "org.apache.spark.SparkEnv")
+java_import(gateway.jvm, "org.apache.spark.SparkConf")
+java_import(gateway.jvm, "org.apache.spark.api.java.*")
+java_import(gateway.jvm, "org.apache.spark.api.python.*")
+java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
+
+intp = gateway.entry_point
+output = Logger()
+sys.stdout = output
+sys.stderr = output
+intp.onPythonScriptInitialized(os.getpid())
+
+jsc = intp.getJavaSparkContext()
+
+if sparkVersion.isImportAllPackageUnderSparkSql():
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
+else:
+ java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+
+
+java_import(gateway.jvm, "scala.Tuple2")
+
+_zcUserQueryNameSpace = {}
+
+jconf = intp.getSparkConf()
+conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
+sc = _zsc_ = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
+_zcUserQueryNameSpace["_zsc_"] = _zsc_
+_zcUserQueryNameSpace["sc"] = sc
+
+if sparkVersion.isSpark2():
+ spark = __zSpark__ = SparkSession(sc, intp.getSparkSession())
+ sqlc = __zSqlc__ = __zSpark__._wrapped
+ _zcUserQueryNameSpace["sqlc"] = sqlc
+ _zcUserQueryNameSpace["__zSqlc__"] = __zSqlc__
+ _zcUserQueryNameSpace["spark"] = spark
+ _zcUserQueryNameSpace["__zSpark__"] = __zSpark__
+else:
+ sqlc = __zSqlc__ = SQLContext(sparkContext=sc, sqlContext=intp.getSQLContext())
+ _zcUserQueryNameSpace["sqlc"] = sqlc
+ _zcUserQueryNameSpace["__zSqlc__"] = sqlc
+
+sqlContext = __zSqlc__
+_zcUserQueryNameSpace["sqlContext"] = sqlContext
+
+completion = __zeppelin_completion__ = PySparkCompletion(intp)
+_zcUserQueryNameSpace["completion"] = completion
+_zcUserQueryNameSpace["__zeppelin_completion__"] = __zeppelin_completion__
+
+z = __zeppelin__ = PyZeppelinContext(intp.getZeppelinContext())
+__zeppelin__._setup_matplotlib()
+_zcUserQueryNameSpace["z"] = z
+_zcUserQueryNameSpace["__zeppelin__"] = __zeppelin__
+
+while True :
+ req = intp.getStatements()
+ try:
+ stmts = req.statements().split("\n")
+ jobGroup = req.jobGroup()
+ jobDesc = req.jobDescription()
+
+ # Get post-execute hooks
+ try:
+ global_hook = intp.getHook('post_exec_dev')
+ except:
+ global_hook = None
+
+ try:
+ user_hook = __zeppelin__.getHook('post_exec')
+ except:
+ user_hook = None
+
+ nhooks = 0
+ for hook in (global_hook, user_hook):
+ if hook:
+ nhooks += 1
+
+ if stmts:
+ # use exec mode to compile the statements except the last statement,
+ # so that the last statement's evaluation will be printed to stdout
+ sc.setJobGroup(jobGroup, jobDesc)
+ code = compile('\n'.join(stmts), '<stdin>', 'exec', ast.PyCF_ONLY_AST, 1)
+ to_run_hooks = []
+ if (nhooks > 0):
+ to_run_hooks = code.body[-nhooks:]
+ to_run_exec, to_run_single = (code.body[:-(nhooks + 1)],
+ [code.body[-(nhooks + 1)]])
+
+ try:
+ for node in to_run_exec:
+ mod = ast.Module([node])
+ code = compile(mod, '<stdin>', 'exec')
+ exec(code, _zcUserQueryNameSpace)
+
+ for node in to_run_single:
+ mod = ast.Interactive([node])
+ code = compile(mod, '<stdin>', 'single')
+ exec(code, _zcUserQueryNameSpace)
+
+ for node in to_run_hooks:
+ mod = ast.Module([node])
+ code = compile(mod, '<stdin>', 'exec')
+ exec(code, _zcUserQueryNameSpace)
+
+ intp.setStatementsFinished("", False)
+ except Py4JJavaError:
+ # raise it to outside try except
+ raise
+ except:
+ exception = traceback.format_exc()
+ m = re.search("File \"<stdin>\", line (\d+).*", exception)
+ if m:
+ line_no = int(m.group(1))
+ intp.setStatementsFinished(
+ "Fail to execute line {}: {}\n".format(line_no, stmts[line_no - 1]) + exception, True)
+ else:
+ intp.setStatementsFinished(exception, True)
+ else:
+ intp.setStatementsFinished("", False)
+
+ except Py4JJavaError:
+ excInnerError = traceback.format_exc() # format_tb() does not return the inner exception
+ innerErrorStart = excInnerError.find("Py4JJavaError:")
+ if innerErrorStart > -1:
+ excInnerError = excInnerError[innerErrorStart:]
+ intp.setStatementsFinished(excInnerError + str(sys.exc_info()), True)
+ except:
+ intp.setStatementsFinished(traceback.format_exc(), True)
+
+ output.reset()
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala b/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala
new file mode 100644
index 0000000..05f1ac0
--- /dev/null
+++ b/spark/interpreter/src/main/scala/org/apache/spark/SparkRBackend.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark
+
+import org.apache.spark.api.r.RBackend
+
+object SparkRBackend {
+ val backend : RBackend = new RBackend()
+ private var started = false;
+ private var portNumber = 0;
+
+ val backendThread : Thread = new Thread("SparkRBackend") {
+ override def run() {
+ backend.run()
+ }
+ }
+
+ def init() : Int = {
+ portNumber = backend.init()
+ portNumber
+ }
+
+ def start() : Unit = {
+ backendThread.start()
+ started = true
+ }
+
+ def close() : Unit = {
+ backend.close()
+ backendThread.join()
+ }
+
+ def isStarted() : Boolean = {
+ started
+ }
+
+ def port(): Int = {
+ return portNumber
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala
new file mode 100644
index 0000000..a9014c2
--- /dev/null
+++ b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/ZeppelinRDisplay.scala
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark
+
+import org.apache.zeppelin.interpreter.InterpreterResult.Code
+import org.apache.zeppelin.interpreter.InterpreterResult.Code.{SUCCESS}
+import org.apache.zeppelin.interpreter.InterpreterResult.Type
+import org.apache.zeppelin.interpreter.InterpreterResult.Type.{TEXT, HTML, TABLE, IMG}
+import org.jsoup.Jsoup
+import org.jsoup.nodes.Element
+import org.jsoup.nodes.Document.OutputSettings
+import org.jsoup.safety.Whitelist
+
+import scala.collection.JavaConversions._
+import scala.util.matching.Regex
+
+case class RDisplay(content: String, `type`: Type, code: Code)
+
+object ZeppelinRDisplay {
+
+ val pattern = new Regex("""^ *\[\d*\] """)
+
+ def render(html: String, imageWidth: String): RDisplay = {
+
+ val document = Jsoup.parse(html)
+ document.outputSettings().prettyPrint(false)
+
+ val body = document.body()
+
+ if (body.getElementsByTag("p").isEmpty) return RDisplay(body.html(), HTML, SUCCESS)
+
+ val bodyHtml = body.html()
+
+ if (! bodyHtml.contains("<img")
+ && ! bodyHtml.contains("<script")
+ && ! bodyHtml.contains("%html ")
+ && ! bodyHtml.contains("%table ")
+ && ! bodyHtml.contains("%img ")
+ ) {
+ return textDisplay(body)
+ }
+
+ if (bodyHtml.contains("%table")) {
+ return tableDisplay(body)
+ }
+
+ if (bodyHtml.contains("%img")) {
+ return imgDisplay(body)
+ }
+
+ return htmlDisplay(body, imageWidth)
+ }
+
+ private def textDisplay(body: Element): RDisplay = {
+ // remove HTML tag while preserving whitespaces and newlines
+ val text = Jsoup.clean(body.html(), "",
+ Whitelist.none(), new OutputSettings().prettyPrint(false))
+ RDisplay(text, TEXT, SUCCESS)
+ }
+
+ private def tableDisplay(body: Element): RDisplay = {
+ val p = body.getElementsByTag("p").first().html.replace("“%table " , "").replace("”", "")
+ val r = (pattern findFirstIn p).getOrElse("")
+ val table = p.replace(r, "").replace("\\t", "\t").replace("\\n", "\n")
+ RDisplay(table, TABLE, SUCCESS)
+ }
+
+ private def imgDisplay(body: Element): RDisplay = {
+ val p = body.getElementsByTag("p").first().html.replace("“%img " , "").replace("”", "")
+ val r = (pattern findFirstIn p).getOrElse("")
+ val img = p.replace(r, "")
+ RDisplay(img, IMG, SUCCESS)
+ }
+
+ private def htmlDisplay(body: Element, imageWidth: String): RDisplay = {
+ var div = new String()
+
+ for (element <- body.children) {
+
+ val eHtml = element.html()
+ var eOuterHtml = element.outerHtml()
+
+ eOuterHtml = eOuterHtml.replace("“%html " , "").replace("”", "")
+
+ val r = (pattern findFirstIn eHtml).getOrElse("")
+
+ div = div + eOuterHtml.replace(r, "")
+ }
+
+ val content = div
+ .replaceAll("src=\"//", "src=\"http://")
+ .replaceAll("href=\"//", "href=\"http://")
+
+ body.html(content)
+
+ for (image <- body.getElementsByTag("img")) {
+ image.attr("width", imageWidth)
+ }
+
+ RDisplay(body.html, HTML, SUCCESS)
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala
new file mode 100644
index 0000000..8181434
--- /dev/null
+++ b/spark/interpreter/src/main/scala/org/apache/zeppelin/spark/utils/DisplayUtils.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark.utils
+
+import java.lang.StringBuilder
+
+import org.apache.spark.rdd.RDD
+
+import scala.collection.IterableLike
+
+object DisplayUtils {
+
+ implicit def toDisplayRDDFunctions[T <: Product](rdd: RDD[T]): DisplayRDDFunctions[T] = new DisplayRDDFunctions[T](rdd)
+
+ implicit def toDisplayTraversableFunctions[T <: Product](traversable: Traversable[T]): DisplayTraversableFunctions[T] = new DisplayTraversableFunctions[T](traversable)
+
+ def html(htmlContent: String = "") = s"%html $htmlContent"
+
+ def img64(base64Content: String = "") = s"%img $base64Content"
+
+ def img(url: String) = s"<img src='$url' />"
+}
+
+trait DisplayCollection[T <: Product] {
+
+ def printFormattedData(traversable: Traversable[T], columnLabels: String*): Unit = {
+ val providedLabelCount: Int = columnLabels.size
+ var maxColumnCount:Int = 1
+ val headers = new StringBuilder("%table ")
+
+ val data = new StringBuilder("")
+
+ traversable.foreach(tuple => {
+ maxColumnCount = math.max(maxColumnCount,tuple.productArity)
+ data.append(tuple.productIterator.mkString("\t")).append("\n")
+ })
+
+ if (providedLabelCount > maxColumnCount) {
+ headers.append(columnLabels.take(maxColumnCount).mkString("\t")).append("\n")
+ } else if (providedLabelCount < maxColumnCount) {
+ val missingColumnHeaders = ((providedLabelCount+1) to maxColumnCount).foldLeft[String](""){
+ (stringAccumulator,index) => if (index==1) s"Column$index" else s"$stringAccumulator\tColumn$index"
+ }
+
+ headers.append(columnLabels.mkString("\t")).append(missingColumnHeaders).append("\n")
+ } else {
+ headers.append(columnLabels.mkString("\t")).append("\n")
+ }
+
+ headers.append(data)
+
+ print(headers.toString)
+ }
+
+}
+
+class DisplayRDDFunctions[T <: Product] (val rdd: RDD[T]) extends DisplayCollection[T] {
+
+ def display(columnLabels: String*)(implicit sparkMaxResult: SparkMaxResult): Unit = {
+ printFormattedData(rdd.take(sparkMaxResult.maxResult), columnLabels: _*)
+ }
+
+ def display(sparkMaxResult:Int, columnLabels: String*): Unit = {
+ printFormattedData(rdd.take(sparkMaxResult), columnLabels: _*)
+ }
+}
+
+class DisplayTraversableFunctions[T <: Product] (val traversable: Traversable[T]) extends DisplayCollection[T] {
+
+ def display(columnLabels: String*): Unit = {
+ printFormattedData(traversable, columnLabels: _*)
+ }
+}
+
+class SparkMaxResult(val maxResult: Int) extends Serializable
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java
new file mode 100644
index 0000000..e177d49
--- /dev/null
+++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/DepInterpreterTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+import static org.junit.Assert.assertEquals;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.Properties;
+
+import org.apache.zeppelin.display.AngularObjectRegistry;
+import org.apache.zeppelin.user.AuthenticationInfo;
+import org.apache.zeppelin.display.GUI;
+import org.apache.zeppelin.interpreter.*;
+import org.apache.zeppelin.interpreter.InterpreterResult.Code;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+public class DepInterpreterTest {
+
+ @Rule
+ public TemporaryFolder tmpDir = new TemporaryFolder();
+
+ private DepInterpreter dep;
+ private InterpreterContext context;
+
+ private Properties getTestProperties() throws IOException {
+ Properties p = new Properties();
+ p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath());
+ p.setProperty("zeppelin.dep.additionalRemoteRepository", "spark-packages,http://dl.bintray.com/spark-packages/maven,false;");
+ return p;
+ }
+
+ @Before
+ public void setUp() throws Exception {
+ Properties p = getTestProperties();
+
+ dep = new DepInterpreter(p);
+ dep.open();
+
+ InterpreterGroup intpGroup = new InterpreterGroup();
+ intpGroup.put("note", new LinkedList<Interpreter>());
+ intpGroup.get("note").add(new SparkInterpreter(p));
+ intpGroup.get("note").add(dep);
+ dep.setInterpreterGroup(intpGroup);
+
+ context = new InterpreterContext("note", "id", null, "title", "text", new AuthenticationInfo(),
+ new HashMap<String, Object>(), new GUI(), new GUI(),
+ new AngularObjectRegistry(intpGroup.getId(), null),
+ null,
+ new LinkedList<InterpreterContextRunner>(), null);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ dep.close();
+ }
+
+ @Test
+ public void testDefault() {
+ dep.getDependencyContext().reset();
+ InterpreterResult ret = dep.interpret("z.load(\"org.apache.commons:commons-csv:1.1\")", context);
+ assertEquals(Code.SUCCESS, ret.code());
+
+ assertEquals(1, dep.getDependencyContext().getFiles().size());
+ assertEquals(1, dep.getDependencyContext().getFilesDist().size());
+
+ // Add a test for the spark-packages repo - default in additionalRemoteRepository
+ ret = dep.interpret("z.load(\"amplab:spark-indexedrdd:0.3\")", context);
+ assertEquals(Code.SUCCESS, ret.code());
+
+ // Reset at the end of the test
+ dep.getDependencyContext().reset();
+ }
+}
http://git-wip-us.apache.org/repos/asf/zeppelin/blob/d762b528/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
----------------------------------------------------------------------
diff --git a/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
new file mode 100644
index 0000000..765237c
--- /dev/null
+++ b/spark/interpreter/src/test/java/org/apache/zeppelin/spark/IPySparkInterpreterTest.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.zeppelin.spark;
+
+
+import com.google.common.io.Files;
+import org.apache.zeppelin.display.GUI;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
+import org.apache.zeppelin.interpreter.InterpreterOutput;
+import org.apache.zeppelin.interpreter.InterpreterResult;
+import org.apache.zeppelin.interpreter.InterpreterResultMessage;
+import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
+import org.apache.zeppelin.python.IPythonInterpreterTest;
+import org.apache.zeppelin.user.AuthenticationInfo;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.net.URL;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Properties;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class IPySparkInterpreterTest {
+
+ private IPySparkInterpreter iPySparkInterpreter;
+ private InterpreterGroup intpGroup;
+
+ @Before
+ public void setup() throws InterpreterException {
+ Properties p = new Properties();
+ p.setProperty("spark.master", "local[4]");
+ p.setProperty("master", "local[4]");
+ p.setProperty("spark.submit.deployMode", "client");
+ p.setProperty("spark.app.name", "Zeppelin Test");
+ p.setProperty("zeppelin.spark.useHiveContext", "true");
+ p.setProperty("zeppelin.spark.maxResult", "1000");
+ p.setProperty("zeppelin.spark.importImplicit", "true");
+ p.setProperty("zeppelin.pyspark.python", "python");
+ p.setProperty("zeppelin.dep.localrepo", Files.createTempDir().getAbsolutePath());
+
+ intpGroup = new InterpreterGroup();
+ intpGroup.put("session_1", new LinkedList<Interpreter>());
+
+ SparkInterpreter sparkInterpreter = new SparkInterpreter(p);
+ intpGroup.get("session_1").add(sparkInterpreter);
+ sparkInterpreter.setInterpreterGroup(intpGroup);
+ sparkInterpreter.open();
+
+ iPySparkInterpreter = new IPySparkInterpreter(p);
+ intpGroup.get("session_1").add(iPySparkInterpreter);
+ iPySparkInterpreter.setInterpreterGroup(intpGroup);
+ iPySparkInterpreter.open();
+ }
+
+
+ @After
+ public void tearDown() throws InterpreterException {
+ if (iPySparkInterpreter != null) {
+ iPySparkInterpreter.close();
+ }
+ }
+
+ @Test
+ public void testBasics() throws InterruptedException, IOException, InterpreterException {
+ // all the ipython test should pass too.
+ IPythonInterpreterTest.testInterpreter(iPySparkInterpreter);
+
+ // rdd
+ InterpreterContext context = getInterpreterContext();
+ InterpreterResult result = iPySparkInterpreter.interpret("sc.range(1,10).sum()", context);
+ Thread.sleep(100);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ List<InterpreterResultMessage> interpreterResultMessages = context.out.getInterpreterResultMessages();
+ assertEquals("45", interpreterResultMessages.get(0).getData());
+
+ context = getInterpreterContext();
+ result = iPySparkInterpreter.interpret("sc.version", context);
+ Thread.sleep(100);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ interpreterResultMessages = context.out.getInterpreterResultMessages();
+ // spark sql
+ context = getInterpreterContext();
+ if (interpreterResultMessages.get(0).getData().startsWith("'1.") ||
+ interpreterResultMessages.get(0).getData().startsWith("u'1.")) {
+ result = iPySparkInterpreter.interpret("df = sqlContext.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ interpreterResultMessages = context.out.getInterpreterResultMessages();
+ assertEquals(
+ "+---+---+\n" +
+ "| _1| _2|\n" +
+ "+---+---+\n" +
+ "| 1| a|\n" +
+ "| 2| b|\n" +
+ "+---+---+\n\n", interpreterResultMessages.get(0).getData());
+ } else {
+ result = iPySparkInterpreter.interpret("df = spark.createDataFrame([(1,'a'),(2,'b')])\ndf.show()", context);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ interpreterResultMessages = context.out.getInterpreterResultMessages();
+ assertEquals(
+ "+---+---+\n" +
+ "| _1| _2|\n" +
+ "+---+---+\n" +
+ "| 1| a|\n" +
+ "| 2| b|\n" +
+ "+---+---+\n\n", interpreterResultMessages.get(0).getData());
+ }
+
+ // cancel
+ final InterpreterContext context2 = getInterpreterContext();
+
+ Thread thread = new Thread() {
+ @Override
+ public void run() {
+ InterpreterResult result = iPySparkInterpreter.interpret("import time\nsc.range(1,10).foreach(lambda x: time.sleep(1))", context2);
+ assertEquals(InterpreterResult.Code.ERROR, result.code());
+ List<InterpreterResultMessage> interpreterResultMessages = null;
+ try {
+ interpreterResultMessages = context2.out.getInterpreterResultMessages();
+ assertTrue(interpreterResultMessages.get(0).getData().contains("KeyboardInterrupt"));
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ };
+ thread.start();
+
+ // sleep 1 second to wait for the spark job starts
+ Thread.sleep(1000);
+ iPySparkInterpreter.cancel(context);
+ thread.join();
+
+ // completions
+ List<InterpreterCompletion> completions = iPySparkInterpreter.completion("sc.ran", 6, getInterpreterContext());
+ assertEquals(1, completions.size());
+ assertEquals("range", completions.get(0).getValue());
+
+ // pyspark streaming
+
+ Class klass = py4j.GatewayServer.class;
+ URL location = klass.getResource('/' + klass.getName().replace('.', '/') + ".class");
+ System.out.println("py4j location: " + location);
+ context = getInterpreterContext();
+ result = iPySparkInterpreter.interpret(
+ "from pyspark.streaming import StreamingContext\n" +
+ "import time\n" +
+ "ssc = StreamingContext(sc, 1)\n" +
+ "rddQueue = []\n" +
+ "for i in range(5):\n" +
+ " rddQueue += [ssc.sparkContext.parallelize([j for j in range(1, 1001)], 10)]\n" +
+ "inputStream = ssc.queueStream(rddQueue)\n" +
+ "mappedStream = inputStream.map(lambda x: (x % 10, 1))\n" +
+ "reducedStream = mappedStream.reduceByKey(lambda a, b: a + b)\n" +
+ "reducedStream.pprint()\n" +
+ "ssc.start()\n" +
+ "time.sleep(6)\n" +
+ "ssc.stop(stopSparkContext=False, stopGraceFully=True)", context);
+ Thread.sleep(1000);
+ assertEquals(InterpreterResult.Code.SUCCESS, result.code());
+ interpreterResultMessages = context.out.getInterpreterResultMessages();
+ assertEquals(1, interpreterResultMessages.size());
+// assertTrue(interpreterResultMessages.get(0).getData().contains("(0, 100)"));
+ }
+
+ private InterpreterContext getInterpreterContext() {
+ return new InterpreterContext(
+ "noteId",
+ "paragraphId",
+ "replName",
+ "paragraphTitle",
+ "paragraphText",
+ new AuthenticationInfo(),
+ new HashMap<String, Object>(),
+ new GUI(),
+ new GUI(),
+ null,
+ null,
+ null,
+ new InterpreterOutput(null));
+ }
+}