This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 751f83b  Auto-discover C/C++ compiler instead of hardcoding g++ 
(#10007)
751f83b is described below

commit 751f83b56576b04cc4c3957311f693ef2c2119f2
Author: Krzysztof Parzyszek <[email protected]>
AuthorDate: Fri Jan 21 15:25:08 2022 -0600

    Auto-discover C/C++ compiler instead of hardcoding g++ (#10007)
    
    Some platforms (e.g. FreeBSD) use clang as the default OS compiler,
    and there is no g++.
---
 python/tvm/contrib/cc.py     | 57 +++++++++++++++++++++++++++++++++-----------
 python/tvm/rpc/server.py     |  8 ++-----
 python/tvm/runtime/module.py |  7 ++----
 3 files changed, 47 insertions(+), 25 deletions(-)

diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py
index 64cbbd2..867cbd6 100644
--- a/python/tvm/contrib/cc.py
+++ b/python/tvm/contrib/cc.py
@@ -23,7 +23,40 @@ import subprocess
 from .._ffi.base import py_str
 
 
-def create_shared(output, objects, options=None, cc="g++"):
+def _is_linux_like():
+    return (
+        sys.platform == "darwin"
+        or sys.platform.startswith("linux")
+        or sys.platform.startswith("freebsd")
+    )
+
+
+def get_cc():
+    """Return the path to the default C/C++ compiler.
+
+    Returns
+    -------
+    out: Optional[str]
+        The path to the default C/C++ compiler, or None if none was found.
+    """
+
+    if not _is_linux_like():
+        return None
+
+    env_cxx = os.environ.get("CXX") or os.environ.get("CC")
+    if env_cxx:
+        return env_cxx
+    cc_names = ["g++", "gcc", "clang++", "clang", "c++", "cc"]
+    dirs_in_path = os.get_exec_path()
+    for cc in cc_names:
+        for d in dirs_in_path:
+            cc_path = os.path.join(d, cc)
+            if os.path.isfile(cc_path) and os.access(cc_path, os.X_OK):
+                return cc_path
+    return None
+
+
+def create_shared(output, objects, options=None, cc=None):
     """Create shared library.
 
     Parameters
@@ -40,11 +73,9 @@ def create_shared(output, objects, options=None, cc="g++"):
     cc : Optional[str]
         The compiler command.
     """
-    if (
-        sys.platform == "darwin"
-        or sys.platform.startswith("linux")
-        or sys.platform.startswith("freebsd")
-    ):
+    cc = cc or get_cc()
+
+    if _is_linux_like():
         _linux_compile(output, objects, options, cc, compile_shared=True)
     elif sys.platform == "win32":
         _windows_compile(output, objects, options)
@@ -52,7 +83,7 @@ def create_shared(output, objects, options=None, cc="g++"):
         raise ValueError("Unsupported platform")
 
 
-def create_executable(output, objects, options=None, cc="g++"):
+def create_executable(output, objects, options=None, cc=None):
     """Create executable binary.
 
     Parameters
@@ -69,7 +100,9 @@ def create_executable(output, objects, options=None, 
cc="g++"):
     cc : Optional[str]
         The compiler command.
     """
-    if sys.platform == "darwin" or sys.platform.startswith("linux"):
+    cc = cc or get_cc()
+
+    if _is_linux_like():
         _linux_compile(output, objects, options, cc)
     elif sys.platform == "win32":
         _windows_compile(output, objects, options)
@@ -109,11 +142,7 @@ def get_target_by_dump_machine(compiler):
 
 # assign so as default output format
 create_shared.output_format = "so" if sys.platform != "win32" else "dll"
-create_shared.get_target_triple = get_target_by_dump_machine(
-    os.environ.get(
-        "CXX", "g++" if sys.platform == "darwin" or 
sys.platform.startswith("linux") else None
-    )
-)
+create_shared.get_target_triple = 
get_target_by_dump_machine(os.environ.get("CXX", get_cc()))
 
 
 def cross_compiler(
@@ -190,7 +219,7 @@ def cross_compiler(
     return _fcompile
 
 
-def _linux_compile(output, objects, options, compile_cmd="g++", 
compile_shared=False):
+def _linux_compile(output, objects, options, compile_cmd, 
compile_shared=False):
     cmd = [compile_cmd]
     if compile_cmd != "nvcc":
         if compile_shared or output.endswith(".so") or 
output.endswith(".dylib"):
diff --git a/python/tvm/rpc/server.py b/python/tvm/rpc/server.py
index 52a7a89..aa8e042 100644
--- a/python/tvm/rpc/server.py
+++ b/python/tvm/rpc/server.py
@@ -25,7 +25,6 @@ Server is TCP based with the following protocol:
    - {server|client}:device-type[:random-key] [-timeout=timeout]
 """
 # pylint: disable=invalid-name
-import os
 import ctypes
 import socket
 import select
@@ -75,9 +74,6 @@ def _server_env(load_library, work_path=None):
     @tvm._ffi.register_func("tvm.rpc.server.download_linked_module", 
override=True)
     def download_linked_module(file_name):
         """Load module from remote side."""
-        # c++ compiler/linker
-        cc = os.environ.get("CXX", "g++")
-
         # pylint: disable=import-outside-toplevel
         path = temp.relpath(file_name)
 
@@ -85,7 +81,7 @@ def _server_env(load_library, work_path=None):
             # Extra dependencies during runtime.
             from tvm.contrib import cc as _cc
 
-            _cc.create_shared(path + ".so", path, cc=cc)
+            _cc.create_shared(path + ".so", path)
             path += ".so"
         elif path.endswith(".tar"):
             # Extra dependencies during runtime.
@@ -94,7 +90,7 @@ def _server_env(load_library, work_path=None):
             tar_temp = utils.tempdir(custom_path=path.replace(".tar", ""))
             _tar.untar(path, tar_temp.temp_dir)
             files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
-            _cc.create_shared(path + ".so", files, cc=cc)
+            _cc.create_shared(path + ".so", files)
             path += ".so"
         elif path.endswith(".dylib") or path.endswith(".so"):
             pass
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 37bab4a..da7c52a 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -529,16 +529,13 @@ def load_module(path, fmt=""):
     else:
         raise ValueError("cannot find file %s" % path)
 
-    # c++ compiler/linker
-    cc = os.environ.get("CXX", "g++")
-
     # High level handling for .o and .tar file.
     # We support this to be consistent with RPC module load.
     if path.endswith(".o"):
         # Extra dependencies during runtime.
         from tvm.contrib import cc as _cc
 
-        _cc.create_shared(path + ".so", path, cc=cc)
+        _cc.create_shared(path + ".so", path)
         path += ".so"
     elif path.endswith(".tar"):
         # Extra dependencies during runtime.
@@ -547,7 +544,7 @@ def load_module(path, fmt=""):
         tar_temp = _utils.tempdir(custom_path=path.replace(".tar", ""))
         _tar.untar(path, tar_temp.temp_dir)
         files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
-        _cc.create_shared(path + ".so", files, cc=cc)
+        _cc.create_shared(path + ".so", files)
         path += ".so"
     # Redirect to the load API
     return _ffi_api.ModuleLoadFromFile(path, fmt)

Reply via email to