Skip to content

Commit

Permalink
Introduce 2 tweaks to DeviceVector to improve safety of the class
Browse files Browse the repository at this point in the history
1. Delete implicitly declared constructors/assignment operations. We could definitely define these, but currently they can lead to dereferencing null pointers or double-freeing pointers

2. While I was here, I also added an explicit check that the elements of the vector are trivially copyable. This is the formal requirement for being able to copy an object with a variant of memcpy.
  • Loading branch information
mabruzzo committed Nov 1, 2023
1 parent d60e086 commit 72aa119
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/utils/DeviceVector.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <algorithm>
#include <stdexcept>
#include <string>
#include <type_traits>
#include <vector>

// External Includes
Expand All @@ -35,12 +36,15 @@ namespace cuda_utilities
* `data()` method. This class works for any device side pointer, scalar or
* array valued.
*
* \tparam T Any serialized type where `sizeof(T)` returns correct results
* should work but non-primitive types have not been tested.
* \tparam T Any trivially copyable type where `sizeof(T)` returns correct
* results should work, but non-primitive types have not been tested.
*/
template <typename T>
class DeviceVector
{
static_assert(std::is_trivially_copyable_v<T>,
"DeviceVector can only be used with trivially_copyable types "
"due to the internal usage of memcpy");
public:
/*!
* \brief Construct a new Device Vector object by calling the
Expand All @@ -60,6 +64,15 @@ class DeviceVector
*/
~DeviceVector() { _deAllocate(); }

/* The following are deleted because they currently lead to invalid state.
* (But they can all easily be implemented in the future).
*/
DeviceVector() = delete;
DeviceVector(const DeviceVector<T>&) = delete;
DeviceVector(DeviceVector<T>&&) = delete;
DeviceVector<T>& operator=(const DeviceVector<T>& other) = delete;
DeviceVector<T>& operator=(DeviceVector<T>&& other) = delete;

/*!
* \brief Get the raw device pointer
*
Expand Down

0 comments on commit 72aa119

Please sign in to comment.