Skip to content

Commit

Permalink
Specify integer size for tensor_layout
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Nov 25, 2024
1 parent 60720ea commit d8613fe
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 46 deletions.
3 changes: 2 additions & 1 deletion examples/1_SimpleNet/simplenet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ program inference

! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32
use, intrinsic :: iso_c_binding, only : c_int64_t

! Import our library for interfacing with PyTorch
use ftorch, only : torch_model, torch_tensor, torch_kCPU, torch_delete, &
Expand All @@ -22,7 +23,7 @@ program inference
real(wp), dimension(5), target :: in_data
real(wp), dimension(5), target :: out_data
real(wp), dimension(5), target :: expected
integer, parameter :: tensor_layout(1) = [1]
integer(c_int64_t), parameter :: tensor_layout(1) = [1]

! Set up Torch data structures
! The net, a vector of input tensors (in this case we only have one), and the output tensor
Expand Down
5 changes: 3 additions & 2 deletions examples/2_ResNet18/resnet_infer_fortran.f90
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
program inference

use, intrinsic :: iso_fortran_env, only : sp => real32
use, intrinsic :: iso_c_binding, only : c_int64_t

! Import our library for interfacing with PyTorch
use ftorch, only : torch_model, torch_tensor, torch_kCPU, torch_delete, &
Expand Down Expand Up @@ -32,10 +33,10 @@ subroutine main()

integer, parameter :: in_dims = 4
integer, parameter :: in_shape(in_dims) = [1, 3, 224, 224]
integer, parameter :: in_layout(in_dims) = [1, 2, 3, 4]
integer(c_int64_t), parameter :: in_layout(in_dims) = [1, 2, 3, 4]
integer, parameter :: out_dims = 2
integer, parameter :: out_shape(out_dims) = [1, 1000]
integer, parameter :: out_layout(out_dims) = [1, 2]
integer(c_int64_t), parameter :: out_layout(out_dims) = [1, 2]

! Path to input data
character(len=100) :: data_dir
Expand Down
3 changes: 2 additions & 1 deletion examples/3_MultiGPU/simplenet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ program inference

! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32
use, intrinsic :: iso_c_binding, only : c_int64_t

! Import our library for interfacing with PyTorch
use ftorch, only : torch_model, torch_tensor, torch_kCUDA, torch_kCPU, &
Expand All @@ -22,7 +23,7 @@ program inference
! Set up Fortran data structures
real(wp), dimension(5), target :: in_data
real(wp), dimension(5), target :: out_data
integer, parameter :: tensor_layout(1) = [1]
integer(c_int64_t), parameter :: tensor_layout(1) = [1]

! Set up Torch data structures
type(torch_model) :: model
Expand Down
3 changes: 2 additions & 1 deletion examples/4_MultiIO/multiionet_infer_fortran.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ program inference

! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32
use, intrinsic :: iso_c_binding, only : c_int64_t

! Import our library for interfacing with PyTorch
use ftorch, only : torch_model, torch_tensor, torch_kCPU, &
Expand All @@ -25,7 +26,7 @@ program inference
real(wp), dimension(4), target :: out_data1
real(wp), dimension(4), target :: out_data2
real(wp), dimension(4) :: expected
integer, parameter :: tensor_layout(1) = [1]
integer(c_int64_t), parameter :: tensor_layout(1) = [1]

! Set up Torch data structures
! The net, a vector of input tensors (in this case we only have one), and the output tensor
Expand Down
3 changes: 2 additions & 1 deletion examples/6_Autograd/autograd.f90
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ program example

! Import precision info from iso
use, intrinsic :: iso_fortran_env, only : sp => real32
use, intrinsic :: iso_c_binding, only : c_int64_t

! Import our library for interfacing with PyTorch's Autograd module
use ftorch, only : torch_tensor, torch_kCPU, &
Expand All @@ -20,7 +21,7 @@ program example
real(wp), dimension(n,m), target :: in_data
real(wp), dimension(:,:), pointer :: out_data
real(wp), dimension(n,m) :: expected
integer :: tensor_layout(2) = [1, 2]
integer(c_int64_t) :: tensor_layout(2) = [1, 2]
integer :: i, j

! Flag for testing
Expand Down
60 changes: 30 additions & 30 deletions src/ftorch.f90

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/ftorch.fypp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ contains

! inputs
${f_type(PREC)}$(kind=${PREC}$), intent(in), target :: data_in${ranksuffix(RANK)}$ !! Input data that tensor will point at
integer, intent(in) :: layout(${RANK}$) !! Control order of indices
integer(c_int64_t), intent(in) :: layout(${RANK}$) !! Control order of indices
integer(c_int), intent(in) :: c_device_type !! Device type the tensor will live on (`torch_kCPU` or `torch_kCUDA`)
integer(c_int), optional, intent(in) :: device_index !! device index to use for `torch_kCUDA` case
logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor
Expand Down
18 changes: 9 additions & 9 deletions src/test/unit/test_constructors.pf
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@ subroutine test_torch_tensor_empty()

type(torch_tensor) :: tensor
integer(c_int), parameter :: ndims = 2
integer(c_int64_t), dimension(2) :: tensor_shape
integer(c_int64_t), dimension(2) :: tensor_layout
integer(c_int), parameter :: dtype = torch_kFloat32
integer(c_int), parameter :: device_type = torch_kCPU
integer(c_int), parameter :: device_index = -1
logical(c_bool), parameter :: requires_grad = .false.
logical :: test_pass

tensor_shape = [2, 3]
tensor_layout = [2, 3]

! Check the tensor pointer is not associated
@assertFalse(c_associated(tensor%p))

! Create a tensor of zeros
call torch_tensor_empty(tensor, ndims, tensor_shape, dtype, device_type, &
call torch_tensor_empty(tensor, ndims, tensor_layout, dtype, device_type, &
device_index, requires_grad)

! Check the tensor pointer is associated
Expand All @@ -52,7 +52,7 @@ subroutine test_torch_tensor_zeros()

type(torch_tensor) :: tensor
integer(c_int), parameter :: ndims = 2
integer(c_int64_t), dimension(2) :: tensor_shape
integer(c_int64_t), dimension(2) :: tensor_layout
integer(c_int), parameter :: dtype = torch_kFloat32
integer(c_int), parameter :: device_type = torch_kCPU
integer(c_int), parameter :: device_index = -1
Expand All @@ -61,13 +61,13 @@ subroutine test_torch_tensor_zeros()
real(kind=real32), dimension(2,3) :: expected
logical :: test_pass

tensor_shape = [2, 3]
tensor_layout = [2, 3]

! Check the tensor pointer is not associated
@assertFalse(c_associated(tensor%p))

! Create a tensor of zeros
call torch_tensor_zeros(tensor, ndims, tensor_shape, dtype, device_type, &
call torch_tensor_zeros(tensor, ndims, tensor_layout, dtype, device_type, &
device_index, requires_grad)

! Check the tensor pointer is associated
Expand Down Expand Up @@ -99,7 +99,7 @@ subroutine test_torch_tensor_ones()

type(torch_tensor) :: tensor
integer(c_int), parameter :: ndims = 2
integer(c_int64_t), dimension(2) :: tensor_shape
integer(c_int64_t), dimension(2) :: tensor_layout
integer(c_int), parameter :: dtype = torch_kFloat32
integer(c_int), parameter :: device_type = torch_kCPU
integer(c_int), parameter :: device_index = -1
Expand All @@ -108,13 +108,13 @@ subroutine test_torch_tensor_ones()
real(kind=real32), dimension(2,3) :: expected
logical :: test_pass

tensor_shape = [2, 3]
tensor_layout = [2, 3]

! Check the tensor pointer is not associated
@assertFalse(c_associated(tensor%p))

! Create tensor of ones
call torch_tensor_ones(tensor, ndims, tensor_shape, dtype, device_type, &
call torch_tensor_ones(tensor, ndims, tensor_layout, dtype, device_type, &
device_index, requires_grad)

! Check the tensor pointer is associated
Expand Down

0 comments on commit d8613fe

Please sign in to comment.