You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2022/01/07 17:03:15 UTC
[spark] branch master updated: [SPARK-37837][INFRA] Enable black formatter in dev Python scripts
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ead131f [SPARK-37837][INFRA] Enable black formatter in dev Python scripts
ead131f is described below
commit ead131fc6387ca510996e561d69e9fcc86067158
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Fri Jan 7 09:01:22 2022 -0800
[SPARK-37837][INFRA] Enable black formatter in dev Python scripts
### What changes were proposed in this pull request?
This PR proposes to enable [black](https://github.com/psf/black) formatter (automatic Python code formatter) for `dev` directory as well.
### Why are the changes needed?
To have the consistent style, and make for a better development cycle
### Does this PR introduce _any_ user-facing change?
No, dev-only.
### How was this patch tested?
I manually verified it as below:
```bash
dev/reformat-python
dev/linter-python
```
Closes #35127 from HyukjinKwon/SPARK-37837.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
.../files/util_scripts/kill_zinc_nailgun.py | 25 ++-
.../files/util_scripts/post_github_pr_comment.py | 45 +++--
.../files/util_scripts/session_lock_resource.py | 29 +--
dev/create-release/generate-contributors.py | 49 ++++--
dev/create-release/releaseutils.py | 29 +--
dev/create-release/translate-contributors.py | 46 +++--
dev/github_jira_sync.py | 57 +++---
dev/is-changed.py | 26 +--
dev/lint-python | 2 +-
dev/merge_spark_pr.py | 153 +++++++++-------
dev/pip-sanity-check.py | 7 +-
dev/reformat-python | 2 +-
dev/run-tests-jenkins.py | 131 +++++++-------
dev/run-tests.py | 195 +++++++++++++--------
dev/sparktestsupport/__init__.py | 2 +-
dev/sparktestsupport/modules.py | 164 ++++++++---------
dev/sparktestsupport/shellutils.py | 6 +-
dev/sparktestsupport/toposort.py | 26 +--
dev/sparktestsupport/utils.py | 16 +-
19 files changed, 561 insertions(+), 449 deletions(-)
diff --git a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/kill_zinc_nailgun.py b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/kill_zinc_nailgun.py
index 40887e8..3b605c9 100755
--- a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/kill_zinc_nailgun.py
+++ b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/kill_zinc_nailgun.py
@@ -12,14 +12,19 @@ def _parse_args():
zinc_port_var = "ZINC_PORT"
zinc_port_option = "--zinc-port"
parser = argparse.ArgumentParser()
- parser.add_argument(zinc_port_option,
- type=int,
- default=int(os.environ.get(zinc_port_var, "0")),
- help="Specify zinc port")
+ parser.add_argument(
+ zinc_port_option,
+ type=int,
+ default=int(os.environ.get(zinc_port_var, "0")),
+ help="Specify zinc port",
+ )
args = parser.parse_args()
if not args.zinc_port:
- parser.error("Specify either environment variable {0} or option {1}".format(
- zinc_port_var, zinc_port_option))
+ parser.error(
+ "Specify either environment variable {0} or option {1}".format(
+ zinc_port_var, zinc_port_option
+ )
+ )
return args
@@ -36,9 +41,11 @@ def _yield_processes_listening_on_port(port):
innocuous_errors = re.compile(
r"^\s*Output information may be incomplete.\s*$"
r"|^lsof: WARNING: can't stat\(\) (?:tracefs|nsfs|overlay|tmpfs|aufs|zfs) file system .*$"
- r"|^\s*$")
- lsof_process = subprocess.Popen(["lsof", "-P"], stdout=subprocess.PIPE,
- stderr=subprocess.PIPE, universal_newlines=True)
+ r"|^\s*$"
+ )
+ lsof_process = subprocess.Popen(
+ ["lsof", "-P"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True
+ )
stdout, stderr = lsof_process.communicate()
if lsof_process.returncode != 0:
raise OSError("Can't run lsof -P, stderr:\n{}".format(stderr))
diff --git a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/post_github_pr_comment.py b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/post_github_pr_comment.py
index 68e31d4..d55295d 100755
--- a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/post_github_pr_comment.py
+++ b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/post_github_pr_comment.py
@@ -15,19 +15,30 @@ def _parse_args():
github_oauth_key_var = "GITHUB_OAUTH_KEY"
github_oauth_key_option = "--github-oauth-key"
parser = argparse.ArgumentParser()
- parser.add_argument("-pr", pr_link_option,
- default=os.environ.get(pr_link_var, ""),
- help="Specify pull request link")
- parser.add_argument(github_oauth_key_option,
- default=os.environ.get(github_oauth_key_var, ""),
- help="Specify github oauth key")
+ parser.add_argument(
+ "-pr",
+ pr_link_option,
+ default=os.environ.get(pr_link_var, ""),
+ help="Specify pull request link",
+ )
+ parser.add_argument(
+ github_oauth_key_option,
+ default=os.environ.get(github_oauth_key_var, ""),
+ help="Specify github oauth key",
+ )
args = parser.parse_args()
if not args.pr_link:
- parser.error("Specify either environment variable {} or option {}".format(
- pr_link_var, pr_link_option))
+ parser.error(
+ "Specify either environment variable {} or option {}".format(
+ pr_link_var, pr_link_option
+ )
+ )
if not args.github_oauth_key:
- parser.error("Specify either environment variable {} or option {}".format(
- github_oauth_key_var, github_oauth_key_option))
+ parser.error(
+ "Specify either environment variable {} or option {}".format(
+ github_oauth_key_var, github_oauth_key_option
+ )
+ )
return args
@@ -39,12 +50,14 @@ def post_message_to_github(msg, github_oauth_key, pr_link):
url = api_url + "/issues/" + ghprb_pull_id + "/comments"
posted_message = json.dumps({"body": msg})
- request = Request(url,
- headers={
- "Authorization": "token {}".format(github_oauth_key),
- "Content-Type": "application/json"
- },
- data=posted_message.encode('utf-8'))
+ request = Request(
+ url,
+ headers={
+ "Authorization": "token {}".format(github_oauth_key),
+ "Content-Type": "application/json",
+ },
+ data=posted_message.encode("utf-8"),
+ )
try:
response = urlopen(request)
diff --git a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/session_lock_resource.py b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/session_lock_resource.py
index f5153d5..dedf538 100755
--- a/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/session_lock_resource.py
+++ b/dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/session_lock_resource.py
@@ -24,10 +24,12 @@ _LOCK_DIR = "/tmp/session_locked_resources"
def _parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser()
- parser.add_argument("-t", "--timeout-secs", type=int,
- help="How long to wait for lock acquisition, in seconds")
- parser.add_argument("-p", "--pid", type=int,
- help="PID to wait for exit (defaults to parent pid)")
+ parser.add_argument(
+ "-t", "--timeout-secs", type=int, help="How long to wait for lock acquisition, in seconds"
+ )
+ parser.add_argument(
+ "-p", "--pid", type=int, help="PID to wait for exit (defaults to parent pid)"
+ )
parser.add_argument("resource", help="Resource to lock")
return parser.parse_args()
@@ -63,13 +65,13 @@ def _daemonize(child_body):
argument, which is a function that is called with a boolean
that indicates that the child succeeded/failed in its initialization
"""
- CHILD_FAIL = '\2'
- CHILD_SUCCESS = '\0'
+ CHILD_FAIL = "\2"
+ CHILD_SUCCESS = "\0"
r_fd, w_fd = os.pipe()
if os.fork() != 0:
# We are the original script. Read success/fail from the final
# child and log an error message if needed.
- child_code = os.read(r_fd, 1) # .decode('utf-8')
+ child_code = os.read(r_fd, 1) # .decode('utf-8')
return child_code == CHILD_SUCCESS
# First child
os.setsid()
@@ -82,6 +84,7 @@ def _daemonize(child_body):
def _write_to_parent(success):
parent_message = CHILD_SUCCESS if success else CHILD_FAIL
os.write(w_fd, parent_message)
+
child_body(_write_to_parent)
os._exit(0)
@@ -114,15 +117,15 @@ def _is_pid_running(pid):
return True
-def _lock_and_wait(lock_success_callback, resource, timeout_secs,
- controlling_pid):
+def _lock_and_wait(lock_success_callback, resource, timeout_secs, controlling_pid):
"""Attempt to lock the file then wait.
lock_success_callback will be called if the locking worked.
"""
lock_filename = os.path.join(_LOCK_DIR, resource)
- lock_message = ("Session lock on " + resource
- + ", controlling pid " + str(controlling_pid) + "\n")
+ lock_message = (
+ "Session lock on " + resource + ", controlling pid " + str(controlling_pid) + "\n"
+ )
try:
f = _acquire_lock(lock_filename, timeout_secs, lock_message)
except IOError:
@@ -139,8 +142,8 @@ def main():
os.mkdir(_LOCK_DIR)
controlling_pid = args.pid or os.getppid()
child_body_func = lambda success_callback: _lock_and_wait(
- success_callback, args.resource, args.timeout_secs,
- controlling_pid)
+ success_callback, args.resource, args.timeout_secs, controlling_pid
+ )
if _daemonize(child_body_func):
return 0
else:
diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py
index ae9629a..bb63515 100755
--- a/dev/create-release/generate-contributors.py
+++ b/dev/create-release/generate-contributors.py
@@ -22,9 +22,21 @@ import os
import re
import sys
-from releaseutils import tag_exists, get_commits, yesOrNoPrompt, get_date, \
- is_valid_author, capitalize_author, JIRA, find_components, translate_issue_type, \
- translate_component, CORE_COMPONENT, contributors_file_name, nice_join
+from releaseutils import (
+ tag_exists,
+ get_commits,
+ yesOrNoPrompt,
+ get_date,
+ is_valid_author,
+ capitalize_author,
+ JIRA,
+ find_components,
+ translate_issue_type,
+ translate_component,
+ CORE_COMPONENT,
+ contributors_file_name,
+ nice_join,
+)
# You must set the following before use!
JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
@@ -37,7 +49,8 @@ while not tag_exists(RELEASE_TAG):
while not tag_exists(PREVIOUS_RELEASE_TAG):
print("Please specify the previous release tag.")
PREVIOUS_RELEASE_TAG = input(
- "For instance, if you are releasing v1.2.0, you should specify v1.1.0: ")
+ "For instance, if you are releasing v1.2.0, you should specify v1.1.0: "
+ )
# Gather commits found in the new tag but not in the old tag.
# This filters commits based on both the git hash and the PR number.
@@ -92,15 +105,16 @@ filtered_commits = []
def is_release(commit_title):
- return ("[release]" in commit_title.lower() or
- "preparing spark release" in commit_title.lower() or
- "preparing development version" in commit_title.lower() or
- "CHANGES.txt" in commit_title)
+ return (
+ "[release]" in commit_title.lower()
+ or "preparing spark release" in commit_title.lower()
+ or "preparing development version" in commit_title.lower()
+ or "CHANGES.txt" in commit_title
+ )
def is_maintenance(commit_title):
- return "maintenance" in commit_title.lower() or \
- "manually close" in commit_title.lower()
+ return "maintenance" in commit_title.lower() or "manually close" in commit_title.lower()
def has_no_jira(commit_title):
@@ -112,8 +126,7 @@ def is_revert(commit_title):
def is_docs(commit_title):
- return re.findall("docs*", commit_title.lower()) or \
- "programming guide" in commit_title.lower()
+ return re.findall("docs*", commit_title.lower()) or "programming guide" in commit_title.lower()
for c in new_commits:
@@ -215,14 +228,16 @@ for commit in filtered_commits:
author_info[author][issue_type] = set()
for component in components:
author_info[author][issue_type].add(component)
+
# Find issues and components associated with this commit
for issue in issues:
try:
jira_issue = jira_client.issue(issue)
jira_type = jira_issue.fields.issuetype.name
jira_type = translate_issue_type(jira_type, issue, warnings)
- jira_components = [translate_component(c.name, _hash, warnings)
- for c in jira_issue.fields.components]
+ jira_components = [
+ translate_component(c.name, _hash, warnings) for c in jira_issue.fields.components
+ ]
all_components = set(jira_components + commit_components)
populate(jira_type, all_components)
except Exception as e:
@@ -254,8 +269,10 @@ for author in authors:
# Otherwise, group contributions by issue types instead of modules
# e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN
else:
- contributions = ["%s in %s" % (issue_type, nice_join(comps))
- for issue_type, comps in author_info[author].items()]
+ contributions = [
+ "%s in %s" % (issue_type, nice_join(comps))
+ for issue_type, comps in author_info[author].items()
+ ]
contribution = "; ".join(contributions)
# Do not use python's capitalize() on the whole string to preserve case
assert contribution
diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py
index 49f0921..26abab9 100755
--- a/dev/create-release/releaseutils.py
+++ b/dev/create-release/releaseutils.py
@@ -24,6 +24,7 @@ from subprocess import Popen, PIPE
try:
from jira.client import JIRA # noqa: F401
+
# Old versions have JIRAError in exceptions package, new (0.5+) in utils.
try:
from jira.exceptions import JIRAError
@@ -111,11 +112,16 @@ def get_commits(tag):
commit_start_marker = "|=== COMMIT START MARKER ===|"
commit_end_marker = "|=== COMMIT END MARKER ===|"
field_end_marker = "|=== COMMIT FIELD END MARKER ===|"
- log_format =\
- commit_start_marker + "%h" +\
- field_end_marker + "%an" +\
- field_end_marker + "%s" +\
- commit_end_marker + "%b"
+ log_format = (
+ commit_start_marker
+ + "%h"
+ + field_end_marker
+ + "%an"
+ + field_end_marker
+ + "%s"
+ + commit_end_marker
+ + "%b"
+ )
output = run_cmd(["git", "log", "--quiet", "--pretty=format:" + log_format, tag])
commits = []
raw_commits = [c for c in output.split(commit_start_marker) if c]
@@ -162,7 +168,7 @@ known_issue_types = {
"documentation": "documentation",
"test": "test",
"task": "improvement",
- "sub-task": "improvement"
+ "sub-task": "improvement",
}
# Maintain a mapping for translating component names when creating the release notes
@@ -193,7 +199,7 @@ known_components = {
"streaming": "Streaming",
"web ui": "Web UI",
"windows": "Windows",
- "yarn": "YARN"
+ "yarn": "YARN",
}
@@ -204,7 +210,7 @@ def translate_issue_type(issue_type, issue_id, warnings):
if issue_type in known_issue_types:
return known_issue_types[issue_type]
else:
- warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id))
+ warnings.append('Unknown issue type "%s" (see %s)' % (issue_type, issue_id))
return issue_type
@@ -215,7 +221,7 @@ def translate_component(component, commit_hash, warnings):
if component in known_components:
return known_components[component]
else:
- warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash))
+ warnings.append('Unknown component "%s" (see %s)' % (component, commit_hash))
return component
@@ -223,8 +229,9 @@ def translate_component(component, commit_hash, warnings):
# The returned components are already filtered and translated
def find_components(commit, commit_hash):
components = re.findall(r"\[\w*\]", commit.lower())
- components = [translate_component(c, commit_hash, [])
- for c in components if c in known_components]
+ components = [
+ translate_component(c, commit_hash, []) for c in components if c in known_components
+ ]
return components
diff --git a/dev/create-release/translate-contributors.py b/dev/create-release/translate-contributors.py
index 2ac6cd0..18a114c 100755
--- a/dev/create-release/translate-contributors.py
+++ b/dev/create-release/translate-contributors.py
@@ -31,8 +31,17 @@
import os
import sys
-from releaseutils import JIRA, JIRAError, get_jira_name, Github, get_github_name, \
- contributors_file_name, is_valid_author, capitalize_author, yesOrNoPrompt
+from releaseutils import (
+ JIRA,
+ JIRAError,
+ get_jira_name,
+ Github,
+ get_github_name,
+ contributors_file_name,
+ is_valid_author,
+ capitalize_author,
+ yesOrNoPrompt,
+)
# You must set the following before use!
JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
@@ -126,10 +135,12 @@ def generate_candidates(author, issues):
display_name = jira_assignee.displayName
if display_name:
candidates.append(
- (display_name, "Full name of %s assignee %s" % (issue, user_name)))
+ (display_name, "Full name of %s assignee %s" % (issue, user_name))
+ )
else:
candidates.append(
- (NOT_FOUND, "No full name found for %s assignee %s" % (issue, user_name)))
+ (NOT_FOUND, "No full name found for %s assignee %s" % (issue, user_name))
+ )
else:
candidates.append((NOT_FOUND, "No assignee found for %s" % issue))
for i, (candidate, source) in enumerate(candidates):
@@ -152,7 +163,7 @@ for i, line in enumerate(lines):
temp_author = line.strip(" * ").split(" -- ")[0].strip()
print("Processing author %s (%d/%d)" % (temp_author, i + 1, len(lines)))
if not temp_author:
- error_msg = " ERROR: Expected the following format \" * <author> -- <contributions>\"\n"
+ error_msg = ' ERROR: Expected the following format " * <author> -- <contributions>"\n'
error_msg += " ERROR: Actual = %s" % line
print(error_msg)
warnings.append(error_msg)
@@ -207,8 +218,9 @@ for i, line in enumerate(lines):
new_author = candidate_names[response]
# In non-interactive mode, just pick the first candidate
else:
- valid_candidate_names = [name for name, _ in candidates
- if is_valid_author(name) and name != NOT_FOUND]
+ valid_candidate_names = [
+ name for name, _ in candidates if is_valid_author(name) and name != NOT_FOUND
+ ]
if valid_candidate_names:
new_author = valid_candidate_names[0]
# Finally, capitalize the author and replace the original one with it
@@ -216,15 +228,17 @@ for i, line in enumerate(lines):
if is_valid_author(new_author):
new_author = capitalize_author(new_author)
else:
- warnings.append(
- "Unable to find a valid name %s for author %s" % (author, temp_author))
+ warnings.append("Unable to find a valid name %s for author %s" % (author, temp_author))
print(" * Replacing %s with %s" % (author, new_author))
# If we are in interactive mode, prompt the user whether we want to remember this new
# mapping
- if INTERACTIVE_MODE and \
- author not in known_translations and \
- yesOrNoPrompt(
- " Add mapping %s -> %s to known translations file?" % (author, new_author)):
+ if (
+ INTERACTIVE_MODE
+ and author not in known_translations
+ and yesOrNoPrompt(
+ " Add mapping %s -> %s to known translations file?" % (author, new_author)
+ )
+ ):
known_translations_file.write("%s - %s\n" % (author, new_author))
known_translations_file.flush()
line = line.replace(temp_author, author)
@@ -257,6 +271,8 @@ if warnings:
print("\n========== Warnings encountered while translating the contributor list ===========")
for w in warnings:
print(w)
- print("Please manually correct these in the final contributors list at %s." %
- new_contributors_file_name)
+ print(
+ "Please manually correct these in the final contributors list at %s."
+ % new_contributors_file_name
+ )
print("==================================================================================\n")
diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py
index f2753b6..3163f26 100755
--- a/dev/github_jira_sync.py
+++ b/dev/github_jira_sync.py
@@ -55,7 +55,7 @@ MAX_FILE = ".github-jira-max"
def get_url(url):
try:
request = Request(url)
- request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY)
+ request.add_header("Authorization", "token %s" % GITHUB_OAUTH_KEY)
return urlopen(request)
except HTTPError:
print("Unable to fetch URL, exiting: %s" % url)
@@ -77,7 +77,7 @@ def get_jira_prs():
page_json = get_json(page)
for pull in page_json:
- jira_issues = re.findall(JIRA_PROJECT_NAME + "-[0-9]{4,5}", pull['title'])
+ jira_issues = re.findall(JIRA_PROJECT_NAME + "-[0-9]{4,5}", pull["title"])
for jira_issue in jira_issues:
result = result + [(jira_issue, pull)]
@@ -91,7 +91,7 @@ def get_jira_prs():
def set_max_pr(max_val):
- f = open(MAX_FILE, 'w')
+ f = open(MAX_FILE, "w")
f.write("%s" % max_val)
f.close()
print("Writing largest PR number seen: %s" % max_val)
@@ -99,7 +99,7 @@ def set_max_pr(max_val):
def get_max_pr():
if os.path.exists(MAX_FILE):
- result = int(open(MAX_FILE, 'r').read())
+ result = int(open(MAX_FILE, "r").read())
print("Read largest PR number previously seen: %s" % result)
return result
else:
@@ -112,23 +112,23 @@ def build_pr_component_dic(jira_prs):
for issue, pr in jira_prs:
print(issue)
page = get_json(get_url(JIRA_API_BASE + "/rest/api/2/issue/" + issue))
- jira_components = [c['name'].upper() for c in page['fields']['components']]
- if pr['number'] in dic:
- dic[pr['number']][1].update(jira_components)
+ jira_components = [c["name"].upper() for c in page["fields"]["components"]]
+ if pr["number"] in dic:
+ dic[pr["number"]][1].update(jira_components)
else:
- pr_components = set(label['name'].upper() for label in pr['labels'])
- dic[pr['number']] = (pr_components, set(jira_components))
+ pr_components = set(label["name"].upper() for label in pr["labels"])
+ dic[pr["number"]] = (pr_components, set(jira_components))
return dic
def reset_pr_labels(pr_num, jira_components):
- url = '%s/issues/%s/labels' % (GITHUB_API_BASE, pr_num)
- labels = ', '.join(('"%s"' % c) for c in jira_components)
+ url = "%s/issues/%s/labels" % (GITHUB_API_BASE, pr_num)
+ labels = ", ".join(('"%s"' % c) for c in jira_components)
try:
- request = Request(url, data=('{"labels":[%s]}' % labels).encode('utf-8'))
- request.add_header('Content-Type', 'application/json')
- request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY)
- request.get_method = lambda: 'PUT'
+ request = Request(url, data=('{"labels":[%s]}' % labels).encode("utf-8"))
+ request.add_header("Content-Type", "application/json")
+ request.add_header("Authorization", "token %s" % GITHUB_OAUTH_KEY)
+ request.get_method = lambda: "PUT"
urlopen(request)
print("Set %s with labels %s" % (pr_num, labels))
except HTTPError:
@@ -136,31 +136,30 @@ def reset_pr_labels(pr_num, jira_components):
sys.exit(-1)
-jira_client = jira.client.JIRA({'server': JIRA_API_BASE},
- basic_auth=(JIRA_USERNAME, JIRA_PASSWORD))
+jira_client = jira.client.JIRA({"server": JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD))
jira_prs = get_jira_prs()
previous_max = get_max_pr()
print("Retrieved %s JIRA PR's from GitHub" % len(jira_prs))
-jira_prs = [(k, v) for k, v in jira_prs if int(v['number']) > previous_max]
+jira_prs = [(k, v) for k, v in jira_prs if int(v["number"]) > previous_max]
print("%s PR's remain after excluding visited ones" % len(jira_prs))
num_updates = 0
considered = []
-for issue, pr in sorted(jira_prs, key=lambda kv: int(kv[1]['number'])):
+for issue, pr in sorted(jira_prs, key=lambda kv: int(kv[1]["number"])):
if num_updates >= MAX_UPDATES:
break
- pr_num = int(pr['number'])
+ pr_num = int(pr["number"])
print("Checking issue %s" % issue)
considered = considered + [pr_num]
- url = pr['html_url']
- title = "[GitHub] Pull Request #%s (%s)" % (pr['number'], pr['user']['login'])
+ url = pr["html_url"]
+ title = "[GitHub] Pull Request #%s (%s)" % (pr["number"], pr["user"]["login"])
try:
page = get_json(get_url(JIRA_API_BASE + "/rest/api/2/issue/" + issue + "/remotelink"))
- existing_links = map(lambda l: l['object']['url'], page)
+ existing_links = map(lambda l: l["object"]["url"], page)
except BaseException:
print("Failure reading JIRA %s (does it exist?)" % issue)
print(sys.exc_info()[0])
@@ -169,20 +168,22 @@ for issue, pr in sorted(jira_prs, key=lambda kv: int(kv[1]['number'])):
if url in existing_links:
continue
- icon = {"title": "Pull request #%s" % pr['number'],
- "url16x16": "https://assets-cdn.github.com/favicon.ico"}
+ icon = {
+ "title": "Pull request #%s" % pr["number"],
+ "url16x16": "https://assets-cdn.github.com/favicon.ico",
+ }
destination = {"title": title, "url": url, "icon": icon}
# For all possible fields see:
# https://developer.atlassian.com/display/JIRADEV/Fields+in+Remote+Issue+Links
# application = {"name": "GitHub pull requests", "type": "org.apache.spark.jira.github"}
jira_client.add_remote_link(issue, destination)
- comment = "User '%s' has created a pull request for this issue:" % pr['user']['login']
- comment += "\n%s" % pr['html_url']
+ comment = "User '%s' has created a pull request for this issue:" % pr["user"]["login"]
+ comment += "\n%s" % pr["html_url"]
if pr_num >= MIN_COMMENT_PR:
jira_client.add_comment(issue, comment)
- print("Added link %s <-> PR #%s" % (issue, pr['number']))
+ print("Added link %s <-> PR #%s" % (issue, pr["number"]))
num_updates += 1
if len(considered) > 0:
diff --git a/dev/is-changed.py b/dev/is-changed.py
index 5491df6..e2f50be 100755
--- a/dev/is-changed.py
+++ b/dev/is-changed.py
@@ -23,32 +23,30 @@ from argparse import ArgumentParser
from sparktestsupport.utils import (
determine_modules_for_files,
determine_modules_to_test,
- identify_changed_files_from_git_commits
+ identify_changed_files_from_git_commits,
)
import sparktestsupport.modules as modules
def parse_opts():
- parser = ArgumentParser(
- prog="is-changed"
- )
+ parser = ArgumentParser(prog="is-changed")
parser.add_argument(
- "-f", "--fail", action='store_true',
- help="Exit with 1 if there is no relevant change."
+ "-f", "--fail", action="store_true", help="Exit with 1 if there is no relevant change."
)
default_value = ",".join(sorted([m.name for m in modules.all_modules]))
parser.add_argument(
- "-m", "--modules", type=str,
+ "-m",
+ "--modules",
+ type=str,
default=default_value,
- help="A comma-separated list of modules to test "
- "(default: %s)" % default_value
+ help="A comma-separated list of modules to test " "(default: %s)" % default_value,
)
args, unknown = parser.parse_known_args()
if unknown:
- parser.error("Unsupported arguments: %s" % ' '.join(unknown))
+ parser.error("Unsupported arguments: %s" % " ".join(unknown))
return args
@@ -57,15 +55,17 @@ def main():
test_modules = opts.modules.split(",")
changed_files = identify_changed_files_from_git_commits(
- "HEAD", target_ref=os.environ["APACHE_SPARK_REF"])
+ "HEAD", target_ref=os.environ["APACHE_SPARK_REF"]
+ )
changed_modules = determine_modules_to_test(
- determine_modules_for_files(changed_files), deduplicated=False)
+ determine_modules_for_files(changed_files), deduplicated=False
+ )
module_names = [m.name for m in changed_modules]
if len(changed_modules) == 0:
print("false")
if opts.fail:
sys.exit(1)
- elif 'root' in test_modules or modules.root in changed_modules:
+ elif "root" in test_modules or modules.root in changed_modules:
print("true")
elif len(set(test_modules).intersection(module_names)) == 0:
print("false")
diff --git a/dev/lint-python b/dev/lint-python
index e60ba7b..c40198e 100755
--- a/dev/lint-python
+++ b/dev/lint-python
@@ -210,7 +210,7 @@ function black_test {
fi
echo "starting black test..."
- BLACK_REPORT=$( ($BLACK_BUILD --config dev/pyproject.toml --check python/pyspark) 2>&1)
+ BLACK_REPORT=$( ($BLACK_BUILD --config dev/pyproject.toml --check python/pyspark dev) 2>&1)
BLACK_STATUS=$?
if [ "$BLACK_STATUS" -ne 0 ]; then
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 284690b..8d09c53 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -37,6 +37,7 @@ from urllib.error import HTTPError
try:
import jira.client
+
JIRA_IMPORTED = True
except ImportError:
JIRA_IMPORTED = False
@@ -70,13 +71,15 @@ def get_json(url):
try:
request = Request(url)
if GITHUB_OAUTH_KEY:
- request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY)
+ request.add_header("Authorization", "token %s" % GITHUB_OAUTH_KEY)
return json.load(urlopen(request))
except HTTPError as e:
- if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0':
- print("Exceeded the GitHub API rate limit; see the instructions in " +
- "dev/merge_spark_pr.py to configure an OAuth token for making authenticated " +
- "GitHub requests.")
+ if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == "0":
+ print(
+ "Exceeded the GitHub API rate limit; see the instructions in "
+ + "dev/merge_spark_pr.py to configure an OAuth token for making authenticated "
+ + "GitHub requests."
+ )
else:
print("Unable to fetch URL, exiting: %s" % url)
sys.exit(-1)
@@ -91,9 +94,9 @@ def fail(msg):
def run_cmd(cmd):
print(cmd)
if isinstance(cmd, list):
- return subprocess.check_output(cmd).decode('utf-8')
+ return subprocess.check_output(cmd).decode("utf-8")
else:
- return subprocess.check_output(cmd.split(" ")).decode('utf-8')
+ return subprocess.check_output(cmd.split(" ")).decode("utf-8")
def continue_maybe(prompt):
@@ -103,7 +106,7 @@ def continue_maybe(prompt):
def clean_up():
- if 'original_head' in globals():
+ if "original_head" in globals():
print("Restoring head pointer to %s" % original_head)
run_cmd("git checkout %s" % original_head)
@@ -124,7 +127,7 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
had_conflicts = False
try:
- run_cmd(['git', 'merge', pr_branch_name, '--squash'])
+ run_cmd(["git", "merge", pr_branch_name, "--squash"])
except Exception as e:
msg = "Error merging: %s\nWould you like to manually fix-up this merge?" % e
continue_maybe(msg)
@@ -132,13 +135,15 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
continue_maybe(msg)
had_conflicts = True
- commit_authors = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name,
- '--pretty=format:%an <%ae>']).split("\n")
- distinct_authors = sorted(set(commit_authors),
- key=lambda x: commit_authors.count(x), reverse=True)
+ commit_authors = run_cmd(
+ ["git", "log", "HEAD..%s" % pr_branch_name, "--pretty=format:%an <%ae>"]
+ ).split("\n")
+ distinct_authors = sorted(
+ set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True
+ )
primary_author = input(
- "Enter primary author in the format of \"name <email>\" [%s]: " %
- distinct_authors[0])
+ 'Enter primary author in the format of "name <email>" [%s]: ' % distinct_authors[0]
+ )
if primary_author == "":
primary_author = distinct_authors[0]
else:
@@ -160,7 +165,9 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
if had_conflicts:
message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % (
- committer_name, committer_email)
+ committer_name,
+ committer_email,
+ )
merge_message_flags += ["-m", message]
# The string "Closes #%s" string is required for GitHub to correctly close the PR
@@ -174,13 +181,14 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc):
merge_message_flags += ["-m", authors]
- run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags)
+ run_cmd(["git", "commit", '--author="%s"' % primary_author] + merge_message_flags)
- continue_maybe("Merge complete (local ref %s). Push to %s?" % (
- target_branch_name, PUSH_REMOTE_NAME))
+ continue_maybe(
+ "Merge complete (local ref %s). Push to %s?" % (target_branch_name, PUSH_REMOTE_NAME)
+ )
try:
- run_cmd('git push %s %s:%s' % (PUSH_REMOTE_NAME, target_branch_name, target_ref))
+ run_cmd("git push %s %s:%s" % (PUSH_REMOTE_NAME, target_branch_name, target_ref))
except Exception as e:
clean_up()
fail("Exception while pushing: %s" % e)
@@ -210,11 +218,12 @@ def cherry_pick(pr_num, merge_hash, default_branch):
msg = "Okay, please fix any conflicts and finish the cherry-pick. Finished?"
continue_maybe(msg)
- continue_maybe("Pick complete (local ref %s). Push to %s?" % (
- pick_branch_name, PUSH_REMOTE_NAME))
+ continue_maybe(
+ "Pick complete (local ref %s). Push to %s?" % (pick_branch_name, PUSH_REMOTE_NAME)
+ )
try:
- run_cmd('git push %s %s:%s' % (PUSH_REMOTE_NAME, pick_branch_name, pick_ref))
+ run_cmd("git push %s %s:%s" % (PUSH_REMOTE_NAME, pick_branch_name, pick_ref))
except Exception as e:
clean_up()
fail("Exception while pushing: %s" % e)
@@ -237,8 +246,9 @@ def fix_version_from_branch(branch, versions):
def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
- asf_jira = jira.client.JIRA({'server': JIRA_API_BASE},
- basic_auth=(JIRA_USERNAME, JIRA_PASSWORD))
+ asf_jira = jira.client.JIRA(
+ {"server": JIRA_API_BASE}, basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)
+ )
jira_id = input("Enter a JIRA id [%s]: " % default_jira_id)
if jira_id == "":
@@ -263,17 +273,20 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
if cur_status == "Resolved" or cur_status == "Closed":
fail("JIRA issue %s already has status '%s'" % (jira_id, cur_status))
print("=== JIRA %s ===" % jira_id)
- print("summary\t\t%s\nassignee\t%s\nstatus\t\t%s\nurl\t\t%s/%s\n" %
- (cur_summary, cur_assignee, cur_status, JIRA_BASE, jira_id))
+ print(
+ "summary\t\t%s\nassignee\t%s\nstatus\t\t%s\nurl\t\t%s/%s\n"
+ % (cur_summary, cur_assignee, cur_status, JIRA_BASE, jira_id)
+ )
versions = asf_jira.project_versions("SPARK")
versions = sorted(versions, key=lambda x: x.name, reverse=True)
- versions = list(filter(lambda x: x.raw['released'] is False, versions))
+ versions = list(filter(lambda x: x.raw["released"] is False, versions))
# Consider only x.y.z versions
- versions = list(filter(lambda x: re.match(r'\d+\.\d+\.\d+', x.name), versions))
+ versions = list(filter(lambda x: re.match(r"\d+\.\d+\.\d+", x.name), versions))
- default_fix_versions = list(map(
- lambda x: fix_version_from_branch(x, versions).name, merge_branches))
+ default_fix_versions = list(
+ map(lambda x: fix_version_from_branch(x, versions).name, merge_branches)
+ )
for v in default_fix_versions:
# Handles the case where we have forked a release branch but not yet made the release.
# In this case, if the PR is committed to the master branch and the release branch, we
@@ -290,15 +303,18 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
while True:
try:
fix_versions = input(
- "Enter comma-separated fix version(s) [%s]: " % default_fix_versions)
+ "Enter comma-separated fix version(s) [%s]: " % default_fix_versions
+ )
if fix_versions == "":
fix_versions = default_fix_versions
fix_versions = fix_versions.replace(" ", "").split(",")
if set(fix_versions).issubset(available_versions):
break
else:
- print("Specified version(s) [%s] not found in the available versions, try "
- "again (or leave blank and fix manually)." % (", ".join(fix_versions)))
+ print(
+ "Specified version(s) [%s] not found in the available versions, try "
+ "again (or leave blank and fix manually)." % (", ".join(fix_versions))
+ )
except KeyboardInterrupt:
raise
except BaseException:
@@ -310,11 +326,15 @@ def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
jira_fix_versions = list(map(lambda v: get_version_json(v), fix_versions))
- resolve = list(filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id)))[0]
- resolution = list(filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions()))[0]
+ resolve = list(filter(lambda a: a["name"] == "Resolve Issue", asf_jira.transitions(jira_id)))[0]
+ resolution = list(filter(lambda r: r.raw["name"] == "Fixed", asf_jira.resolutions()))[0]
asf_jira.transition_issue(
- jira_id, resolve["id"], fixVersions=jira_fix_versions,
- comment=comment, resolution={'id': resolution.raw['id']})
+ jira_id,
+ resolve["id"],
+ fixVersions=jira_fix_versions,
+ comment=comment,
+ resolution={"id": resolution.raw["id"]},
+ )
print("Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions))
@@ -340,7 +360,8 @@ def choose_jira_assignee(issue, asf_jira):
annotations.append("Commentor")
print("[%d] %s (%s)" % (idx, author.displayName, ",".join(annotations)))
raw_assignee = input(
- "Enter number of user, or userid, to assign to (blank to leave unassigned):")
+ "Enter number of user, or userid, to assign to (blank to leave unassigned):"
+ )
if raw_assignee == "":
return None
else:
@@ -402,41 +423,41 @@ def standardize_jira_ref(text):
components = []
# If the string is compliant, no need to process any further
- if (re.search(r'^\[SPARK-[0-9]{3,6}\](\[[A-Z0-9_\s,]+\] )+\S+', text)):
+ if re.search(r"^\[SPARK-[0-9]{3,6}\](\[[A-Z0-9_\s,]+\] )+\S+", text):
return text
# Extract JIRA ref(s):
- pattern = re.compile(r'(SPARK[-\s]*[0-9]{3,6})+', re.IGNORECASE)
+ pattern = re.compile(r"(SPARK[-\s]*[0-9]{3,6})+", re.IGNORECASE)
for ref in pattern.findall(text):
# Add brackets, replace spaces with a dash, & convert to uppercase
- jira_refs.append('[' + re.sub(r'\s+', '-', ref.upper()) + ']')
- text = text.replace(ref, '')
+ jira_refs.append("[" + re.sub(r"\s+", "-", ref.upper()) + "]")
+ text = text.replace(ref, "")
# Extract spark component(s):
# Look for alphanumeric chars, spaces, dashes, periods, and/or commas
- pattern = re.compile(r'(\[[\w\s,.-]+\])', re.IGNORECASE)
+ pattern = re.compile(r"(\[[\w\s,.-]+\])", re.IGNORECASE)
for component in pattern.findall(text):
components.append(component.upper())
- text = text.replace(component, '')
+ text = text.replace(component, "")
# Cleanup any remaining symbols:
- pattern = re.compile(r'^\W+(.*)', re.IGNORECASE)
- if (pattern.search(text) is not None):
+ pattern = re.compile(r"^\W+(.*)", re.IGNORECASE)
+ if pattern.search(text) is not None:
text = pattern.search(text).groups()[0]
# Assemble full text (JIRA ref(s), module(s), remaining text)
- clean_text = ''.join(jira_refs).strip() + ''.join(components).strip() + " " + text.strip()
+ clean_text = "".join(jira_refs).strip() + "".join(components).strip() + " " + text.strip()
# Replace multiple spaces with a single space, e.g. if no jira refs and/or components were
# included
- clean_text = re.sub(r'\s+', ' ', clean_text.strip())
+ clean_text = re.sub(r"\s+", " ", clean_text.strip())
return clean_text
def get_current_ref():
ref = run_cmd("git rev-parse --abbrev-ref HEAD").strip()
- if ref == 'HEAD':
+ if ref == "HEAD":
# The current ref is a detached HEAD, so grab its SHA.
return run_cmd("git rev-parse HEAD").strip()
else:
@@ -454,7 +475,7 @@ def main():
continue_maybe("The env-vars JIRA_USERNAME and/or JIRA_PASSWORD are not set. Continue?")
branches = get_json("%s/branches" % GITHUB_API_BASE)
- branch_names = list(filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches]))
+ branch_names = list(filter(lambda x: x.startswith("branch-"), [x["name"] for x in branches]))
# Assumes branch names can be sorted lexicographically
latest_branch = sorted(branch_names, reverse=True)[0]
@@ -486,7 +507,7 @@ def main():
else:
title = pr["title"]
- modified_body = re.sub(re.compile(r'<!--[^>]*-->\n?', re.DOTALL), '', pr["body"]).lstrip()
+ modified_body = re.sub(re.compile(r"<!--[^>]*-->\n?", re.DOTALL), "", pr["body"]).lstrip()
if modified_body != pr["body"]:
print("=" * 80)
print(modified_body)
@@ -511,16 +532,19 @@ def main():
# Merged pull requests don't appear as merged in the GitHub API;
# Instead, they're closed by asfgit.
- merge_commits = \
- [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"]
+ merge_commits = [
+ e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"
+ ]
if merge_commits:
merge_hash = merge_commits[0]["commit_id"]
message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"]
print("Pull request %s has already been merged, assuming you want to backport" % pr_num)
- commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify',
- "%s^{commit}" % merge_hash]).strip() != ""
+ commit_is_downloaded = (
+ run_cmd(["git", "rev-parse", "--quiet", "--verify", "%s^{commit}" % merge_hash]).strip()
+ != ""
+ )
if not commit_is_downloaded:
fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num)
@@ -529,13 +553,14 @@ def main():
sys.exit(0)
if not bool(pr["mergeable"]):
- msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \
- "Continue? (experts only!)"
+ msg = (
+ "Pull request %s is not mergeable in its current form.\n" % pr_num
+ + "Continue? (experts only!)"
+ )
continue_maybe(msg)
print("\n=== Pull Request #%s ===" % pr_num)
- print("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" %
- (title, pr_repo_desc, target_ref, url))
+ print("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % (title, pr_repo_desc, target_ref, url))
continue_maybe("Proceed with merging pull request #%s?" % pr_num)
merged_refs = [target_ref]
@@ -549,8 +574,11 @@ def main():
if JIRA_IMPORTED:
if JIRA_USERNAME and JIRA_PASSWORD:
continue_maybe("Would you like to update an associated JIRA?")
- jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % \
- (pr_num, GITHUB_BASE, pr_num)
+ jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (
+ pr_num,
+ GITHUB_BASE,
+ pr_num,
+ )
resolve_jira_issues(title, merged_refs, jira_comment)
else:
print("JIRA_USERNAME and JIRA_PASSWORD not set")
@@ -562,6 +590,7 @@ def main():
if __name__ == "__main__":
import doctest
+
(failure_count, test_count) = doctest.testmod()
if failure_count:
sys.exit(-1)
diff --git a/dev/pip-sanity-check.py b/dev/pip-sanity-check.py
index 469e27b..fdb1107 100644
--- a/dev/pip-sanity-check.py
+++ b/dev/pip-sanity-check.py
@@ -19,14 +19,11 @@ from pyspark.sql import SparkSession
import sys
if __name__ == "__main__":
- spark = SparkSession\
- .builder\
- .appName("PipSanityCheck")\
- .getOrCreate()
+ spark = SparkSession.builder.appName("PipSanityCheck").getOrCreate()
sc = spark.sparkContext
rdd = sc.parallelize(range(100), 10)
value = rdd.reduce(lambda x, y: x + y)
- if (value != 4950):
+ if value != 4950:
print("Value {0} did not match expected value.".format(value), file=sys.stderr)
sys.exit(-1)
print("Successfully ran pip sanity check")
diff --git a/dev/reformat-python b/dev/reformat-python
index 03f7846..54e183c 100755
--- a/dev/reformat-python
+++ b/dev/reformat-python
@@ -28,4 +28,4 @@ if [ $? -ne 0 ]; then
exit 1
fi
-$BLACK_BUILD --config dev/pyproject.toml python/pyspark
+$BLACK_BUILD --config dev/pyproject.toml python/pyspark dev
diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py
index 1783210..93fbf1a 100755
--- a/dev/run-tests-jenkins.py
+++ b/dev/run-tests-jenkins.py
@@ -45,12 +45,14 @@ def post_message_to_github(msg, ghprb_pull_id):
github_oauth_key = os.environ["GITHUB_OAUTH_KEY"]
posted_message = json.dumps({"body": msg})
- request = Request(url,
- headers={
- "Authorization": "token %s" % github_oauth_key,
- "Content-Type": "application/json"
- },
- data=posted_message.encode('utf-8'))
+ request = Request(
+ url,
+ headers={
+ "Authorization": "token %s" % github_oauth_key,
+ "Content-Type": "application/json",
+ },
+ data=posted_message.encode("utf-8"),
+ )
try:
response = urlopen(request)
@@ -67,22 +69,20 @@ def post_message_to_github(msg, ghprb_pull_id):
print_err(" > data: %s" % posted_message)
-def pr_message(build_display_name,
- build_url,
- ghprb_pull_id,
- short_commit_hash,
- commit_url,
- msg,
- post_msg=''):
+def pr_message(
+ build_display_name, build_url, ghprb_pull_id, short_commit_hash, commit_url, msg, post_msg=""
+):
# align the arguments properly for string formatting
- str_args = (build_display_name,
- msg,
- build_url,
- ghprb_pull_id,
- short_commit_hash,
- commit_url,
- str(' ' + post_msg + '.') if post_msg else '.')
- return '**[Test build %s %s](%stestReport)** for PR %s at commit [`%s`](%s)%s' % str_args
+ str_args = (
+ build_display_name,
+ msg,
+ build_url,
+ ghprb_pull_id,
+ short_commit_hash,
+ commit_url,
+ str(" " + post_msg + ".") if post_msg else ".",
+ )
+ return "**[Test build %s %s](%stestReport)** for PR %s at commit [`%s`](%s)%s" % str_args
def run_pr_checks(pr_tests, ghprb_actual_commit, sha1):
@@ -92,16 +92,24 @@ def run_pr_checks(pr_tests, ghprb_actual_commit, sha1):
@return a list of messages to post back to GitHub
"""
# Ensure we save off the current HEAD to revert to
- current_pr_head = run_cmd(['git', 'rev-parse', 'HEAD'], return_output=True).strip()
+ current_pr_head = run_cmd(["git", "rev-parse", "HEAD"], return_output=True).strip()
pr_results = list()
for pr_test in pr_tests:
- test_name = pr_test + '.sh'
- pr_results.append(run_cmd(['bash', os.path.join(SPARK_HOME, 'dev', 'tests', test_name),
- ghprb_actual_commit, sha1],
- return_output=True).rstrip())
+ test_name = pr_test + ".sh"
+ pr_results.append(
+ run_cmd(
+ [
+ "bash",
+ os.path.join(SPARK_HOME, "dev", "tests", test_name),
+ ghprb_actual_commit,
+ sha1,
+ ],
+ return_output=True,
+ ).rstrip()
+ )
# Ensure, after each test, that we're back on the current PR
- run_cmd(['git', 'checkout', '-f', current_pr_head])
+ run_cmd(["git", "checkout", "-f", current_pr_head])
return pr_results
@@ -112,37 +120,38 @@ def run_tests(tests_timeout):
@return a tuple containing the test result code and the result note to post to GitHub
"""
- test_result_code = subprocess.Popen(['timeout',
- tests_timeout,
- os.path.join(SPARK_HOME, 'dev', 'run-tests')]).wait()
+ test_result_code = subprocess.Popen(
+ ["timeout", tests_timeout, os.path.join(SPARK_HOME, "dev", "run-tests")]
+ ).wait()
failure_note_by_errcode = {
# error to denote run-tests script failures:
- 1: 'executing the `dev/run-tests` script',
- ERROR_CODES["BLOCK_GENERAL"]: 'some tests',
- ERROR_CODES["BLOCK_RAT"]: 'RAT tests',
- ERROR_CODES["BLOCK_SCALA_STYLE"]: 'Scala style tests',
- ERROR_CODES["BLOCK_JAVA_STYLE"]: 'Java style tests',
- ERROR_CODES["BLOCK_PYTHON_STYLE"]: 'Python style tests',
- ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests',
- ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation',
- ERROR_CODES["BLOCK_BUILD"]: 'to build',
- ERROR_CODES["BLOCK_BUILD_TESTS"]: 'build dependency tests',
- ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests',
- ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests',
- ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests',
- ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: 'PySpark pip packaging tests',
- ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: 'SparkR unit tests',
- ERROR_CODES["BLOCK_TIMEOUT"]: 'from timeout after a configured wait of `%s`' % (
- tests_timeout)
+ 1: "executing the `dev/run-tests` script",
+ ERROR_CODES["BLOCK_GENERAL"]: "some tests",
+ ERROR_CODES["BLOCK_RAT"]: "RAT tests",
+ ERROR_CODES["BLOCK_SCALA_STYLE"]: "Scala style tests",
+ ERROR_CODES["BLOCK_JAVA_STYLE"]: "Java style tests",
+ ERROR_CODES["BLOCK_PYTHON_STYLE"]: "Python style tests",
+ ERROR_CODES["BLOCK_R_STYLE"]: "R style tests",
+ ERROR_CODES["BLOCK_DOCUMENTATION"]: "to generate documentation",
+ ERROR_CODES["BLOCK_BUILD"]: "to build",
+ ERROR_CODES["BLOCK_BUILD_TESTS"]: "build dependency tests",
+ ERROR_CODES["BLOCK_MIMA"]: "MiMa tests",
+ ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: "Spark unit tests",
+ ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: "PySpark unit tests",
+ ERROR_CODES["BLOCK_PYSPARK_PIP_TESTS"]: "PySpark pip packaging tests",
+ ERROR_CODES["BLOCK_SPARKR_UNIT_TESTS"]: "SparkR unit tests",
+ ERROR_CODES["BLOCK_TIMEOUT"]: "from timeout after a configured wait of `%s`"
+ % (tests_timeout),
}
if test_result_code == 0:
- test_result_note = ' * This patch passes all tests.'
+ test_result_note = " * This patch passes all tests."
else:
note = failure_note_by_errcode.get(
- test_result_code, "due to an unknown error code, %s" % test_result_code)
- test_result_note = ' * This patch **fails %s**.' % note
+ test_result_code, "due to an unknown error code, %s" % test_result_code
+ )
+ test_result_note = " * This patch **fails %s**." % note
return [test_result_code, test_result_note]
@@ -202,30 +211,24 @@ def main():
# hash, the second the GitHub SHA1 hash, and the final the current PR hash
# * and, lastly, return string output to be included in the pr message output that will
# be posted to GitHub
- pr_tests = [
- "pr_merge_ability",
- "pr_public_classes"
- ]
+ pr_tests = ["pr_merge_ability", "pr_public_classes"]
# `bind_message_base` returns a function to generate messages for GitHub posting
- github_message = functools.partial(pr_message,
- build_display_name,
- build_url,
- ghprb_pull_id,
- short_commit_hash,
- commit_url)
+ github_message = functools.partial(
+ pr_message, build_display_name, build_url, ghprb_pull_id, short_commit_hash, commit_url
+ )
# post start message
- post_message_to_github(github_message('has started'), ghprb_pull_id)
+ post_message_to_github(github_message("has started"), ghprb_pull_id)
pr_check_results = run_pr_checks(pr_tests, ghprb_actual_commit, sha1)
test_result_code, test_result_note = run_tests(tests_timeout)
# post end message
- result_message = github_message('has finished')
- result_message += '\n' + test_result_note + '\n'
- result_message += '\n'.join(pr_check_results)
+ result_message = github_message("has finished")
+ result_message += "\n" + test_result_note + "\n"
+ result_message += "\n".join(pr_check_results)
post_message_to_github(result_message, ghprb_pull_id)
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 266aafec..d943277 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -30,7 +30,7 @@ from sparktestsupport.utils import (
determine_modules_for_files,
determine_modules_to_test,
determine_tags_to_exclude,
- identify_changed_files_from_git_commits
+ identify_changed_files_from_git_commits,
)
import sparktestsupport.modules as modules
@@ -70,9 +70,9 @@ def determine_java_executable():
def set_title_and_block(title, err_block):
os.environ["CURRENT_BLOCK"] = str(ERROR_CODES[err_block])
- line_str = '=' * 72
+ line_str = "=" * 72
- print('')
+ print("")
print(line_str)
print(title)
print(line_str)
@@ -126,8 +126,10 @@ def build_spark_documentation():
bundle_bin = which("bundle")
if not bundle_bin:
- print("[error] Cannot find a version of `bundle` on the system; please",
- " install one with `gem install bundler` and retry to build documentation.")
+ print(
+ "[error] Cannot find a version of `bundle` on the system; please",
+ " install one with `gem install bundler` and retry to build documentation.",
+ )
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
else:
run_cmd([bundle_bin, "install"])
@@ -150,22 +152,20 @@ def exec_sbt(sbt_args=()):
sbt_cmd = [os.path.join(SPARK_HOME, "build", "sbt")] + sbt_args
- sbt_output_filter = re.compile(b"^.*[info].*Resolving" + b"|" +
- b"^.*[warn].*Merging" + b"|" +
- b"^.*[info].*Including")
+ sbt_output_filter = re.compile(
+ b"^.*[info].*Resolving" + b"|" + b"^.*[warn].*Merging" + b"|" + b"^.*[info].*Including"
+ )
# NOTE: echo "q" is needed because sbt on encountering a build file
# with failure (either resolution or compilation) prompts the user for
# input either q, r, etc to quit or retry. This echo is there to make it
# not block.
- echo_proc = subprocess.Popen(["echo", "\"q\n\""], stdout=subprocess.PIPE)
- sbt_proc = subprocess.Popen(sbt_cmd,
- stdin=echo_proc.stdout,
- stdout=subprocess.PIPE)
+ echo_proc = subprocess.Popen(["echo", '"q\n"'], stdout=subprocess.PIPE)
+ sbt_proc = subprocess.Popen(sbt_cmd, stdin=echo_proc.stdout, stdout=subprocess.PIPE)
echo_proc.wait()
- for line in iter(sbt_proc.stdout.readline, b''):
+ for line in iter(sbt_proc.stdout.readline, b""):
if not sbt_output_filter.match(line):
- print(line.decode('utf-8'), end='')
+ print(line.decode("utf-8"), end="")
retcode = sbt_proc.wait()
if retcode != 0:
@@ -188,8 +188,13 @@ def get_scala_profiles(scala_version):
if scala_version in sbt_maven_scala_profiles:
return sbt_maven_scala_profiles[scala_version]
else:
- print("[error] Could not find", scala_version, "in the list. Valid options",
- " are", sbt_maven_scala_profiles.keys())
+ print(
+ "[error] Could not find",
+ scala_version,
+ "in the list. Valid options",
+ " are",
+ sbt_maven_scala_profiles.keys(),
+ )
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
@@ -197,8 +202,7 @@ def switch_scala_version(scala_version):
"""
Switch the code base to use the given Scala version.
"""
- set_title_and_block(
- "Switch the Scala version to %s" % scala_version, "BLOCK_SCALA_VERSION")
+ set_title_and_block("Switch the Scala version to %s" % scala_version, "BLOCK_SCALA_VERSION")
assert scala_version is not None
ver_num = scala_version[-4:] # Simply extract. e.g.) 2.13 from scala2.13
@@ -220,8 +224,13 @@ def get_hadoop_profiles(hadoop_version):
if hadoop_version in sbt_maven_hadoop_profiles:
return sbt_maven_hadoop_profiles[hadoop_version]
else:
- print("[error] Could not find", hadoop_version, "in the list. Valid options",
- " are", sbt_maven_hadoop_profiles.keys())
+ print(
+ "[error] Could not find",
+ hadoop_version,
+ "in the list. Valid options",
+ " are",
+ sbt_maven_hadoop_profiles.keys(),
+ )
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
@@ -239,8 +248,10 @@ def build_spark_maven(extra_profiles):
def build_spark_sbt(extra_profiles):
# Enable all of the profiles for the build:
build_profiles = extra_profiles + modules.root.build_profile_flags
- sbt_goals = ["test:package", # Build test jars as some tests depend on them
- "streaming-kinesis-asl-assembly/assembly"]
+ sbt_goals = [
+ "test:package", # Build test jars as some tests depend on them
+ "streaming-kinesis-asl-assembly/assembly",
+ ]
profiles_and_goals = build_profiles + sbt_goals
print("[info] Building Spark using SBT with these arguments: ", " ".join(profiles_and_goals))
@@ -255,8 +266,10 @@ def build_spark_unidoc_sbt(extra_profiles):
sbt_goals = ["unidoc"]
profiles_and_goals = build_profiles + sbt_goals
- print("[info] Building Spark unidoc using SBT with these arguments: ",
- " ".join(profiles_and_goals))
+ print(
+ "[info] Building Spark unidoc using SBT with these arguments: ",
+ " ".join(profiles_and_goals),
+ )
exec_sbt(profiles_and_goals)
@@ -266,8 +279,10 @@ def build_spark_assembly_sbt(extra_profiles, checkstyle=False):
build_profiles = extra_profiles + modules.root.build_profile_flags
sbt_goals = ["assembly/package"]
profiles_and_goals = build_profiles + sbt_goals
- print("[info] Building Spark assembly using SBT with these arguments: ",
- " ".join(profiles_and_goals))
+ print(
+ "[info] Building Spark assembly using SBT with these arguments: ",
+ " ".join(profiles_and_goals),
+ )
exec_sbt(profiles_and_goals)
if checkstyle:
@@ -295,8 +310,10 @@ def detect_binary_inop_with_mima(extra_profiles):
build_profiles = extra_profiles + modules.root.build_profile_flags
set_title_and_block("Detecting binary incompatibilities with MiMa", "BLOCK_MIMA")
profiles = " ".join(build_profiles)
- print("[info] Detecting binary incompatibilities with MiMa using SBT with these profiles: ",
- profiles)
+ print(
+ "[info] Detecting binary incompatibilities with MiMa using SBT with these profiles: ",
+ profiles,
+ )
run_cmd([os.path.join(SPARK_HOME, "dev", "mima"), profiles])
@@ -305,8 +322,10 @@ def run_scala_tests_maven(test_profiles):
profiles_and_goals = test_profiles + mvn_test_goals
- print("[info] Running Spark tests using Maven with these arguments: ",
- " ".join(profiles_and_goals))
+ print(
+ "[info] Running Spark tests using Maven with these arguments: ",
+ " ".join(profiles_and_goals),
+ )
exec_maven(profiles_and_goals)
@@ -320,8 +339,9 @@ def run_scala_tests_sbt(test_modules, test_profiles):
profiles_and_goals = test_profiles + sbt_test_goals
- print("[info] Running Spark tests using SBT with these arguments: ",
- " ".join(profiles_and_goals))
+ print(
+ "[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)
+ )
exec_sbt(profiles_and_goals)
@@ -334,20 +354,21 @@ def run_scala_tests(build_tool, extra_profiles, test_modules, excluded_tags, inc
# Remove duplicates while keeping the test module order
test_modules = list(dict.fromkeys(test_modules))
- test_profiles = extra_profiles + \
- list(set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules)))
+ test_profiles = extra_profiles + list(
+ set(itertools.chain.from_iterable(m.build_profile_flags for m in test_modules))
+ )
if included_tags:
- test_profiles += ['-Dtest.include.tags=' + ",".join(included_tags)]
+ test_profiles += ["-Dtest.include.tags=" + ",".join(included_tags)]
if excluded_tags:
- test_profiles += ['-Dtest.exclude.tags=' + ",".join(excluded_tags)]
+ test_profiles += ["-Dtest.exclude.tags=" + ",".join(excluded_tags)]
# set up java11 env if this is a pull request build with 'test-java11' in the title
if "ghprbPullTitle" in os.environ:
if "test-java11" in os.environ["ghprbPullTitle"].lower():
os.environ["JAVA_HOME"] = "/usr/java/jdk-11.0.1"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["JAVA_HOME"], os.environ["PATH"])
- test_profiles += ['-Djava.version=11']
+ test_profiles += ["-Djava.version=11"]
if build_tool == "maven":
run_scala_tests_maven(test_profiles)
@@ -368,7 +389,7 @@ def run_python_tests(test_modules, parallelism, with_coverage=False):
script = "run-tests"
command = [os.path.join(SPARK_HOME, "python", script)]
if test_modules != [modules.root]:
- command.append("--modules=%s" % ','.join(m.name for m in test_modules))
+ command.append("--modules=%s" % ",".join(m.name for m in test_modules))
command.append("--parallelism=%i" % parallelism)
run_cmd(command)
@@ -395,35 +416,42 @@ def run_sparkr_tests():
def parse_opts():
- parser = ArgumentParser(
- prog="run-tests"
- )
+ parser = ArgumentParser(prog="run-tests")
parser.add_argument(
- "-p", "--parallelism", type=int, default=8,
- help="The number of suites to test in parallel (default %(default)d)"
+ "-p",
+ "--parallelism",
+ type=int,
+ default=8,
+ help="The number of suites to test in parallel (default %(default)d)",
)
parser.add_argument(
- "-m", "--modules", type=str,
+ "-m",
+ "--modules",
+ type=str,
default=None,
help="A comma-separated list of modules to test "
- "(default: %s)" % ",".join(sorted([m.name for m in modules.all_modules]))
+ "(default: %s)" % ",".join(sorted([m.name for m in modules.all_modules])),
)
parser.add_argument(
- "-e", "--excluded-tags", type=str,
+ "-e",
+ "--excluded-tags",
+ type=str,
default=None,
help="A comma-separated list of tags to exclude in the tests, "
- "e.g., org.apache.spark.tags.ExtendedHiveTest "
+ "e.g., org.apache.spark.tags.ExtendedHiveTest ",
)
parser.add_argument(
- "-i", "--included-tags", type=str,
+ "-i",
+ "--included-tags",
+ type=str,
default=None,
help="A comma-separated list of tags to include in the tests, "
- "e.g., org.apache.spark.tags.ExtendedHiveTest "
+ "e.g., org.apache.spark.tags.ExtendedHiveTest ",
)
args, unknown = parser.parse_known_args()
if unknown:
- parser.error("Unsupported arguments: %s" % ' '.join(unknown))
+ parser.error("Unsupported arguments: %s" % " ".join(unknown))
if args.parallelism < 1:
parser.error("Parallelism cannot be less than 1")
return args
@@ -433,8 +461,10 @@ def main():
opts = parse_opts()
# Ensure the user home directory (HOME) is valid and is an absolute directory
if not USER_HOME or not os.path.isabs(USER_HOME):
- print("[error] Cannot determine your home directory as an absolute path;",
- " ensure the $HOME environment variable is set properly.")
+ print(
+ "[error] Cannot determine your home directory as an absolute path;",
+ " ensure the $HOME environment variable is set properly.",
+ )
sys.exit(1)
os.chdir(SPARK_HOME)
@@ -448,8 +478,10 @@ def main():
java_exe = determine_java_executable()
if not java_exe:
- print("[error] Cannot find a version of `java` on the system; please",
- " install one and retry.")
+ print(
+ "[error] Cannot find a version of `java` on the system; please",
+ " install one and retry.",
+ )
sys.exit(2)
# Install SparkR
@@ -490,8 +522,12 @@ def main():
extra_profiles = get_hadoop_profiles(hadoop_version) + get_scala_profiles(scala_version)
- print("[info] Using build tool", build_tool, "with profiles",
- *(extra_profiles + ["under environment", test_env]))
+ print(
+ "[info] Using build tool",
+ build_tool,
+ "with profiles",
+ *(extra_profiles + ["under environment", test_env]),
+ )
changed_modules = []
changed_files = []
@@ -509,13 +545,16 @@ def main():
if test_env == "github_actions" and (is_apache_spark_ref or is_github_prev_sha):
if is_apache_spark_ref:
changed_files = identify_changed_files_from_git_commits(
- "HEAD", target_ref=os.environ["APACHE_SPARK_REF"])
+ "HEAD", target_ref=os.environ["APACHE_SPARK_REF"]
+ )
elif is_github_prev_sha:
changed_files = identify_changed_files_from_git_commits(
- os.environ["GITHUB_SHA"], target_ref=os.environ["GITHUB_PREV_SHA"])
+ os.environ["GITHUB_SHA"], target_ref=os.environ["GITHUB_PREV_SHA"]
+ )
modules_to_test = determine_modules_to_test(
- determine_modules_for_files(changed_files), deduplicated=False)
+ determine_modules_for_files(changed_files), deduplicated=False
+ )
if modules.root not in modules_to_test:
# If root module is not found, only test the intersected modules.
@@ -547,8 +586,7 @@ def main():
if opts.included_tags:
included_tags.extend([t.strip() for t in opts.included_tags.split(",")])
- print("[info] Found the following changed modules:",
- ", ".join(x.name for x in changed_modules))
+ print("[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules))
# setup environment variables
# note - the 'root' module doesn't collect environment variables for all modules. Because the
@@ -569,25 +607,26 @@ def main():
run_apache_rat_checks()
# style checks
- if not changed_files or any(f.endswith(".scala")
- or f.endswith("scalastyle-config.xml")
- for f in changed_files):
+ if not changed_files or any(
+ f.endswith(".scala") or f.endswith("scalastyle-config.xml") for f in changed_files
+ ):
run_scala_style_checks(extra_profiles)
- if not changed_files or any(f.endswith(".java")
- or f.endswith("checkstyle.xml")
- or f.endswith("checkstyle-suppressions.xml")
- for f in changed_files):
+ if not changed_files or any(
+ f.endswith(".java")
+ or f.endswith("checkstyle.xml")
+ or f.endswith("checkstyle-suppressions.xml")
+ for f in changed_files
+ ):
# Run SBT Checkstyle after the build to prevent a side-effect to the build.
should_run_java_style_checks = True
- if not changed_files or any(f.endswith("lint-python")
- or f.endswith("tox.ini")
- or f.endswith(".py")
- for f in changed_files):
+ if not changed_files or any(
+ f.endswith("lint-python") or f.endswith("tox.ini") or f.endswith(".py")
+ for f in changed_files
+ ):
run_python_style_checks()
- if not changed_files or any(f.endswith(".R")
- or f.endswith("lint-r")
- or f.endswith(".lintr")
- for f in changed_files):
+ if not changed_files or any(
+ f.endswith(".R") or f.endswith("lint-r") or f.endswith(".lintr") for f in changed_files
+ ):
run_sparkr_style_checks()
# determine if docs were changed and if we're inside the amplab environment
@@ -618,7 +657,8 @@ def main():
run_python_tests(
modules_with_python_tests,
opts.parallelism,
- with_coverage=os.environ.get("PYSPARK_CODECOV", "false") == "true")
+ with_coverage=os.environ.get("PYSPARK_CODECOV", "false") == "true",
+ )
run_python_packaging_tests()
if any(m.should_run_r_tests for m in test_modules):
run_sparkr_tests()
@@ -627,6 +667,7 @@ def main():
def _test():
import doctest
import sparktestsupport.utils
+
failure_count = doctest.testmod(sparktestsupport.utils)[0] + doctest.testmod()[0]
if failure_count:
sys.exit(-1)
diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py
index fa251a4..b4edb6b 100644
--- a/dev/sparktestsupport/__init__.py
+++ b/dev/sparktestsupport/__init__.py
@@ -35,5 +35,5 @@ ERROR_CODES = {
"BLOCK_BUILD_TESTS": 22,
"BLOCK_PYSPARK_PIP_TESTS": 23,
"BLOCK_SCALA_VERSION": 24,
- "BLOCK_TIMEOUT": 124
+ "BLOCK_TIMEOUT": 124,
}
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index e72635e..7cd5bd1 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -32,10 +32,20 @@ class Module(object):
files have changed.
"""
- def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(),
- environ=None, sbt_test_goals=(), python_test_goals=(),
- excluded_python_implementations=(), test_tags=(), should_run_r_tests=False,
- should_run_build_tests=False):
+ def __init__(
+ self,
+ name,
+ dependencies,
+ source_file_regexes,
+ build_profile_flags=(),
+ environ=None,
+ sbt_test_goals=(),
+ python_test_goals=(),
+ excluded_python_implementations=(),
+ test_tags=(),
+ should_run_r_tests=False,
+ should_run_build_tests=False,
+ ):
"""
Define a new module.
@@ -100,7 +110,7 @@ tags = Module(
dependencies=[],
source_file_regexes=[
"common/tags/",
- ]
+ ],
)
kvstore = Module(
@@ -178,9 +188,9 @@ catalyst = Module(
sbt_test_goals=[
"catalyst/test",
],
- environ=None if "GITHUB_ACTIONS" not in os.environ else {
- "ENABLE_DOCKER_INTEGRATION_TESTS": "1"
- },
+ environ=None
+ if "GITHUB_ACTIONS" not in os.environ
+ else {"ENABLE_DOCKER_INTEGRATION_TESTS": "1"},
)
sql = Module(
@@ -192,9 +202,9 @@ sql = Module(
sbt_test_goals=[
"sql/test",
],
- environ=None if "GITHUB_ACTIONS" not in os.environ else {
- "ENABLE_DOCKER_INTEGRATION_TESTS": "1"
- },
+ environ=None
+ if "GITHUB_ACTIONS" not in os.environ
+ else {"ENABLE_DOCKER_INTEGRATION_TESTS": "1"},
)
hive = Module(
@@ -210,9 +220,7 @@ hive = Module(
sbt_test_goals=[
"hive/test",
],
- test_tags=[
- "org.apache.spark.tags.ExtendedHiveTest"
- ]
+ test_tags=["org.apache.spark.tags.ExtendedHiveTest"],
)
repl = Module(
@@ -238,7 +246,7 @@ hive_thriftserver = Module(
],
sbt_test_goals=[
"hive-thriftserver/test",
- ]
+ ],
)
avro = Module(
@@ -249,7 +257,7 @@ avro = Module(
],
sbt_test_goals=[
"avro/test",
- ]
+ ],
)
sql_kafka = Module(
@@ -260,7 +268,7 @@ sql_kafka = Module(
],
sbt_test_goals=[
"sql-kafka-0-10/test",
- ]
+ ],
)
sketch = Module(
@@ -269,9 +277,7 @@ sketch = Module(
source_file_regexes=[
"common/sketch/",
],
- sbt_test_goals=[
- "sketch/test"
- ]
+ sbt_test_goals=["sketch/test"],
)
graphx = Module(
@@ -280,9 +286,7 @@ graphx = Module(
source_file_regexes=[
"graphx/",
],
- sbt_test_goals=[
- "graphx/test"
- ]
+ sbt_test_goals=["graphx/test"],
)
streaming = Module(
@@ -293,7 +297,7 @@ streaming = Module(
],
sbt_test_goals=[
"streaming/test",
- ]
+ ],
)
@@ -311,12 +315,10 @@ streaming_kinesis_asl = Module(
build_profile_flags=[
"-Pkinesis-asl",
],
- environ={
- "ENABLE_KINESIS_TESTS": "1"
- },
+ environ={"ENABLE_KINESIS_TESTS": "1"},
sbt_test_goals=[
"streaming-kinesis-asl/test",
- ]
+ ],
)
@@ -329,10 +331,7 @@ streaming_kafka_0_10 = Module(
"external/kafka-0-10-assembly",
"external/kafka-0-10-token-provider",
],
- sbt_test_goals=[
- "streaming-kafka-0-10/test",
- "token-provider-kafka-0-10/test"
- ]
+ sbt_test_goals=["streaming-kafka-0-10/test", "token-provider-kafka-0-10/test"],
)
@@ -344,7 +343,7 @@ mllib_local = Module(
],
sbt_test_goals=[
"mllib-local/test",
- ]
+ ],
)
@@ -357,7 +356,7 @@ mllib = Module(
],
sbt_test_goals=[
"mllib/test",
- ]
+ ],
)
@@ -369,15 +368,13 @@ examples = Module(
],
sbt_test_goals=[
"examples/test",
- ]
+ ],
)
pyspark_core = Module(
name="pyspark-core",
dependencies=[core],
- source_file_regexes=[
- "python/(?!pyspark/(ml|mllib|sql|streaming))"
- ],
+ source_file_regexes=["python/(?!pyspark/(ml|mllib|sql|streaming))"],
python_test_goals=[
# doctests
"pyspark.rdd",
@@ -407,15 +404,13 @@ pyspark_core = Module(
"pyspark.tests.test_taskcontext",
"pyspark.tests.test_util",
"pyspark.tests.test_worker",
- ]
+ ],
)
pyspark_sql = Module(
name="pyspark-sql",
dependencies=[pyspark_core, hive, avro],
- source_file_regexes=[
- "python/pyspark/sql"
- ],
+ source_file_regexes=["python/pyspark/sql"],
python_test_goals=[
# doctests
"pyspark.sql.types",
@@ -467,35 +462,25 @@ pyspark_sql = Module(
"pyspark.sql.tests.test_udf",
"pyspark.sql.tests.test_udf_profiler",
"pyspark.sql.tests.test_utils",
- ]
+ ],
)
pyspark_resource = Module(
name="pyspark-resource",
- dependencies=[
- pyspark_core
- ],
- source_file_regexes=[
- "python/pyspark/resource"
- ],
+ dependencies=[pyspark_core],
+ source_file_regexes=["python/pyspark/resource"],
python_test_goals=[
# unittests
"pyspark.resource.tests.test_resources",
- ]
+ ],
)
pyspark_streaming = Module(
name="pyspark-streaming",
- dependencies=[
- pyspark_core,
- streaming,
- streaming_kinesis_asl
- ],
- source_file_regexes=[
- "python/pyspark/streaming"
- ],
+ dependencies=[pyspark_core, streaming, streaming_kinesis_asl],
+ source_file_regexes=["python/pyspark/streaming"],
python_test_goals=[
# doctests
"pyspark.streaming.util",
@@ -504,16 +489,14 @@ pyspark_streaming = Module(
"pyspark.streaming.tests.test_dstream",
"pyspark.streaming.tests.test_kinesis",
"pyspark.streaming.tests.test_listener",
- ]
+ ],
)
pyspark_mllib = Module(
name="pyspark-mllib",
dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib],
- source_file_regexes=[
- "python/pyspark/mllib"
- ],
+ source_file_regexes=["python/pyspark/mllib"],
python_test_goals=[
# doctests
"pyspark.mllib.classification",
@@ -540,16 +523,14 @@ pyspark_mllib = Module(
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
- ]
+ ],
)
pyspark_ml = Module(
name="pyspark-ml",
dependencies=[pyspark_core, pyspark_mllib],
- source_file_regexes=[
- "python/pyspark/ml/"
- ],
+ source_file_regexes=["python/pyspark/ml/"],
python_test_goals=[
# doctests
"pyspark.ml.classification",
@@ -582,15 +563,13 @@ pyspark_ml = Module(
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
- ]
+ ],
)
pyspark_pandas = Module(
name="pyspark-pandas",
dependencies=[pyspark_core, pyspark_sql],
- source_file_regexes=[
- "python/pyspark/pandas/"
- ],
+ source_file_regexes=["python/pyspark/pandas/"],
python_test_goals=[
# doctests
"pyspark.pandas.accessors",
@@ -670,16 +649,14 @@ pyspark_pandas = Module(
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
- # they aren't available there
- ]
+ # they aren't available there
+ ],
)
pyspark_pandas_slow = Module(
name="pyspark-pandas-slow",
dependencies=[pyspark_core, pyspark_sql],
- source_file_regexes=[
- "python/pyspark/pandas/"
- ],
+ source_file_regexes=["python/pyspark/pandas/"],
python_test_goals=[
# doctests
"pyspark.pandas.frame",
@@ -699,7 +676,7 @@ pyspark_pandas_slow = Module(
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
# they aren't available there
- ]
+ ],
)
sparkr = Module(
@@ -708,7 +685,7 @@ sparkr = Module(
source_file_regexes=[
"R/",
],
- should_run_r_tests=True
+ should_run_r_tests=True,
)
@@ -717,7 +694,7 @@ docs = Module(
dependencies=[],
source_file_regexes=[
"docs/",
- ]
+ ],
)
build = Module(
@@ -727,7 +704,7 @@ build = Module(
".*pom.xml",
"dev/test-dependencies.sh",
],
- should_run_build_tests=True
+ should_run_build_tests=True,
)
yarn = Module(
@@ -742,9 +719,7 @@ yarn = Module(
"yarn/test",
"network-yarn/test",
],
- test_tags=[
- "org.apache.spark.tags.ExtendedYarnTest"
- ]
+ test_tags=["org.apache.spark.tags.ExtendedYarnTest"],
)
mesos = Module(
@@ -752,7 +727,7 @@ mesos = Module(
dependencies=[],
source_file_regexes=["resource-managers/mesos/"],
build_profile_flags=["-Pmesos"],
- sbt_test_goals=["mesos/test"]
+ sbt_test_goals=["mesos/test"],
)
kubernetes = Module(
@@ -760,7 +735,7 @@ kubernetes = Module(
dependencies=[],
source_file_regexes=["resource-managers/kubernetes"],
build_profile_flags=["-Pkubernetes"],
- sbt_test_goals=["kubernetes/test"]
+ sbt_test_goals=["kubernetes/test"],
)
hadoop_cloud = Module(
@@ -768,7 +743,7 @@ hadoop_cloud = Module(
dependencies=[],
source_file_regexes=["hadoop-cloud"],
build_profile_flags=["-Phadoop-cloud"],
- sbt_test_goals=["hadoop-cloud/test"]
+ sbt_test_goals=["hadoop-cloud/test"],
)
spark_ganglia_lgpl = Module(
@@ -777,7 +752,7 @@ spark_ganglia_lgpl = Module(
build_profile_flags=["-Pspark-ganglia-lgpl"],
source_file_regexes=[
"external/spark-ganglia-lgpl",
- ]
+ ],
)
docker_integration_tests = Module(
@@ -786,12 +761,10 @@ docker_integration_tests = Module(
build_profile_flags=["-Pdocker-integration-tests"],
source_file_regexes=["external/docker-integration-tests"],
sbt_test_goals=["docker-integration-tests/test"],
- environ=None if "GITHUB_ACTIONS" not in os.environ else {
- "ENABLE_DOCKER_INTEGRATION_TESTS": "1"
- },
- test_tags=[
- "org.apache.spark.tags.DockerTest"
- ]
+ environ=None
+ if "GITHUB_ACTIONS" not in os.environ
+ else {"ENABLE_DOCKER_INTEGRATION_TESTS": "1"},
+ test_tags=["org.apache.spark.tags.DockerTest"],
)
# The root module is a dummy module which is used to run all of the tests.
@@ -801,12 +774,13 @@ root = Module(
dependencies=[build, core], # Changes to build should trigger all tests.
source_file_regexes=[],
# In order to run all of the tests, enable every test profile:
- build_profile_flags=list(set(
- itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))),
+ build_profile_flags=list(
+ set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))
+ ),
sbt_test_goals=[
"test",
],
python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)),
should_run_r_tests=True,
- should_run_build_tests=True
+ should_run_build_tests=True,
)
diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py
index d9cb8aa4..1d40ae9 100644
--- a/dev/sparktestsupport/shellutils.py
+++ b/dev/sparktestsupport/shellutils.py
@@ -25,9 +25,9 @@ subprocess_check_output = subprocess.check_output
def exit_from_command_with_retcode(cmd, retcode):
if retcode < 0:
- print("[error] running", ' '.join(cmd), "; process was terminated by signal", -retcode)
+ print("[error] running", " ".join(cmd), "; process was terminated by signal", -retcode)
else:
- print("[error] running", ' '.join(cmd), "; received return code", retcode)
+ print("[error] running", " ".join(cmd), "; received return code", retcode)
sys.exit(int(os.environ.get("CURRENT_BLOCK", 255)))
@@ -53,7 +53,7 @@ def run_cmd(cmd, return_output=False):
cmd = cmd.split()
try:
if return_output:
- return subprocess_check_output(cmd).decode('utf-8')
+ return subprocess_check_output(cmd).decode("utf-8")
else:
return subprocess.run(cmd, universal_newlines=True, check=True)
except subprocess.CalledProcessError as e:
diff --git a/dev/sparktestsupport/toposort.py b/dev/sparktestsupport/toposort.py
index 6785e48..dcc8d4a 100644
--- a/dev/sparktestsupport/toposort.py
+++ b/dev/sparktestsupport/toposort.py
@@ -35,16 +35,15 @@
from functools import reduce as _reduce
-__all__ = ['toposort', 'toposort_flatten']
+__all__ = ["toposort", "toposort_flatten"]
def toposort(data):
"""Dependencies are expressed as a dictionary whose keys are items
-and whose values are a set of dependent items. Output is a list of
-sets in topological order. The first set consists of items with no
-dependencies, each subsequent set consists of items that depend upon
-items in the preceding sets.
-"""
+ and whose values are a set of dependent items. Output is a list of
+ sets in topological order. The first set consists of items with no
+ dependencies, each subsequent set consists of items that depend upon
+ items in the preceding sets."""
# Special case empty input.
if len(data) == 0:
@@ -65,18 +64,19 @@ items in the preceding sets.
if not ordered:
break
yield ordered
- data = {item: (dep - ordered)
- for item, dep in data.items()
- if item not in ordered}
+ data = {item: (dep - ordered) for item, dep in data.items() if item not in ordered}
if len(data) != 0:
- raise ValueError('Cyclic dependencies exist among these items: {}'.format(
- ', '.join(repr(x) for x in data.items())))
+ raise ValueError(
+ "Cyclic dependencies exist among these items: {}".format(
+ ", ".join(repr(x) for x in data.items())
+ )
+ )
def toposort_flatten(data, sort=True):
"""Returns a single list of dependencies. For any set returned by
-toposort(), those items are sorted and appended to the result (just to
-make the results deterministic)."""
+ toposort(), those items are sorted and appended to the result (just to
+ make the results deterministic)."""
result = []
for d in toposort(data):
diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py
index 1dccc9a..94928fa 100755
--- a/dev/sparktestsupport/utils.py
+++ b/dev/sparktestsupport/utils.py
@@ -77,13 +77,14 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe
raise AttributeError("must specify either target_branch or target_ref, not both")
if target_branch is not None:
diff_target = target_branch
- run_cmd(['git', 'fetch', 'origin', str(target_branch + ':' + target_branch)])
+ run_cmd(["git", "fetch", "origin", str(target_branch + ":" + target_branch)])
else:
diff_target = target_ref
- raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target],
- universal_newlines=True)
+ raw_output = subprocess.check_output(
+ ["git", "diff", "--name-only", patch_sha, diff_target], universal_newlines=True
+ )
# Remove any empty strings
- return [f for f in raw_output.split('\n') if f]
+ return [f for f in raw_output.split("\n") if f]
def determine_modules_to_test(changed_modules, deduplicated=True):
@@ -128,7 +129,8 @@ def determine_modules_to_test(changed_modules, deduplicated=True):
modules_to_test = set()
for module in changed_modules:
modules_to_test = modules_to_test.union(
- determine_modules_to_test(module.dependent_modules, deduplicated))
+ determine_modules_to_test(module.dependent_modules, deduplicated)
+ )
modules_to_test = modules_to_test.union(set(changed_modules))
if not deduplicated:
@@ -138,7 +140,8 @@ def determine_modules_to_test(changed_modules, deduplicated=True):
if modules.root in modules_to_test:
return [modules.root]
return toposort_flatten(
- {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True)
+ {m: set(m.dependencies).intersection(modules_to_test) for m in modules_to_test}, sort=True
+ )
def determine_tags_to_exclude(changed_modules):
@@ -151,6 +154,7 @@ def determine_tags_to_exclude(changed_modules):
def _test():
import doctest
+
failure_count = doctest.testmod()[0]
if failure_count:
sys.exit(-1)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org