Skip to content

Commit

Permalink
**BREAKING** Return the number of dtype entries from `memory::size(…
Browse files Browse the repository at this point in the history
…)`. (#711)
  • Loading branch information
kris-rowe committed Jan 24, 2024
1 parent 6084f64 commit 678d41e
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 29 deletions.
25 changes: 11 additions & 14 deletions include/occa/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,38 +196,35 @@ namespace occa {
* @startDoc{size}
*
* Description:
* Get the byte size of the allocated memory
* Returns the number of elements of size [[dtype_t]] held in this memory.
* If no type was given during [[allocation|device.malloc]], returns the
* size of the storage in bytes.
*
* @endDoc
*/
udim_t size() const;

/**
* @startDoc{length[0]}
* @startDoc{length}
*
* Description:
* Get the length of the memory object, using its underlying [[dtype_t]].
* This [[dtype_t]] can be fetched through the [[memory.dtype]] method
*
* If no type was given during [[allocation|device.malloc]] or was ever set
* through [[casting it|memory.cast]], it will return the bytes just like [[memory.size]].
* Returns the number of elements of size [[dtype_t]] held in this memory.
* If no type was given during [[allocation|device.malloc]], returns the
* size of the storage in bytes.
*
* @endDoc
*/
udim_t length() const;

/**
* @startDoc{length[1]}
* @startDoc{byte_size}
*
* Overloaded Description:
* Same as above but explicitly chose the type (`T`)
* Description:
* Get the size of the allocated memory in bytes.
*
* @endDoc
*/
template <class T>
udim_t length() const {
return size() / sizeof(T);
}
udim_t byte_size() const;

/**
* @startDoc{operator_equals[0]}
Expand Down
4 changes: 2 additions & 2 deletions include/occa/functional/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,8 @@ namespace occa {
}

array concat(const array &other) const {
const udim_t entries = length();
const udim_t other_entries = other.length();
const udim_t entries = memory_.length();
const udim_t other_entries = other.memory_.length();

occa::memory ret = getDevice().template malloc<T>(entries + other_entries);
ret.copyFrom(memory_, entries, 0);
Expand Down
20 changes: 11 additions & 9 deletions src/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ namespace occa {
}

udim_t memory::size() const {
if (modeMemory == NULL) {
return 0;
}
return modeMemory->size;
return length();
}

udim_t memory::length() const {
Expand All @@ -145,6 +142,11 @@ namespace occa {
return modeMemory->size / modeMemory->dtype_->bytes();
}

udim_t memory::byte_size() const {
if (modeMemory == NULL) return 0;
else return modeMemory->size;
}

bool memory::operator == (const occa::memory &other) const {
return (modeMemory == other.modeMemory);
}
Expand Down Expand Up @@ -172,12 +174,12 @@ namespace occa {
? (length() - offset)
: count);

OCCA_ERROR("Trying to allocate negative bytes (" << bytes << ")",
OCCA_ERROR("Trying to allocate negative elements (" << count << ")",
bytes >= 0);

OCCA_ERROR("Cannot have offset and bytes greater than the memory size ("
<< offset_ << " + " << bytes << " > " << size() << ")",
(offset_ + (dim_t) bytes) <= (dim_t) size());
OCCA_ERROR("Memory size is less than offset + count ("
<< size() << " <" << offset << " + " << count << ")",
(offset + (dim_t) count) <= (dim_t) size());

occa::memory m(modeMemory->slice(offset_, bytes));
m.setDtype(dtype());
Expand Down Expand Up @@ -330,7 +332,7 @@ namespace occa {

occa::memory mem = (
occa::device(modeMemory->getModeDevice())
.malloc(size(), *this, properties())
.malloc(byte_size(), *this, properties())
);
mem.setDtype(dtype());

Expand Down
9 changes: 6 additions & 3 deletions tests/src/core/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,22 @@ void testWrapMemory() {
int *hostPtr = &value;

occa::memory mem = device.wrapMemory((void*) hostPtr, bytes);
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);

mem = device.wrapMemory(hostPtr, 1);
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);

mem = device.wrapMemory(hostPtr, 1, {{"use_host_pointer", false}});
mem.setDtype(occa::dtype::int_);
ASSERT_EQ(mem.ptr<int>()[0], value);
ASSERT_EQ(mem.ptr<int>(), hostPtr);
ASSERT_EQ((int) mem.length<int>(), 1);
ASSERT_EQ((int) mem.length(), 1);
}

void testUnwrap() {
Expand Down
2 changes: 1 addition & 1 deletion tests/src/core/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void testCast() {
ASSERT_TRUE(occa::dtype::double_ == occa_memory.dtype());
ASSERT_TRUE(occa::dtype::byte == casted_memory.dtype());

ASSERT_EQ(occa_memory.size(), casted_memory.size());
ASSERT_EQ(occa_memory.byte_size(), casted_memory.byte_size());
}

void testCopy() {
Expand Down

0 comments on commit 678d41e

Please sign in to comment.