From d54479bed00d3150f0cc3343df632299d7949e9c Mon Sep 17 00:00:00 2001 From: Jan Adler <91198858+adlerjan@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:47:43 +0100 Subject: [PATCH] Fix LDPC Initialization Bug (#23) Fix LDPC Initialization --- hermespy/fec/aff3ct/ldpc.cpp | 37 +++++++++++------- tests/unit_tests/fec/test_ldpc.py | 65 ++++++++++++++++++++----------- 2 files changed, 65 insertions(+), 37 deletions(-) diff --git a/hermespy/fec/aff3ct/ldpc.cpp b/hermespy/fec/aff3ct/ldpc.cpp index 2533e9f3..77817141 100644 --- a/hermespy/fec/aff3ct/ldpc.cpp +++ b/hermespy/fec/aff3ct/ldpc.cpp @@ -29,16 +29,15 @@ class LDPC syndromeChecking(syndromeChecking) { // Read the H matrix - this->infoBitPos = std::vector(dataBlockSize); - this->H = std::make_unique(LDPC_matrix_handler::read(hSourcePath, &infoBitPos)); - - // Infer parameters + this->H = std::make_unique(LDPC_matrix_handler::read(hSourcePath)); + this->codeBlockSize = this->H->get_n_rows(); // N + this->dataBlockSize = this->codeBlockSize - this->H->get_n_cols(); // K + + // Create the encoder and decoder this->updateRule = std::make_unique>((unsigned int)this->H->get_cols_max_degree()); - this->dataBlockSize = this->H->get_n_cols(); - this->codeBlockSize = this->H->get_n_rows(); this->encoder = std::make_unique>(this->dataBlockSize, this->codeBlockSize, *this->H, "IDENTITY", gSavePath, false); - if (this->infoBitPos.size() < 1) this->infoBitPos = this->encoder->get_info_bits_pos(); + this->infoBitPos = this->encoder->get_info_bits_pos(); this->decoder = std::make_unique>(this->dataBlockSize, this->codeBlockSize, numIterations, *this->H, infoBitPos, *this->updateRule, syndromeChecking, minNumIterations); } @@ -69,6 +68,11 @@ class LDPC return this->codeBlockSize; } + float getRate() const + { + return (float)this->dataBlockSize / (float)this->codeBlockSize; + } + int getNumIterations() const { return this->numIterations; @@ -174,6 +178,10 @@ PYBIND11_MODULE(ldpc, m) Number of bits within a code block to be decoded. )pbdoc") + .def_property_readonly("rate", &LDPC::getRate, R"pbdoc( + Coding rate of the LDPC code. + )pbdoc") + .def_property("num_iterations", &LDPC::getNumIterations, &LDPC::setNumIterations, R"pbdoc( Number of iterations during decoding. )pbdoc") @@ -182,12 +190,11 @@ PYBIND11_MODULE(ldpc, m) C++ bindings are always enabled. )pbdoc") - .def(py::pickle( - [](const LDPC& ldpc) { - return py::make_tuple(ldpc.getNumIterations(), ldpc.getHSourcePath(), ldpc.getGSavePath(), ldpc.getSyndromeChecking(), ldpc.getMinNumIterations()); - }, - [](py::tuple t) { - return LDPC(t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast()); - } - )); + .def("__getstate__", [](const LDPC& ldpc) { + return py::make_tuple(ldpc.getNumIterations(), ldpc.getHSourcePath(), ldpc.getGSavePath(), ldpc.getSyndromeChecking(), ldpc.getMinNumIterations()); + }) + + .def("__setstate__", [](LDPC& ldpc, py::tuple t) { + new (&ldpc) LDPC{t[0].cast(), t[1].cast(), t[2].cast(), t[3].cast(), t[4].cast()}; + }); } \ No newline at end of file diff --git a/tests/unit_tests/fec/test_ldpc.py b/tests/unit_tests/fec/test_ldpc.py index 4faac5c4..9702d94b 100644 --- a/tests/unit_tests/fec/test_ldpc.py +++ b/tests/unit_tests/fec/test_ldpc.py @@ -28,40 +28,61 @@ class TestLDPCCoding(TestCase): def setUp(self) -> None: # Infer the aff3ct folder location of matrix files relative to this test - h_directory = path.join(path.dirname(path.realpath(__file__)), "..", "..", "..", "submodules", "affect", "conf", "dec", "LDPC") - self.h_path = path.join(h_directory, "CCSDS_64_128.alist") + self.h_directory = path.join(path.dirname(path.realpath(__file__)), "..", "..", "..", "submodules", "affect", "conf", "dec", "LDPC") + + self.h_candidates = [ + 'DEBUG_6_3', + 'CCSDS_64_128', + 'MACKAY_504_1008', + 'WIMAX_288_576', + ] self.g_directory = mkdtemp() self.g_path = path.join(self.g_directory, "test.alist") self.rng = default_rng(42) - self.num_attempts = 10 - self.num_iterations = 10 - - self.coding = LDPCCoding(self.num_iterations, self.h_path, self.g_path, False, 10) + self.num_attempts = 20 + self.num_iterations = 100 def tearDown(self) -> None: rmtree(self.g_directory) - def _test_encode_decode(self) -> None: + def test_encode_decode(self) -> None: """Encoding a data block should yield a valid code.""" - - for i in range(self.num_attempts): - data_block = self.rng.integers(0, 2, self.coding.bit_block_size, dtype=np.int32) - flip_index = self.rng.integers(0, self.coding.bit_block_size) - - code_block = self.coding.encode(data_block) - code_block[flip_index] = not bool(code_block[flip_index]) - - decoded_block = self.coding.decode(code_block) - assert_array_equal(data_block, decoded_block) - - def _test_pickle(self) -> None: + + for h_candidate in self.h_candidates: + with self.subTest(h_candidate=h_candidate): + h_path = path.join(self.h_directory, h_candidate + ".alist") + coding = LDPCCoding(self.num_iterations, h_path, self.g_path, True, 10) + + errors = 0 + for _ in range(self.num_attempts): + data_block = self.rng.integers(0, 2, coding.bit_block_size, dtype=np.int32) + flip_index = self.rng.integers(0, coding.bit_block_size, dtype=np.int32) + + code_block = coding.encode(data_block) + code_block[flip_index] = not bool(code_block[flip_index]) + + decoded_block = coding.decode(code_block) + errors += np.sum(data_block != decoded_block) + + self.assertGreater(1, errors, msg=f"Too many errors: {errors}") + + def test_pickle(self) -> None: """Pickeling and unpickeling the C++ wrapper""" + coding = LDPCCoding(self.num_iterations, path.join(self.h_directory, self.h_candidates[0] + '.alist'), "", False, 10) + with NamedTemporaryFile() as file: - dump(self.coding, file) + dump(coding, file) file.seek(0) - coding = load(file) - self.assertEqual(self.num_iterations, coding.num_iterations) + deserialized_coding = load(file) + self.assertEqual(self.num_iterations, deserialized_coding.num_iterations) + + # Actuall run a full encoding and decoding with the unpickled object + data_block = self.rng.integers(0, 2, deserialized_coding.bit_block_size, dtype=np.int32) + code_block = deserialized_coding.encode(data_block) + decoded_block = deserialized_coding.decode(code_block) + + assert_array_equal(data_block, decoded_block)