Skip to content

Commit

Permalink
python 2 support
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Jun 8, 2016
1 parent 6954783 commit 5ee3358
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 64 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ torch.egg-info/
*/**/__pycache__
torch/__init__.py
torch/csrc/generic/TensorMethods.cpp
*/**/*.pyc
20 changes: 18 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from setuptools import setup, Extension
from os.path import expanduser
from tools.cwrap import cwrap
import platform

################################################################################
# Generate __init__.py from templates
Expand Down Expand Up @@ -49,6 +51,16 @@
################################################################################
# Declare the package
################################################################################
extra_link_args = []

# TODO: remove and properly submodule TH in the repo itself
th_path = expanduser("~/torch/install/")
th_header_path = th_path + "include"
th_lib_path = th_path + "lib"
if platform.system() == 'Darwin':
extra_link_args.append('-L' + th_lib_path)
extra_link_args.append('-Wl,-rpath,' + th_lib_path)

sources = [
"torch/csrc/Module.cpp",
"torch/csrc/Tensor.cpp",
Expand All @@ -59,9 +71,13 @@
libraries=['TH'],
sources=sources,
language='c++',
include_dirs=["torch/csrc"])
include_dirs=(["torch/csrc", th_header_path]),
extra_link_args = extra_link_args,
)



setup(name="torch", version="0.1",
ext_modules=[C],
packages=['torch'])
packages=['torch'],
)
24 changes: 24 additions & 0 deletions test/smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch

a = torch.FloatTensor(4, 3)
b = torch.FloatTensor(3, 4)

a.add(b)

c = a.storage()

d = a.select(0, 1)

print(c)
print(a)
print(b)
print(d)


a.fill(0)

print(a[1])

print(a.ge(long(0)))
print(a.ge(0))

Empty file added tools/__init__.py
Empty file.
13 changes: 8 additions & 5 deletions tools/cwrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def signature_hash(self):
'THStorage': Template('return THPStorage_(newObject)($expr)'),
'THLongStorage': Template('return THPLongStorage_newObject($expr)'),
'bool': Template('return PyBool_FromLong($expr)'),
'long': Template('return PyLong_FromLong($expr)'),
'long': Template('return PyInt_FromLong($expr)'),
'double': Template('return PyFloat_FromDouble($expr)'),
'self': Template('$expr; Py_INCREF(self); return (PyObject*)self'),
# TODO
Expand Down Expand Up @@ -397,16 +397,19 @@ def argfilter():
CONSTANT arguments are literals.
Repeated arguments do not need to be specified twice.
"""
provided = set()
# use class rather than nonlocal to maintain 2.7 compat
# see http://stackoverflow.com/questions/3190706/nonlocal-keyword-in-python-2-x
# TODO: check this works
class context:
provided = set()
def is_already_provided(arg):
nonlocal provided
ret = False
ret |= arg.name == 'self'
ret |= arg.name == '_res_new'
ret |= arg.type == 'CONSTANT'
ret |= arg.type == 'EXPRESSION'
ret |= arg.name in provided
provided.add(arg.name)
ret |= arg.name in context.provided
context.provided.add(arg.name)
return ret
return is_already_provided

Expand Down
2 changes: 1 addition & 1 deletion torch/Storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ def __repr__(self):
return str(self)

def __iter__(self):
return map(lambda i: self[i], range(self.size()))
return iter(map(lambda i: self[i], range(self.size())))

2 changes: 1 addition & 1 deletion torch/Tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def __str__(self):
return _printing.printTensor(self)

def __iter__(self):
return map(lambda i: self.select(0, i), range(self.size(0)))
return iter(map(lambda i: self.select(0, i), range(self.size(0))))
50 changes: 34 additions & 16 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

#include "THP.h"

#if PY_MAJOR_VERSION == 2
#define ASSERT_TRUE(cmd) if (!(cmd)) {PyErr_SetString(PyExc_ImportError, "initialization error"); return;}
#else
#define ASSERT_TRUE(cmd) if (!(cmd)) return NULL
#endif

static PyObject* module;
static PyObject* tensor_classes;
Expand Down Expand Up @@ -34,21 +38,21 @@ static bool THPModule_loadClasses(PyObject *self)
PyObject *torch_module = PyImport_ImportModule("torch");
PyObject* module_dict = PyModule_GetDict(torch_module);

THPDoubleStorageClass = PyMapping_GetItemString(module_dict, "DoubleStorage");
THPFloatStorageClass = PyMapping_GetItemString(module_dict, "FloatStorage");
THPLongStorageClass = PyMapping_GetItemString(module_dict, "LongStorage");
THPIntStorageClass = PyMapping_GetItemString(module_dict, "IntStorage");
THPShortStorageClass = PyMapping_GetItemString(module_dict, "ShortStorage");
THPCharStorageClass = PyMapping_GetItemString(module_dict, "CharStorage");
THPByteStorageClass = PyMapping_GetItemString(module_dict, "ByteStorage");

THPDoubleTensorClass = PyMapping_GetItemString(module_dict, "DoubleTensor");
THPFloatTensorClass = PyMapping_GetItemString(module_dict, "FloatTensor");
THPLongTensorClass = PyMapping_GetItemString(module_dict, "LongTensor");
THPIntTensorClass = PyMapping_GetItemString(module_dict, "IntTensor");
THPShortTensorClass = PyMapping_GetItemString(module_dict, "ShortTensor");
THPCharTensorClass = PyMapping_GetItemString(module_dict, "CharTensor");
THPByteTensorClass = PyMapping_GetItemString(module_dict, "ByteTensor");
THPDoubleStorageClass = PyMapping_GetItemString(module_dict,(char*)"DoubleStorage");
THPFloatStorageClass = PyMapping_GetItemString(module_dict,(char*)"FloatStorage");
THPLongStorageClass = PyMapping_GetItemString(module_dict,(char*)"LongStorage");
THPIntStorageClass = PyMapping_GetItemString(module_dict,(char*)"IntStorage");
THPShortStorageClass = PyMapping_GetItemString(module_dict,(char*)"ShortStorage");
THPCharStorageClass = PyMapping_GetItemString(module_dict,(char*)"CharStorage");
THPByteStorageClass = PyMapping_GetItemString(module_dict,(char*)"ByteStorage");

THPDoubleTensorClass = PyMapping_GetItemString(module_dict,(char*)"DoubleTensor");
THPFloatTensorClass = PyMapping_GetItemString(module_dict,(char*)"FloatTensor");
THPLongTensorClass = PyMapping_GetItemString(module_dict,(char*)"LongTensor");
THPIntTensorClass = PyMapping_GetItemString(module_dict,(char*)"IntTensor");
THPShortTensorClass = PyMapping_GetItemString(module_dict,(char*)"ShortTensor");
THPCharTensorClass = PyMapping_GetItemString(module_dict,(char*)"CharTensor");
THPByteTensorClass = PyMapping_GetItemString(module_dict,(char*)"ByteTensor");
PySet_Add(tensor_classes, THPDoubleTensorClass);
PySet_Add(tensor_classes, THPFloatTensorClass);
PySet_Add(tensor_classes, THPLongTensorClass);
Expand Down Expand Up @@ -314,13 +318,15 @@ static PyMethodDef TorchMethods[] = {
{NULL, NULL, 0, NULL}
};

#if PY_MAJOR_VERSION != 2
static struct PyModuleDef torchmodule = {
PyModuleDef_HEAD_INIT,
"torch.C",
NULL,
-1,
TorchMethods
};
#endif

static void errorHandler(const char *msg, void *data)
{
Expand All @@ -338,10 +344,17 @@ static void updateErrorHandlers()
THSetArgErrorHandler(errorHandlerArg, NULL);
}

#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC initC()
#else
PyMODINIT_FUNC PyInit_C()
#endif
{
#if PY_MAJOR_VERSION == 2
ASSERT_TRUE(module = Py_InitModule("torch.C", TorchMethods));
#else
ASSERT_TRUE(module = PyModule_Create(&torchmodule));

#endif
ASSERT_TRUE(tensor_classes = PySet_New(NULL));
ASSERT_TRUE(PyObject_SetAttrString(module, "_tensorclasses", tensor_classes) == 0);

Expand All @@ -363,5 +376,10 @@ PyMODINIT_FUNC PyInit_C()

updateErrorHandlers();

#if PY_MAJOR_VERSION == 2
#else
return module;
#endif
}

#undef ASSERT_TRUE
9 changes: 9 additions & 0 deletions torch/csrc/THP.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
#include <stdbool.h>
#include <TH/TH.h>

// Back-compatibility macros, Thanks to http://cx-oracle.sourceforge.net/
// define PyInt_* macros for Python 3.x
#ifndef PyInt_Check
#define PyInt_Check PyLong_Check
#define PyInt_FromLong PyLong_FromLong
#define PyInt_AsLong PyLong_AsLong
#define PyInt_Type PyLong_Type
#endif

#include "Exceptions.h"
#include "utils.h"

Expand Down
34 changes: 23 additions & 11 deletions torch/csrc/generic/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ PyObject * THPStorage_(newObject)(THStorage *ptr)
// TODO: error checking
PyObject *args = PyTuple_New(0);
PyObject *kwargs = Py_BuildValue("{s:N}", "cdata", PyLong_FromVoidPtr(ptr));

PyObject *instance = PyObject_Call(THPStorageClass, args, kwargs);
Py_DECREF(args);
Py_DECREF(kwargs);
Expand All @@ -30,17 +31,17 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
{
HANDLE_TH_ERRORS
static const char *keywords[] = {"cdata", NULL};
PyObject *number_arg = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!", (char **)keywords, &PyLong_Type, &number_arg))
void* number_arg = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O&", (char **)keywords,
THPUtils_getLong, &number_arg))
return NULL;

THPStorage *self = (THPStorage *)type->tp_alloc(type, 0);
if (self != NULL) {
if (kwargs) {
self->cdata = (THStorage*)PyLong_AsVoidPtr(number_arg);
self->cdata = (THStorage*)number_arg;
THStorage_(retain)(self->cdata);
} else if (/* !kwargs && */ number_arg) {
self->cdata = THStorage_(newWithSize)(PyLong_AsLong(number_arg));
self->cdata = THStorage_(newWithSize)((long) number_arg);
} else {
self->cdata = THStorage_(new)();
}
Expand All @@ -66,8 +67,9 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
{
HANDLE_TH_ERRORS
/* Integer index */
if (PyLong_Check(index)) {
long nindex = PyLong_AsLong(index);
long nindex;
if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1 ) {
if (nindex < 0)
nindex += THStorage_(size)(self->cdata);
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
Expand All @@ -89,7 +91,11 @@ static PyObject * THPStorage_(get)(THPStorage *self, PyObject *index)
THStorage *new_storage = THStorage_(newWithData)(new_data, slicelength);
return THPStorage_(newObject)(new_storage);
}
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported");
char err_string[512];
snprintf (err_string, 512,
"%s %s", "Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return NULL;
END_HANDLE_TH_ERRORS
}
Expand All @@ -101,8 +107,10 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
if (!THPUtils_(parseReal)(value, &rvalue))
return -1;

if (PyLong_Check(index)) {
THStorage_(set)(self->cdata, PyLong_AsSize_t(index), rvalue);
long nindex;
if ((PyLong_Check(index) || PyInt_Check(index))
&& THPUtils_getLong(index, &nindex) == 1) {
THStorage_(set)(self->cdata, nindex, rvalue);
return 0;
} else if (PySlice_Check(index)) {
Py_ssize_t start, stop, len;
Expand All @@ -114,7 +122,11 @@ static int THPStorage_(set)(THPStorage *self, PyObject *index, PyObject *value)
THStorage_(set)(self->cdata, start, rvalue);
return 0;
}
PyErr_SetString(PyExc_RuntimeError, "Only indexing with integers and slices supported at the moment");
char err_string[512];
snprintf (err_string, 512, "%s %s",
"Only indexing with integers and slices supported, but got type: ",
index->ob_type->tp_name);
PyErr_SetString(PyExc_RuntimeError, err_string);
return -1;
END_HANDLE_TH_ERRORS_RET(-1)
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/generic/StorageMethods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static PyObject * THPStorage_(resize)(THPStorage *self, PyObject *number_arg)
HANDLE_TH_ERRORS
if (!PyLong_Check(number_arg))
return NULL;
size_t newsize = PyLong_AsSize_t(number_arg);
long newsize = PyLong_AsLong(number_arg);
if (PyErr_Occurred())
return NULL;
THStorage_(resize)(self->cdata, newsize);
Expand Down
Loading

0 comments on commit 5ee3358

Please sign in to comment.