Skip to content

Commit

Permalink
Improve Windows Compatibility (for csrc/scripts) (pytorch#2941)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterjc123 authored and apaszke committed Nov 8, 2017
1 parent 1d57a2d commit aa91193
Show file tree
Hide file tree
Showing 60 changed files with 842 additions and 379 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.a*
torch/lib/*.dll*
torch/lib/*.lib
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
Expand Down
2 changes: 1 addition & 1 deletion cmake/FindCUDA/FindCUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ option(CUDA_HOST_COMPILATION_CPP "Generated file extension" ON)
set(CUDA_NVCC_FLAGS "" CACHE STRING "Semi-colon delimit multiple arguments.")

if(CMAKE_GENERATOR MATCHES "Visual Studio")
set(CUDA_HOST_COMPILER "$(VCInstallDir)bin" CACHE FILEPATH "Host side compiler used by NVCC")
set(CUDA_HOST_COMPILER "${CMAKE_C_COMPILER}" CACHE FILEPATH "Host side compiler used by NVCC")
else()
if(APPLE
AND "${CMAKE_C_COMPILER_ID}" MATCHES "Clang"
Expand Down
143 changes: 109 additions & 34 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@
NCCL_INCLUDE_DIR, NCCL_ROOT_DIR, NCCL_SYSTEM_LIB
from tools.setup_helpers.nnpack import WITH_NNPACK, NNPACK_LIB_PATHS, \
NNPACK_INCLUDE_DIRS
from tools.setup_helpers.nvtoolext import NVTOOLEXT_HOME
from tools.setup_helpers.split_types import split_types

DEBUG = check_env_flag('DEBUG')
WITH_DISTRIBUTED = not check_env_flag('NO_DISTRIBUTED')

IS_WINDOWS = (platform.system() == 'Windows')
IS_DARWIN = (platform.system() == 'Darwin')
IS_LINUX = (platform.system() == 'Linux')

WITH_DISTRIBUTED = not check_env_flag('NO_DISTRIBUTED') and not IS_WINDOWS
WITH_DISTRIBUTED_MW = WITH_DISTRIBUTED and check_env_flag('WITH_DISTRIBUTED_MW')


Expand Down Expand Up @@ -82,14 +88,17 @@ def patched_link(self, *args, **kwargs):

dep_libs = [
'nccl', 'ATen',
'libshm', 'gloo', 'THD', 'nanopb',
'libshm', 'libshm_windows', 'gloo', 'THD', 'nanopb',
]


def build_libs(libs):
for lib in libs:
assert lib in dep_libs, 'invalid lib: {}'.format(lib)
build_libs_cmd = ['bash', 'torch/lib/build_libs.sh']
if IS_WINDOWS:
build_libs_cmd = ['torch\\lib\\build_libs.bat']
else:
build_libs_cmd = ['bash', 'torch/lib/build_libs.sh']
my_env = os.environ.copy()
my_env["PYTORCH_PYTHON"] = sys.executable
if WITH_SYSTEM_NCCL:
Expand Down Expand Up @@ -119,7 +128,11 @@ def run(self):
libs = []
if WITH_NCCL and not WITH_SYSTEM_NCCL:
libs += ['nccl']
libs += ['ATen', 'libshm', 'nanopb']
libs += ['ATen', 'nanopb']
if IS_WINDOWS:
libs += ['libshm_windows']
else:
libs += ['libshm']
if WITH_DISTRIBUTED:
if sys.platform.startswith('linux'):
libs += ['gloo']
Expand Down Expand Up @@ -201,6 +214,7 @@ def monkey_patch_THD_link_flags():
class build_ext(setuptools.command.build_ext.build_ext):

def run(self):

# Print build options
if WITH_NUMPY:
print('-- Building with NumPy bindings')
Expand Down Expand Up @@ -273,6 +287,25 @@ def run(self):
'torch/lib/tmp_install/share/ATen/Declarations.yaml',
jit_gen_dir)

if IS_WINDOWS:
build_temp = self.build_temp
build_dir = 'torch/csrc'

ext_filename = self.get_ext_filename('_C')
lib_filename = '.'.join(ext_filename.split('.')[:-1]) + '.lib'

_C_LIB = os.path.join(build_temp, build_dir, lib_filename).replace('\\', '/')

THNN.extra_link_args += [_C_LIB]
if WITH_CUDA:
THCUNN.extra_link_args += [_C_LIB]
else:
# To generate .obj files for AutoGPU for the export class
# a header file cannot build, so it has to be copied to someplace as a source file
if os.path.exists("torch/csrc/generated/AutoGPU_cpu_win.cpp"):
os.remove("torch/csrc/generated/AutoGPU_cpu_win.cpp")
shutil.copyfile("torch/csrc/cuda/AutoGPU.h", "torch/csrc/generated/AutoGPU_cpu_win.cpp")

# It's an old-style class in Python 2.7...
setuptools.command.build_ext.build_ext.run(self)

Expand Down Expand Up @@ -315,14 +348,23 @@ def run(self):
include_dirs = []
library_dirs = []
extra_link_args = []
extra_compile_args = ['-std=c++11', '-Wno-write-strings',
# Python 2.6 requires -fno-strict-aliasing, see
# http://legacy.python.org/dev/peps/pep-3123/
'-fno-strict-aliasing',
# Clang has an unfixed bug leading to spurious missing
# braces warnings, see
# https://bugs.llvm.org/show_bug.cgi?id=21629
'-Wno-missing-braces']

if IS_WINDOWS:
extra_compile_args = ['/Z7', '/EHa', '/DNOMINMAX'
# /Z7 turns on symbolic debugging information in .obj files
# /EHa is about native C++ catch support for asynchronous
# structured exception handling (SEH)
# /DNOMINMAX removes builtin min/max functions
]
else:
extra_compile_args = ['-std=c++11', '-Wno-write-strings',
# Python 2.6 requires -fno-strict-aliasing, see
# http://legacy.python.org/dev/peps/pep-3123/
'-fno-strict-aliasing',
# Clang has an unfixed bug leading to spurious missing
# braces warnings, see
# https://bugs.llvm.org/show_bug.cgi?id=21629
'-Wno-missing-braces']

cwd = os.path.dirname(os.path.abspath(__file__))
lib_path = os.path.join(cwd, "torch", "lib")
Expand Down Expand Up @@ -355,13 +397,18 @@ def check_file(f):
ATEN_LIB = os.path.join(lib_path, 'libATen.so.1')
THD_LIB = os.path.join(lib_path, 'libTHD.a')
NCCL_LIB = os.path.join(lib_path, 'libnccl.so.1')
if platform.system() == 'Darwin':
ATEN_LIB = os.path.join(lib_path, 'libATen.1.dylib')
NCCL_LIB = os.path.join(lib_path, 'libnccl.1.dylib')

# static library only
NANOPB_STATIC_LIB = os.path.join(lib_path, 'libprotobuf-nanopb.a')

if IS_DARWIN:
ATEN_LIB = os.path.join(lib_path, 'libATen.1.dylib')
NCCL_LIB = os.path.join(lib_path, 'libnccl.1.dylib')

if IS_WINDOWS:
ATEN_LIB = os.path.join(lib_path, 'ATen.lib')
NANOPB_STATIC_LIB = os.path.join(lib_path, 'protobuf-nanopb.lib')

main_compile_args = ['-D_THP_CORE']
main_libraries = ['shm']
main_link_args = [ATEN_LIB, NANOPB_STATIC_LIB]
Expand Down Expand Up @@ -457,20 +504,41 @@ def check_file(f):
include_dirs += [tmp_install_path + "/include/THD"]
main_link_args += [THD_LIB]

if IS_WINDOWS and not WITH_CUDA:
main_sources += ["torch/csrc/generated/AutoGPU_cpu_win.cpp"]

if WITH_CUDA:
cuda_lib_dirs = ['lib64', 'lib']
nvtoolext_lib_name = None
if IS_WINDOWS:
cuda_lib_path = CUDA_HOME + '/lib/x64/'
nvtoolext_lib_path = NVTOOLEXT_HOME + '/lib/x64/'
nvtoolext_include_path = os.path.join(NVTOOLEXT_HOME, 'include')

library_dirs.append(nvtoolext_lib_path)
include_dirs.append(nvtoolext_include_path)

nvtoolext_lib_name = 'nvToolsExt64_1'

# MSVC doesn't support runtime symbol resolving, `nvrtc` and `cuda` should be linked
main_libraries += ['nvrtc', 'cuda']
else:
cuda_lib_dirs = ['lib64', 'lib']

for lib_dir in cuda_lib_dirs:
cuda_lib_path = os.path.join(CUDA_HOME, lib_dir)
if os.path.exists(cuda_lib_path):
break
extra_link_args.append('-Wl,-rpath,' + cuda_lib_path)

nvtoolext_lib_name = 'nvToolsExt'

library_dirs.append(cuda_lib_path)
cuda_include_path = os.path.join(CUDA_HOME, 'include')
for lib_dir in cuda_lib_dirs:
cuda_lib_path = os.path.join(CUDA_HOME, lib_dir)
if os.path.exists(cuda_lib_path):
break
include_dirs.append(cuda_include_path)
include_dirs.append(tmp_install_path + "/include/THCUNN")
library_dirs.append(cuda_lib_path)
extra_link_args.append('-Wl,-rpath,' + cuda_lib_path)
extra_compile_args += ['-DWITH_CUDA']
extra_compile_args += ['-DCUDA_LIB_PATH=' + cuda_lib_path]
main_libraries += ['cudart', 'nvToolsExt']
main_libraries += ['cudart', nvtoolext_lib_name]
main_sources += [
"torch/csrc/cuda/Module.cpp",
"torch/csrc/cuda/Storage.cpp",
Expand Down Expand Up @@ -498,7 +566,8 @@ def check_file(f):
library_dirs.append(CUDNN_LIB_DIR)
# NOTE: these are at the front, in case there's another cuDNN in CUDA path
include_dirs.insert(0, CUDNN_INCLUDE_DIR)
extra_link_args.insert(0, '-Wl,-rpath,' + CUDNN_LIB_DIR)
if not IS_WINDOWS:
extra_link_args.insert(0, '-Wl,-rpath,' + CUDNN_LIB_DIR)
main_sources += [
"torch/csrc/cudnn/BatchNorm.cpp",
"torch/csrc/cudnn/Conv.cpp",
Expand All @@ -519,8 +588,11 @@ def check_file(f):
extra_compile_args += ['-DWITH_NNPACK']

if DEBUG:
extra_compile_args += ['-O0', '-g']
extra_link_args += ['-O0', '-g']
if IS_WINDOWS:
extra_link_args.append('/DEBUG:FULL')
else:
extra_compile_args += ['-O0', '-g']
extra_link_args += ['-O0', '-g']

if os.getenv('PYTORCH_BINARY_BUILD') and platform.system() == 'Linux':
print('PYTORCH_BINARY_BUILD found. Static linking libstdc++ on Linux')
Expand All @@ -537,8 +609,10 @@ def check_file(f):


def make_relative_rpath(path):
if platform.system() == 'Darwin':
if IS_DARWIN:
return '-Wl,-rpath,@loader_path/' + path
elif IS_WINDOWS:
return ''
else:
return '-Wl,-rpath,$ORIGIN/' + path

Expand All @@ -559,11 +633,12 @@ def make_relative_rpath(path):
)
extensions.append(C)

DL = Extension("torch._dl",
sources=["torch/csrc/dl.c"],
language='c',
)
extensions.append(DL)
if not IS_WINDOWS:
DL = Extension("torch._dl",
sources=["torch/csrc/dl.c"],
language='c',
)
extensions.append(DL)

THNN = Extension("torch._thnn._THNN",
sources=['torch/csrc/nn/THNN.cpp'],
Expand All @@ -579,7 +654,7 @@ def make_relative_rpath(path):

if WITH_CUDA:
thnvrtc_link_flags = extra_link_args + [make_relative_rpath('lib')]
if platform.system() == 'Linux':
if IS_LINUX:
thnvrtc_link_flags = ['-Wl,--no-as-needed'] + thnvrtc_link_flags
THNVRTC = Extension("torch._nvrtc",
libraries=['nvrtc', 'cuda'],
Expand Down Expand Up @@ -634,7 +709,7 @@ def make_relative_rpath(path):
cmdclass=cmdclass,
packages=packages,
package_data={'torch': [
'lib/*.so*', 'lib/*.dylib*',
'lib/*.so*', 'lib/*.dylib*', 'lib/*.dll',
'lib/torch_shm_manager',
'lib/*.h',
'lib/include/TH/*.h', 'lib/include/TH/generic/*.h',
Expand Down
7 changes: 7 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import os
import re
import argparse
import unittest
import warnings
Expand Down Expand Up @@ -320,6 +321,12 @@ def accept_output(update_type):
("I got this output for {}:\n\n{}\n\n"
"No expect file exists; to accept the current output, run:\n"
"python {} {} --accept").format(munged_id, s, __main__.__file__, munged_id))

# a hack for JIT tests
if sys.platform == 'win32':
expected = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', expected)
s = re.sub(r'CppOp\[(.+?)\]', 'CppOp[]', s)

if ACCEPT:
if expected != s:
return accept_output("updated output")
Expand Down
40 changes: 40 additions & 0 deletions test/run_test.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
@echo off

set PYCMD=python

echo Running JIT tests
%PYCMD% test_jit.py

echo Running torch tests
%PYCMD% test_torch.py

echo Running autograd tests
%PYCMD% test_autograd.py
%PYCMD% test_potrf.py

echo Running sparse tests
%PYCMD% test_sparse.py

echo Running nn tests
%PYCMD% test_nn.py

echo Running legacy nn tests
%PYCMD% test_legacy_nn.py

echo Running optim tests
%PYCMD% test_optim.py

echo Running multiprocessing tests
%PYCMD% test_multiprocessing.py

echo Running util tests
%PYCMD% test_utils.py

echo Running dataloader tests
%PYCMD% test_dataloader.py

echo Running cuda tests
%PYCMD% test_cuda.py

echo Running NCCL tests
%PYCMD% test_nccl.py
19 changes: 10 additions & 9 deletions test/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@
TEST_MULTIGPU = TEST_CUDA_IPC and torch.cuda.device_count() > 1


class SubProcess(mp.Process):
def __init__(self, tensor):
super(SubProcess, self).__init__()
self.tensor = tensor
self.daemon = True

def run(self):
self.tensor.add_(3)


def simple_fill(queue, event):
data = queue.get()
data[0][:] = 4
Expand Down Expand Up @@ -269,15 +279,6 @@ def queue_put():
queue_put()

def test_inherit_tensor(self):
class SubProcess(mp.Process):
def __init__(self, tensor):
super(SubProcess, self).__init__()
self.tensor = tensor
self.daemon = True

def run(self):
self.tensor.add_(3)

t = torch.zeros(5, 5)
p = SubProcess(t.share_memory_())
p.start()
Expand Down
3 changes: 2 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def init(cls):
path = download_file('https://download.pytorch.org/test_data/legacy_modules.t7')
except unittest.SkipTest:
return
tests = load_lua(path)
long_size = 8 if sys.platform == 'win32' else None
tests = load_lua(path, long_size=long_size)
for name, test in tests['modules'].items():
test_name = 'test_' + name.replace('nn.', '')
setattr(cls, test_name, cls._module_test(name, test))
Expand Down
5 changes: 5 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include "Functions.h"
#include <ATen/WrapDimUtils.h>

// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
#define _USE_MATH_DEFINES
#include <ciso646>
#endif
#include <math.h>

// ${generated_comment}
Expand Down
Loading

0 comments on commit aa91193

Please sign in to comment.