Skip to content

Commit

Permalink
Merge pull request #57 from beomki-yeo/feat-set-block
Browse files Browse the repository at this point in the history
Add set block function in matrix
  • Loading branch information
beomki-yeo authored Mar 31, 2022
2 parents 6a07475 + 7a596f1 commit 3a89a8a
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 0 deletions.
12 changes: 12 additions & 0 deletions math/cmath/include/algebra/math/impl/cmath_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ struct actor {
return block_getter().template operator()<ROWS, COLS>(m, row, col);
}

/// Operator setting a block
template <size_type ROWS, size_type COLS, class input_matrix_type>
ALGEBRA_HOST_DEVICE void set_block(input_matrix_type &m,
const matrix_type<ROWS, COLS> &b, int row,
int col) {
for (size_type i = 0; i < ROWS; ++i) {
for (size_type j = 0; j < COLS; ++j) {
element_getter()(m, i + row, j + col) = element_getter()(b, i, j);
}
}
}

// Create zero matrix
template <size_type ROWS, size_type COLS>
ALGEBRA_HOST_DEVICE inline matrix_type<ROWS, COLS> zero() const {
Expand Down
8 changes: 8 additions & 0 deletions math/eigen/include/algebra/math/impl/eigen_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ struct actor {
return m.template block<ROWS, COLS>(row, col);
}

/// Operator setting a block
template <int ROWS, int COLS, class input_matrix_type>
ALGEBRA_HOST_DEVICE void set_block(input_matrix_type &m,
const matrix_type<ROWS, COLS> &b, int row,
int col) {
m.template block<ROWS, COLS>(row, col) = b;
}

// Create zero matrix
template <int ROWS, int COLS>
ALGEBRA_HOST_DEVICE inline matrix_type<ROWS, COLS> zero() {
Expand Down
12 changes: 12 additions & 0 deletions math/smatrix/include/algebra/math/impl/smatrix_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ struct actor {
return m.template Sub<matrix_type<ROWS, COLS> >(row, col);
}

/// Operator setting a block
template <unsigned int ROWS, unsigned int COLS, class input_matrix_type>
ALGEBRA_HOST_DEVICE void set_block(input_matrix_type &m,
const matrix_type<ROWS, COLS> &b,
unsigned int row, unsigned int col) {
for (unsigned int i = 0; i < ROWS; ++i) {
for (unsigned int j = 0; j < COLS; ++j) {
m(i + row, j + col) = b(i, j);
}
}
}

// Create zero matrix
template <unsigned int ROWS, unsigned int COLS>
ALGEBRA_HOST_DEVICE inline matrix_type<ROWS, COLS> zero() {
Expand Down
18 changes: 18 additions & 0 deletions tests/common/test_device_basics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,38 @@ class test_device_basics : public test_base<T> {
}
}

// Test set_zero
matrix_actor().set_zero(m2);
for (size_type i = 0; i < 6; ++i) {
for (size_type j = 0; j < 4; ++j) {
result += 0.4f * algebra::getter::element(m2, i, j);
}
}

// Test set_identity
matrix_actor().set_identity(m2);
for (size_type i = 0; i < 6; ++i) {
for (size_type j = 0; j < 4; ++j) {
result += 0.3f * algebra::getter::element(m2, i, j);
}
}

// Test block operations
auto b32 = matrix_actor().template block<3, 2>(m2, 2, 2);
algebra::getter::element(b32, 0, 0) = 4;
algebra::getter::element(b32, 0, 1) = 3;
algebra::getter::element(b32, 1, 0) = 12;
algebra::getter::element(b32, 1, 1) = 13;
algebra::getter::element(b32, 2, 0) = 5;
algebra::getter::element(b32, 2, 1) = 6;

matrix_actor().set_block(m2, b32, 2, 2);
for (size_type i = 0; i < 6; ++i) {
for (size_type j = 0; j < 4; ++j) {
result += 0.57f * algebra::getter::element(m2, i, j);
}
}

return result;
}

Expand Down
24 changes: 24 additions & 0 deletions tests/common/test_host_basics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,30 @@ TYPED_TEST_P(test_host_basics, matrix64) {
}
}
}

// Test block operations
auto b32 = typename TypeParam::matrix_actor().template block<3, 2>(m, 2, 2);
ASSERT_NEAR(algebra::getter::element(b32, 0, 0), 1., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(b32, 0, 1), 0., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(b32, 1, 0), 0., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(b32, 1, 1), 1., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(b32, 2, 0), 0., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(b32, 2, 1), 0., this->m_epsilon);

algebra::getter::element(b32, 0, 0) = 4;
algebra::getter::element(b32, 0, 1) = 3;
algebra::getter::element(b32, 1, 0) = 12;
algebra::getter::element(b32, 1, 1) = 13;
algebra::getter::element(b32, 2, 0) = 5;
algebra::getter::element(b32, 2, 1) = 6;

typename TypeParam::matrix_actor().set_block(m, b32, 2, 2);
ASSERT_NEAR(algebra::getter::element(m, 2, 2), 4., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(m, 2, 3), 3., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(m, 3, 2), 12., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(m, 3, 3), 13., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(m, 4, 2), 5., this->m_epsilon);
ASSERT_NEAR(algebra::getter::element(m, 4, 3), 6., this->m_epsilon);
}

// Test matrix operations with 3x3 matrix
Expand Down

0 comments on commit 3a89a8a

Please sign in to comment.