sandeep-krishnamurthy closed pull request #11991: [MXNET-644] Automated flaky 
test detection
URL: https://github.com/apache/incubator-mxnet/pull/11991
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/.gitattributes b/.gitattributes
index e577ab3c116..2c975abf8ea 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,2 +1,3 @@
 .gitattributes export-ignore
 R-package/* export-ignore
+*.py  diff=python
diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh
index 1c861beb916..134d016b790 100755
--- a/ci/docker/runtime_functions.sh
+++ b/ci/docker/runtime_functions.sh
@@ -924,6 +924,17 @@ nightly_straight_dope_python3_multi_gpu_tests() {
       test_notebooks_multi_gpu.py --nologcapture
 }
 
+# check commit for flaky tests
+flaky_check_select_tests(){
+    set -ex
+    tools/flaky_tests/test_selector.py -b HEAD~1 HEAD
+}
+flaky_check_run_flakiness_checker(){
+    set -ex
+    export PYTHONPATH=./python/
+    tools/flaky_tests/check_tests.py
+}
+
 # Deploy
 
 deploy_docs() {
diff --git a/tools/flaky_tests/Jenkinsfile b/tools/flaky_tests/Jenkinsfile
new file mode 100644
index 00000000000..debb55a84a1
--- /dev/null
+++ b/tools/flaky_tests/Jenkinsfile
@@ -0,0 +1,73 @@
+// -*- mode: groovy -*-
+
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+err = null
+mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a'
+
+node('mxnetlinux-cpu') {
+  // Loading the utilities requires a node context unfortunately
+  checkout scm
+  utils = load('ci/Jenkinsfile_utils.groovy')
+}
+utils.assign_node_labels(linux_gpu: 'mxnetlinux-gpu', linux_cpu: 
'mxnetlinux-cpu')
+tests = false
+
+utils.main_wrapper(
+core_logic: {
+    stage('Preprocessing'){
+        node(NODE_LINUX_CPU){
+            ws('workspace/fc-preprocessing'){
+                utils.init_git()
+                utils.docker_run('ubuntu_cpu', 'flaky_check_select_tests', 
false)
+                tests = fileExists('tests.tmp')
+                stash name:'flaky_tests', includes:'tests.tmp' 
+            }
+        }
+    }
+
+    // only continue if some tests were selected
+    if (!tests) {
+        currentBuild.result = 'SUCCESS'
+    }
+    else {
+        stage('Compilation') {
+            node(NODE_LINUX_CPU) {
+                ws('workspace/fc-compilation') {
+                    utils.init_git()
+                    utils.docker_run('ubuntu_build_cuda', 
'build_ubuntu_gpu_cuda91_cudnn7', false)
+                    utils.pack_lib('gpu', mx_lib)
+                }
+            }
+        }
+        stage('Flakiness Check') {
+            node(NODE_LINUX_GPU) {
+                ws('workspace/fc-execution') {
+                    utils.init_git()
+                    unstash 'flaky_tests'
+                    utils.unpack_lib('gpu', mx_lib)
+                    utils.docker_run('ubuntu_gpu', 
'flaky_check_run_flakiness_checker', false)
+                }
+            }
+        }
+    }
+}
+,
+failure_handler: {
+}
+)
diff --git a/tools/flaky_tests/check_tests.py b/tools/flaky_tests/check_tests.py
new file mode 100755
index 00000000000..4aea8bba6fb
--- /dev/null
+++ b/tools/flaky_tests/check_tests.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Checks selected tests for flakiness
+
+This script is used for automated flaky test detection. It reads the
+tests listed in tests.tmp and runs them a large number of times.
+Currently, each test is given an equal number of runs
+such that all tests being checked can be run within the time budget.
+"""
+import logging
+import time
+import os
+import sys
+
+import flakiness_checker
+
+LOGGING_FILE = os.path.join(os.path.dirname(__file__), "results.log")
+TESTS_DIRECTORY = "tests/python"
+TESTS_FILE = "tests.tmp"
+PROFILING_TRIALS = 10
+TIME_BUDGET = 5400 
+
+logger = logging.getLogger(__name__)
+
+fh = logging.FileHandler(LOGGING_FILE)
+fh.setLevel(logging.INFO)
+fh.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))
+logger.addHandler(fh)
+
+
+def calculate_test_trials(tests):
+    """Calculate the number of times each test should be run
+    
+    Currently, each test is run the same number of times, where the
+    number is based on the time it takes to run each test once.
+    """
+    def time_test(test):
+        start = time.time()
+        flakiness_checker.run_test_trials(
+            test[0], test[1], PROFILING_TRIALS + 1)
+        end = time.time()
+        profile_time = end - start
+
+        start = time.time()
+        flakiness_checker.run_test_trials(test[0], test[1], 1)
+        end = time.time()
+        setup_time = end - start
+
+        actual_time = profile_time - setup_time
+        return actual_time / PROFILING_TRIALS
+
+    total_time = 0.0
+    for t in tests:
+        total_time += time_test(t)
+
+    try:
+        n = int(TIME_BUDGET / total_time)
+    except ZeroDivisionError:
+        logger.error("Total time for tests was 0")
+        return []
+    
+    logger.debug("total_time: %f | num_trials: %d", total_time, n)
+    return [(t, n) for t in tests]
+
+
+def check_tests(tests):
+    """Check given tests for flakiness"""
+    flaky, nonflaky = [], []
+    tests = calculate_test_trials(tests)
+
+    for t, n in tests:
+        res = flakiness_checker.run_test_trials(t[0], t[1], n)
+
+        if res != 0:
+            flaky.append(t)
+        else:
+            nonflaky.append(t)
+    
+    return flaky, nonflaky
+
+
+def output_results(flaky, nonflaky):
+    logger.info("Following tests failed flakiness checker:")
+    if not flaky:
+        logger.info("None")
+    for test in flaky:
+        logger.info("%s:%s", test[0], test[1])
+
+    logger.info("Following tests passed flakiness checker:")
+    if not nonflaky:
+        logger.info("None")
+    for test in nonflaky:
+        logger.info("%s:%s", test[0], test[1])
+
+    logger.info("[Results]\tTotal: %d\tFlaky: %d\tNon-flaky: %d",
+                len(flaky) + len(nonflaky), len(flaky), len(nonflaky))
+
+
+if __name__ == "__main__":
+    logging.basicConfig(level=logging.INFO)
+
+    tests = []
+    with open(TESTS_FILE) as f:
+        for line in f.readlines():
+            test = line.split(":")
+            tests.append((test[0], test[1]))
+    
+    os.remove(TESTS_FILE)
+    
+    flaky, nonflaky = check_tests(tests)
+    output_results(flaky, nonflaky)
+
+    if flaky:
+        sys.exit(1)
diff --git a/tools/flaky_tests/dependency_analyzer.py 
b/tools/flaky_tests/dependency_analyzer.py
new file mode 100755
index 00000000000..f1954528146
--- /dev/null
+++ b/tools/flaky_tests/dependency_analyzer.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Output dependent functions given a list of function dependnecies
+
+This module searches the given directory or file for functions that are
+dependent on the given list of functions. The current directory is used
+if none is provided. This script is designed only for python files--
+it uses python's ast module parse files and find function calls.
+The function calls are then compared to the list of dependencies
+and if there is a match, the top-level function name is added to
+the set of dependent functions. 
+
+Cross-file dependencies are handled by storing them in a json file, 
+called config.json. Each test file with outside dependecies 
+should be listed in this file along with a list of its dependncies.
+Currently this file is updated manually.
+"""
+import sys
+import os
+import argparse
+import ast
+import logging
+import re
+import itertools
+import json
+import io
+
+DEFAULT_CONFIG_FILE = os.path.join(
+    os.path.dirname(__file__), "test_dependencies.config")
+
+logger = logging.getLogger(__name__)
+
+
+def read_config(filename):
+    """Reads cross-file dependencies from json file"""
+    with open(filename) as f:
+        return json.load(f)
+
+def find_dependents(dependencies, top):
+    top = os.path.abspath(top)
+    dependents = {}
+
+    for filename in dependencies.keys():
+        funcs = dependencies[filename]
+        abs_path = os.path.join(top, filename)
+        deps = find_dependents_file(set(funcs), abs_path)
+        dependents[filename] = deps
+
+    try:
+        file_deps = read_config(DEFAULT_CONFIG_FILE)
+    except IOError:
+        file_deps = {}
+        logger.WARNING("No config file found, "
+            "continuing with no file dependencies")
+
+    for filename in list(dependents.keys()):
+        if filename in file_deps:
+            for dependent in file_deps[filename]:
+                dependents[dependent] = dependents[filename]
+
+    return dependents
+
+
+
+def find_dependents_file(dependencies, filename):
+    """Recursively search a file for dependent functions"""
+    class CallVisitor(ast.NodeVisitor):
+        def visit_Name(self, node):
+            return node.id
+
+        def visit_Attribute(self, node):
+            try:
+                return "{}.{}".format(node.value.id, node.attr)
+            except AttributeError:
+                return "{}.{}".format(self.generic_visit(node), node.attr)
+
+    if not dependencies:
+        return set()
+
+    if os.path.splitext(filename)[1] !=".py":
+        logger.debug("Skipping non-python file: %s", filename)
+        return set()
+
+    with io.open(filename, encoding="utf-8") as f:
+        tree = ast.parse(f.read())
+    logger.debug("seaching: %s", filename)
+
+    dependents = set()
+    cv = CallVisitor()
+
+    for t in tree.body:     # search for function calls matching dependencies
+        if isinstance(t, ast.FunctionDef):
+            name = t.name
+            if name in dependencies:
+                dependents.add(name)
+        else:
+            name = "top-level"
+
+        for n in ast.walk(t):
+            if isinstance(n, ast.Call):
+                func = cv.visit(n.func)
+                if func in dependencies:
+                    dependents.add(name)
+
+    try:
+        dependents |= find_dependents_file(dependents - dependencies, filename)
+    except RuntimeError as re:
+        logger.error("Encountered recursion error when seaching %s: %s",
+                     filename, re.args[0])
+
+    return dependents
+
+
+def output_results(dependents):
+    logger.info("Dependencies:")
+    for filename in dependents.keys():
+        logger.info(filename)
+        if not dependents[filename]:
+            logger.info("None")
+            continue
+        for func in dependents[filename]:
+            logger.info("\t%s", func)
+
+
+def parse_args():
+    class DependencyAction(argparse.Action):
+        def __call__(self, parser, namespace, values, option_string=None):
+            setattr(namespace, "dependencies", {})
+            for v in values:
+                dep = v.split(":")
+                if len(dep) != 2:
+                    raise ValueError("Invalid format for dependency " + v +
+                                     "Format: <file>:<func-name>.)")
+                try:
+                    namespace.dependencies[dep[0]].append(dep[1])
+                except KeyError:
+                    namespace.dependencies[dep[0]] = [dep[1]]
+
+    arg_parser = argparse.ArgumentParser()
+    arg_parser.add_argument(
+        "dependencies", nargs="+", action=DependencyAction,
+        help="list of dependent functions, "
+        "in the format: <file>:<func_name>")
+
+    arg_parser.add_argument(
+        "--logging-level", "-l", dest="level", default="INFO",
+        help="logging level, defaults to INFO")
+
+    arg_parser.add_argument(
+        "--path", "-p", default=".",
+        help="directory in which given files are located")
+
+    args = arg_parser.parse_args()
+    return args
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    try:
+        logging.basicConfig(level=getattr(logging, args.level))
+    except AttributeError:
+        logging.basicConfig(level=logging.INFO)
+        logging.warning("Invalid logging level: %s", args.level)
+    logger.debug("args: %s", args)
+
+    dependents = find_dependents(args.dependencies, args.path)
+    output_results(dependents)
diff --git a/tools/flaky_tests/diff_collator.py 
b/tools/flaky_tests/diff_collator.py
new file mode 100755
index 00000000000..1ae5a346d34
--- /dev/null
+++ b/tools/flaky_tests/diff_collator.py
@@ -0,0 +1,218 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Output a list of differences between current branch and master
+
+Precondition: this script is run inside an existing git repository
+
+This script first retrieves the raw output from git diff. By default,
+the current and master branches are used as targets for git diff,
+but the user may specify their own targets. Then, the raw output is 
+parsed to retrieve info about each of the changes between the targets, 
+including file name, top-level funtion name, and line numbers. 
+Finally, the list of changes is outputted.
+"""
+
+import os
+import subprocess
+import sys
+import re
+import argparse
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+def get_diff_output(args):
+    """Perform a git diff using provided args"""
+    diff_cmd = ["git", "diff", "--unified=0"]
+    if args.commits is not None:
+        diff_cmd.extend([args.commits[0], args.commits[1]])
+    else:
+        if args.branches is None:            
+            args.branches = ["master", "HEAD"]
+        diff_target = args.branches[0] + "..." + args.branches[1]
+        diff_cmd.append(diff_target)
+
+    if args.path:
+        diff_cmd.extend(["--", args.path])
+
+    logger.debug("Command: %s", diff_cmd)
+    try:
+        return subprocess.check_output(diff_cmd)
+    except subprocess.CalledProcessError as e:
+        logger.error("git diff returned a non zero exit code: %d",
+                      e.returncode)
+        sys.exit(1)
+
+
+def parser(diff_output):
+    """Split diff output into patches and parse each indiviudally"""
+    diff_output = diff_output.decode("utf-8")
+    top = subprocess.check_output(["git","rev-parse", "--top-level"])
+    top = top.decode("utf-8")
+    changes = {}
+
+    for patch in diff_output.split("diff --git")[1:]:
+        file_name, cs = parse_patch(patch)
+        if not cs:
+            continue
+        changes[file_name] = cs
+    
+    return changes
+
+
+def parse_patch(patch):
+    """ Parse changes in a single patch
+
+    Git diff outputs results as patches, each of which corresponds 
+    to a single file that has been changed. Each patch consists of 
+    a header and one or more hunks that show differing lines between 
+    files versions. Hunks themselves have headers, which include 
+    line numbers changed and function names.
+    """
+    lines = patch.splitlines()
+    file_name  = lines[0].split()[-1][2:]
+    changes = {}
+    
+    logger.debug("Parsing: %s", file_name)
+    for line in patch.splitlines():
+        # parse hunk header
+        if line.startswith("@"):
+            tokens = line.split()
+            to_range = []
+            start = 0
+            end = 0
+            
+            # Get line numbers
+            for t in tokens[1:]:
+                if t.startswith("@"):
+                    start = int(to_range[0])
+                    try:
+                        end = start + int(to_range[1])
+                    except IndexError:
+                        end = start
+                else:
+                    to_range = t[1:].split(",")
+
+            # Get function name
+            try:
+                hunk_name = tokens[tokens.index("def") + 1].split("(")[0]
+            except ValueError:
+                hunk_name = "top-level"
+            logger.debug("\tHunk: %s - (%d,%d)", hunk_name, start, end)
+
+            # Add hunk info to changes
+            if hunk_name not in changes:
+                changes[hunk_name] = []
+            changes[hunk_name].append((start, end))
+
+        # newly defined top-level function
+        if line.startswith("+def "):
+            func_name = line.split()[1].split("(")[0]
+            changes[func_name] = []
+            logger.debug("\tFound new top-level function: %s", func_name)
+
+    return file_name, changes
+
+
+def output_changes(changes, verbosity=2):
+    """ Output changes in an easy to understand format
+    
+    Three verbosity levels: 
+    1 - only file names, 
+    2- file and functions names,
+    3- file and function names and line numbers.
+
+    Example (verbosity 3):
+    file1
+        func_a
+            1:2
+            3:4
+        func_b
+            5:5
+        func_c
+    """
+    logger.debug("verbosity: %d", verbosity)
+
+    if not changes:
+        logger.info("No changes found")
+    else:
+        for file_name, chunks in changes.items():    
+            logger.info(file_name)
+            if verbosity < 2:
+                continue
+            for func_name, ranges in chunks.items():
+                logger.info("\t%s", func_name)
+                if verbosity < 3:
+                    continue
+                for (start, end) in ranges:
+                    logger.info("\t\t%s:%s", start, end)
+
+    
+
+def parse_args():
+    arg_parser = argparse.ArgumentParser()
+
+    arg_parser.add_argument(
+        "--verbosity", "-v", action="count", default=2,
+        help="verbosity level, repeat up to 3 times, defaults to 2")
+    arg_parser.add_argument(
+        "--logging-level", "-l", dest="level", default="INFO",
+        help="logging level, defaults to INFO")
+
+    targets = arg_parser.add_mutually_exclusive_group()
+    targets.add_argument(
+        "--commits", "-c", nargs=2, metavar=("HASH1 ","HASH2"),
+        help="specifies two commits to be compared")
+    targets.add_argument(
+        "--branches", "-b", nargs=2, metavar=("MASTER", "TOPIC"),
+        help="specifies two branches to be compared")
+
+    filters = arg_parser.add_argument_group(
+        "filters", "filter which files should be included in output")
+    filters.add_argument(
+        "--filter-path", "-p", dest="path", 
+        help="specify directory or file in which to search for changes")
+    filters.add_argument(
+        "--filter", "-f", dest="expr", metavar="REGEX", default=".*",
+        help="filter files with given python regular expression")
+    
+    args = arg_parser.parse_args()
+    return args
+
+
+if __name__ == "__main__":
+    args = parse_args()
+    
+    try:
+        logging.basicConfig(level=getattr(logging, args.level))
+    except AttributeError:
+        logging.basicConfig(level=logging.INFO)
+        logging.warning("Invalid logging level: %s", args.level)
+    logging.debug("args: %s", args)
+
+    diff_output = get_diff_output(args)
+
+    changes = parser(diff_output)
+    for file_name, chunks in changes.items():
+        if not re.match(args.expr, file_name):
+            del changes[file_name]
+
+    output_changes(changes, args.verbosity)
diff --git a/tools/flakiness_checker.py b/tools/flaky_tests/flakiness_checker.py
old mode 100644
new mode 100755
similarity index 50%
rename from tools/flakiness_checker.py
rename to tools/flaky_tests/flakiness_checker.py
index 79fa3b1854f..51e10b5e9c0
--- a/tools/flakiness_checker.py
+++ b/tools/flaky_tests/flakiness_checker.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,6 +17,7 @@
 # under the License.
 
 """ Checks a given test for flakiness
+
 Takes the file name and function name of a test, as well as, optionally,
 the number of trials to run and the random seed to use
 """
@@ -27,80 +29,88 @@
 import argparse
 import re
 import logging
+import nose
 
-logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
 
 DEFAULT_NUM_TRIALS = 10000
-DEFAULT_VERBOSITY = 2
 
-def run_test_trials(args):
-    test_path = args.test_path + ":" + args.test_name
-    logging.info("Testing: %s", test_path)
+
+def run_test_trials(test_file, test_name, num_trials, seed=None, args=None):
+    test_path = "{}:{}".format(test_file, test_name)
+    logger.info("Testing-- %s", test_path)
+    logger.info("Setting MXNET_TEST_COUNT to %d", num_trials)
 
     new_env = os.environ.copy()
-    new_env["MXNET_TEST_COUNT"] = str(args.num_trials)
-    
-    if args.seed is None:
-        logging.info("No test seed provided, using random seed")
+    new_env["MXNET_TEST_COUNT"] = str(num_trials)
+    if seed is None:
+        logger.info("No test seed provided, using random seed")
     else:
-        new_env["MXNET_TEST_SEED"] = str(args.seed)
+        new_env["MXNET_TEST_SEED"] = str(seed)
 
-    verbosity = "--verbosity=" + str(args.verbosity)
+    command = ["nosetests", test_path]
+    if args:
+        command.extend(args)
+
+    logger.debug("Nosetests command: %s", command)
+    return subprocess.call(command, env = new_env)
 
-    code = subprocess.call(["nosetests", verbosity, test_path], 
-                           env = new_env)
-    
-    logging.info("Nosetests terminated with exit code %d", code)
-
-def find_test_path(test_file):
-    """Searches for the test file and returns the path if found
-    As a default, the currend working directory is the top of the search.
-    If a directory was provided as part of the argument, the directory will be
-    joined with cwd unless it was an absolute path, in which case, the
-    absolute path will be used instead. 
-    """
-    test_file += ".py"
-    test_path = os.path.split(test_file)
-    top = os.path.join(os.getcwd(), test_path[0])
-
-    for (path, dirs, files) in os.walk(top):
-        if test_path[1] in files:
-            return  os.path.join(path, test_path[1])
-    raise FileNotFoundError("Could not find " + test_path[1] + 
-                            "in directory: " + top)
 
 class NameAction(argparse.Action):
     """Parses command line argument to get test file and test name"""
+    def find_test_path(self, test_file):
+        """Searches for the test file and returns the path if found.
+
+        As a default, the current working directory is used as the top
+        of the search. If a directory was provided, the directory 
+        will be joined with cwd, unless it was an absolute path, 
+        in which case, the absolute path will be used instead. 
+        """
+        test_file += ".py"
+        test_path = os.path.split(test_file)
+        top = os.path.join(os.getcwd(), test_path[0])
+
+        for (path, _ , files) in os.walk(top):
+            if test_path[1] in files:
+                return  os.path.join(path, test_path[1])
+
+        raise IOError("Could not find {} in directory: {}".format(
+            test_path[1],  top))
+
     def __call__(self, parser, namespace, values, option_string=None):
         name = re.split("\.py:|\.", values)
         if len(name) != 2:
+            logger.error("Invalid test specifier: %s", name)
             raise ValueError("Invalid argument format for test. Format: "
                              "<file-name>.<test-name> or"
                              " <directory>/<file>:<test-name>")
-        setattr(namespace, "test_path", find_test_path(name[0]))
+
+        setattr(namespace, "test_path", self.find_test_path(name[0]))
         setattr(namespace, "test_name", name[1])
 
+
 def parse_args():
     parser = argparse.ArgumentParser(description="Check test for flakiness")
     
-    parser.add_argument("test", action=NameAction,
-                        help="file name and and function name of test, "
-                        "provided in the format: <file-name>.<test-name> "
-                        "or <directory>/<file>:<test-name>")
-    
+    parser.add_argument("--logging-level", "-l", dest="level", default="INFO",
+                        help="set logging level, defaults to INFO")
+
     parser.add_argument("-n", "--num-trials", metavar="N",
                         default=DEFAULT_NUM_TRIALS, type=int,
                         help="number of test trials, passed as "
-                        "MXNET_TEST_COUNT, defaults to 500")
+                        "MXNET_TEST_COUNT, defaults to 10000")
 
-    parser.add_argument("-s", "--seed", type=int,
+    parser.add_argument("--seed", type=int,
                         help="test seed, passed as MXNET_TEST_SEED, "
                         "defaults to random seed") 
 
-    parser.add_argument("-v", "--verbosity",
-                        default=DEFAULT_VERBOSITY, type=int,
-                        help="logging level, passed to nosetests")
+    parser.add_argument("test", action=NameAction,
+                        help="file name and and function name of test, "
+                        "provided in the format: <file-name>.<test-name> "
+                        "or <directory>/<file>:<test-name>")
 
+    parser.add_argument("args", nargs=argparse.REMAINDER,
+                        help="args to pass to nosetests")
 
     args = parser.parse_args()
     return args
@@ -108,5 +118,13 @@ def parse_args():
 
 if __name__ == "__main__":
     args = parse_args()
-
-    run_test_trials(args)
+    try:
+        logging.basicConfig(level=getattr(logging, args.level))
+    except AttributeError:
+        logging.basicConfig(level=logging.INFO)
+        logging.warning("Invalid logging level: %s", args.level)
+    logger.debug("args: %s", args)
+
+    code = run_test_trials(args.test_path, args.test_name, args.num_trials,
+                           args.seed, args.args)
+    logger.info("Nosetests terminated with exit code %d", code)
diff --git a/tools/flaky_tests/test_dependencies.config 
b/tools/flaky_tests/test_dependencies.config
new file mode 100644
index 00000000000..58c0a87dcfb
--- /dev/null
+++ b/tools/flaky_tests/test_dependencies.config
@@ -0,0 +1,22 @@
+{
+    "tests/python/gpu/test_operator_gpu.py": [
+        "test_operator.py",
+        "test_optimizer.py",
+        "test_random.py",
+        "test_exc_handling.py",
+        "test_sparse_ndarray.py",
+        "test_sparse_operator.py",
+        "test_ndarray.py"
+    ],
+    "tests/python/gpu/test_gluon_gpu.py": [
+        "test_gluon.py",
+        "test_loss.py",
+        "test_gluon_rnn.py"
+    ],
+    "tests/python/mkl/test_quantization_mkldnn.py": [
+        "test_quantization.py"
+    ],
+    "tests/python/quantization_gpu/test_quantization_gpu.py": [
+        "test_quantization.py"
+    ]
+}
diff --git a/tools/flaky_tests/test_selector.py 
b/tools/flaky_tests/test_selector.py
new file mode 100755
index 00000000000..f04dd2115b6
--- /dev/null
+++ b/tools/flaky_tests/test_selector.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+""" Checks the current branch for changes affecting tests
+
+This script is used for automated flaky test detection, when changes
+are detected that affect a test, this script will list all
+affected affected tests in a file called tests.tmp, which will
+be read by the check_tests.py script to check them for flakiness.
+"""
+
+import logging
+import subprocess
+import sys
+
+import diff_collator
+import dependency_analyzer
+
+logger = logging.getLogger(__name__)
+TEST_PREFIX = "test_"
+TESTS_FILE = "tests.tmp"
+
+def select_tests(changes):
+    """returns tests that are dependent on given changes
+
+    All python unit tests are top-level function with the prefix 
+    "test_" in the function name. To get all tests, we simply 
+    filter our changes by this prefix, stored in TEST_PREFIX.
+    """
+    top = subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
+    top = top.decode("utf-8").splitlines()[0]
+    deps = dependency_analyzer.find_dependents(changes, top)
+
+    return [(filename, test) 
+            for filename in deps.keys() 
+            for test in deps[filename] 
+            if test.startswith(TEST_PREFIX)]
+
+def output_tests(tests):
+    if not tests:
+        return 1
+    
+    with open(TESTS_FILE, "w+") as f:
+        for filename, testcase in tests:
+            f.write("{}:{}\n".format(filename, testcase))
+    
+    return 0
+
+
+if __name__ == "__main__":
+    args = diff_collator.parse_args()
+    try:
+        logging.basicConfig(level=getattr(logging, args.level))
+    except AttributeError:
+        logging.basicConfig(level=logging.INFO)
+        logger.warning("Invalid logging level: %s", args.level)
+
+    diff_output = diff_collator.get_diff_output(args)
+    changes = diff_collator.parser(diff_output)
+    diff_collator.output_changes(changes)
+
+    changes = {k:set(v.keys()) for k, v in  changes.items()}
+    tests = select_tests(changes)
+    sys.exit(output_tests(tests))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to