Skip to content

Commit

Permalink
Finish test_torch_tensor_ones
Browse files Browse the repository at this point in the history
  • Loading branch information
jwallwork23 committed Nov 15, 2024
1 parent 0d9b25a commit 826702a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
16 changes: 8 additions & 8 deletions run_test_suite.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ cd ${BUILD_DIR}/test/unit
ctest ${CTEST_ARGS}
cd -

# Integration tests
EXAMPLES="1_SimpleNet 2_ResNet18 4_MultiIO 6_Autograd"
for EXAMPLE in ${EXAMPLES}; do
pip -q install -r examples/${EXAMPLE}/requirements.txt
cd ${BUILD_DIR}/test/examples/${EXAMPLE}
ctest ${CTEST_ARGS}
cd -
done
# # Integration tests
# EXAMPLES="1_SimpleNet 2_ResNet18 4_MultiIO 6_Autograd"
# for EXAMPLE in ${EXAMPLES}; do
# pip -q install -r examples/${EXAMPLE}/requirements.txt
# cd ${BUILD_DIR}/test/examples/${EXAMPLE}
# ctest ${CTEST_ARGS}
# cd -
# done
18 changes: 16 additions & 2 deletions src/test/unit/test_constructors.pf
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ end subroutine test_torch_tensor_zeros
@test
subroutine test_torch_tensor_ones()
use pFUnit
use ftorch
use ftorch, only: torch_kFloat32, torch_kCPU, torch_tensor, torch_tensor_delete, torch_tensor_ones, torch_tensor_to_array
use ftorch_test_utils, only: assert_allclose
use, intrinsic :: iso_fortran_env, only: real32
use iso_c_binding, only: c_bool, c_int, c_int64_t, c_null_ptr

implicit none
Expand All @@ -61,6 +63,9 @@ subroutine test_torch_tensor_ones()
integer(c_int) :: device_type
integer(c_int) :: device_index
logical(c_bool) :: requires_grad
real(kind=real32), dimension(:,:), pointer :: out_data
real(kind=real32), dimension(2,3) :: expected
logical :: test_pass

ndims = 2
tensor_shape = [2, 3]
Expand All @@ -69,13 +74,22 @@ subroutine test_torch_tensor_ones()
device_index = -1
requires_grad = .false.

! Create tensor of ones
call torch_tensor_ones(tensor, ndims, tensor_shape, dtype, device_type, &
device_index, requires_grad)

! Check if tensor is not null
! @assertNotEqual(tensor%p, c_null_ptr) ! FIXME: compiler not happy with this
! @assertTrue(tensor%p /= c_null_ptr) ! FIXME: compiler not happy with this

call torch_tensor_to_array(tensor, out_data, shape(expected))

! Check that the tensor values are all one
expected(:,:) = 1.0
test_pass = assert_allclose(out_data, expected, test_name="test_torch_tensor_ones")
@assertTrue(test_pass)

! Cleanup
nullify(out_data)
call torch_tensor_delete(tensor)

end subroutine test_torch_tensor_ones
Expand Down

0 comments on commit 826702a

Please sign in to comment.