Skip to content

Commit

Permalink
Return the number of dtype entries from memory::size().
Browse files Browse the repository at this point in the history
  • Loading branch information
kris-rowe committed Oct 20, 2023
1 parent ca23135 commit 0f2f339
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 28 deletions.
25 changes: 11 additions & 14 deletions include/occa/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,35 @@ namespace occa {
* @startDoc{size}
*
* Description:
* Get the byte size of the allocated memory
* Returns the number of elements of size [[dtype_t]] held in this memory.
* If no type was given during [[allocation|device.malloc]], returns the
* size of the storage in bytes.
*
* @endDoc
*/
udim_t size() const;

/**
* @startDoc{length[0]}
* @startDoc{length}
*
* Description:
* Get the length of the memory object, using its underlying [[dtype_t]].
* This [[dtype_t]] can be fetched through the [[memory.dtype]] method
*
* If no type was given during [[allocation|device.malloc]] or was ever set
* through [[casting it|memory.cast]], it will return the bytes just like [[memory.size]].
* Returns the number of elements of size [[dtype_t]] held in this memory.
* If no type was given during [[allocation|device.malloc]], returns the
* size of the storage in bytes.
*
* @endDoc
*/
udim_t length() const;

/**
* @startDoc{length[1]}
* @startDoc{byte_size}
*
* Overloaded Description:
* Same as above but explicitly chose the type (`T`)
* Description:
* Get the size of the allocated memory in bytes.
*
* @endDoc
*/
template <class T>
udim_t length() const {
return size() / sizeof(T);
}
udim_t byte_size() const;

/**
* @startDoc{operator_equals[0]}
Expand Down
4 changes: 2 additions & 2 deletions include/occa/functional/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ namespace occa {
}

array concat(const array &other) const {
const udim_t bytes1 = memory_.size();
const udim_t bytes2 = other.memory_.size();
const udim_t bytes1 = memory_.byte_size();
const udim_t bytes2 = other.memory_.byte_size();

occa::memory ret = getDevice().template malloc<T>(length() + other.length());
ret.copyFrom(memory_, bytes1, 0);
Expand Down
20 changes: 11 additions & 9 deletions src/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ namespace occa {
}

udim_t memory::size() const {
if (modeMemory == NULL) {
return 0;
}
return modeMemory->size;
return length();
}

udim_t memory::length() const {
Expand All @@ -145,6 +142,11 @@ namespace occa {
return modeMemory->size / modeMemory->dtype_->bytes();
}

udim_t memory::byte_size() const {
if (modeMemory == NULL) return 0;
else return modeMemory->size;
}

bool memory::operator == (const occa::memory &other) const {
return (modeMemory == other.modeMemory);
}
Expand Down Expand Up @@ -172,12 +174,12 @@ namespace occa {
? (length() - offset)
: count);

OCCA_ERROR("Trying to allocate negative bytes (" << bytes << ")",
OCCA_ERROR("Trying to allocate negative elements (" << count << ")",
bytes >= 0);

OCCA_ERROR("Cannot have offset and bytes greater than the memory size ("
<< offset_ << " + " << bytes << " > " << size() << ")",
(offset_ + (dim_t) bytes) <= (dim_t) size());
OCCA_ERROR("Memory size is less than offset + count ("
<< size() << " <" << offset << " + " << count << ")",
(offset + (dim_t) count) <= (dim_t) size());

occa::memory m(modeMemory->slice(offset_, bytes));
m.setDtype(dtype());
Expand Down Expand Up @@ -320,7 +322,7 @@ namespace occa {

occa::memory mem = (
occa::device(modeMemory->getModeDevice())
.malloc(size(), *this, properties())
.malloc(byte_size(), *this, properties())
);
mem.setDtype(dtype());

Expand Down
9 changes: 6 additions & 3 deletions tests/src/core/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,22 @@ void testWrapMemory() {
int *hostPtr = &value;

occa::memory mem = device.wrapMemory((void*) hostPtr, bytes);
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);

mem = device.wrapMemory(hostPtr, 1);
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);

mem = device.wrapMemory(hostPtr, 1, {{"use_host_pointer", false}});
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);
}

void testUnwrap() {
Expand Down

0 comments on commit 0f2f339

Please sign in to comment.