diff --git a/nvcc4jupyter/parsers.py b/nvcc4jupyter/parsers.py index e94afce..76a0392 100644 --- a/nvcc4jupyter/parsers.py +++ b/nvcc4jupyter/parsers.py @@ -19,6 +19,7 @@ def get_parser_cuda() -> argparse.ArgumentParser: parser.add_argument("-t", "--timeit", action="store_true") parser.add_argument("-p", "--profile", action="store_true") parser.add_argument("-a", "--profiler-args", type=str, default="") + parser.add_argument("-c", "--compiler-args", type=str, default="") return parser diff --git a/nvcc4jupyter/plugin.py b/nvcc4jupyter/plugin.py index 269a2cc..413f836 100644 --- a/nvcc4jupyter/plugin.py +++ b/nvcc4jupyter/plugin.py @@ -87,7 +87,10 @@ class NVCCPlugin(Magics): shutil.rmtree(group_dirpath) def _compile( - self, group_name: str, executable_fname: str = DEFAULT_EXEC_FNAME + self, + group_name: str, + executable_fname: str = DEFAULT_EXEC_FNAME, + compiler_args: str = "", ) -> str: """ Compiles all source files in a given group together with all source @@ -97,6 +100,7 @@ class NVCCPlugin(Magics): group_name: The name of the source file group to be compiled. executable_fname: The output executable file name. Defaults to "cuda_exec.out". + compiler_args: The optional "nvcc" compiler arguments. Raises: RuntimeError: If the group does not exist or if does not have any @@ -121,18 +125,12 @@ class NVCCPlugin(Magics): executable_fpath = os.path.join(group_dirpath, executable_fname) - args = [ - "nvcc", - "-I" + shared_dirpath + "," + group_dirpath, - ] + args = ["nvcc"] + args.extend(compiler_args.split()) + args.append("-I" + shared_dirpath + "," + group_dirpath) args.extend(source_files) - args.extend( - [ - "-o", - executable_fpath, - "-Wno-deprecated-gpu-targets", - ] - ) + args.extend(["-o", executable_fpath, "-Wno-deprecated-gpu-targets"]) + subprocess.check_output(args, stderr=subprocess.STDOUT) return executable_fpath @@ -188,7 +186,10 @@ class NVCCPlugin(Magics): self, group_name: str, args: argparse.Namespace ) -> str: try: - exec_fpath = self._compile(group_name) + exec_fpath = self._compile( + group_name=group_name, + compiler_args=args.compiler_args, + ) output = self._run( exec_fpath=exec_fpath, timeit=args.timeit,