Skip to content

Commit

Permalink
[Tensor] Support Read/Save Functionalities for Quantized Type
Browse files Browse the repository at this point in the history
This PR enhances the existing functionality by supporting reading and saving operations on quantized type tensors.
This change can ensure that quantized tensors can accurately retrieve and store their data from binary files.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Nov 18, 2024
1 parent 6366ddc commit 3ccea3d
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 4 deletions.
12 changes: 12 additions & 0 deletions nntrainer/tensor/char_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ void CharTensor::print(std::ostream &out) const {
}
out.copyfmt(init);
}

/// @todo print quantization information
}

void CharTensor::copy(const void *buf) {
Expand All @@ -364,4 +366,14 @@ void CharTensor::copy(const void *buf) {
}
}

void CharTensor::save_quantization_info(std::ostream &file) {
checkedWrite(file, (char *)&axis, sizeof(uint8_t),
"[CharTensor::save] failed to write quantization information");
}

void CharTensor::read_quantization_info(std::ifstream &file) {
checkedRead(file, (char *)&axis, sizeof(uint8_t),
"[CharTensor::read] failed to read quantization information");
}

} // namespace nntrainer
15 changes: 15 additions & 0 deletions nntrainer/tensor/char_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,22 @@ class CharTensor : public TensorBase {
*/
void print(std::ostream &out) const override;

/**
* @copydoc TensorBase::save_quantization_info()
*/
void save_quantization_info(std::ostream &file) override;

/**
* @copydoc TensorBase::read_quantization_info()
*/
void read_quantization_info(std::ifstream &file) override;

private:
/**
* @brief quantization axis
*/
uint8_t axis;

/**
* @brief copy a buffer to @a this, the caller has to ensure that @a this is
* initialized otherwise undefined behavior
Expand Down
14 changes: 10 additions & 4 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,12 @@ void Tensor::save(std::ostream &file) {
NNTR_THROW_IF(!getContiguous(), std::invalid_argument)
<< getName() << " is not contiguous, cannot save.";

std::streamsize sz = static_cast<std::streamsize>(bytes());
/// @note Save quantization information which only works on Quantized Tensor
itensor->save_quantization_info(file);

std::streamsize sz = static_cast<std::streamsize>(bytes() + scale_size());
NNTR_THROW_IF(sz < 0, std::invalid_argument)
<< "save size: " << bytes()
<< "save size: " << bytes() + scale_size()
<< " is too big. It cannot be represented by std::streamsize";

checkedWrite(file, getData<char>(), sz, "[Tensor::save] operation failed");
Expand All @@ -1049,10 +1052,13 @@ void Tensor::read(std::ifstream &file) {
NNTR_THROW_IF(!getContiguous(), std::invalid_argument)
<< getName() << " is not contiguous, cannot read.";

std::streamsize sz = static_cast<std::streamsize>(bytes());
/// @note Read quantization information which only works on Quantized Tensor
itensor->read_quantization_info(file);

std::streamsize sz = static_cast<std::streamsize>(bytes() + scale_size());

NNTR_THROW_IF(sz < 0, std::invalid_argument)
<< "read size: " << bytes()
<< "read size: " << bytes() + scale_size()
<< " is too big. It cannot be represented by std::streamsize";

checkedRead(file, getData<char>(), sz, "[Tensor::read] operation failed");
Expand Down
10 changes: 10 additions & 0 deletions nntrainer/tensor/tensor_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,16 @@ class TensorBase {
size_t getIndex(unsigned int b, unsigned int c, unsigned int h,
unsigned int w) const noexcept;

/**
* @brief Save quantization information
*/
virtual void save_quantization_info(std::ostream &file) {}

/**
* @brief Read quantization information
*/
virtual void read_quantization_info(std::ifstream &file) {}

/**
* @brief Get size of current tensor
* @retval unsigned int size of the current tensor
Expand Down

0 comments on commit 3ccea3d

Please sign in to comment.