diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..01dbb89 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,14 @@ +version: 2 +updates: + - package-ecosystem: cargo + directory: / + schedule: + interval: weekly + - package-ecosystem: pip + directory: / + schedule: + interval: weekly + - package-ecosystem: github-actions + directory: / + schedule: + interval: weekly diff --git a/.github/workflow/ci.yml b/.github/workflow/ci.yml new file mode 100644 index 0000000..3c7eb28 --- /dev/null +++ b/.github/workflow/ci.yml @@ -0,0 +1,120 @@ +# This file is autogenerated by maturin v1.3.0 +# To update, run +# +# maturin generate-ci github +# +name: CI + +on: + push: + branches: + - main + - master + tags: + - '*' + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ubuntu-latest + strategy: + matrix: + target: [x86_64, x86, aarch64, armv7, s390x, ppc64le] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + windows: + runs-on: windows-latest + strategy: + matrix: + target: [x64, x86] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + architecture: ${{ matrix.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + macos: + runs-on: macos-latest + strategy: + matrix: + target: [x86_64, aarch64] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.target }} + args: --release --out dist + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v3 + with: + name: wheels + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: "startsWith(github.ref, 'refs/tags/')" + needs: [linux, windows, macos, sdist] + steps: + - uses: actions/download-artifact@v3 + with: + name: wheels + - name: Publish to PyPI + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing * diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..af3ca5e --- /dev/null +++ b/.gitignore @@ -0,0 +1,72 @@ +/target + +# Byte-compiled / optimized / DLL files +__pycache__/ +.pytest_cache/ +*.py[cod] + +# C extensions +*.so + +# Distribution / packaging +.Python +.venv/ +env/ +bin/ +build/ +develop-eggs/ +dist/ +eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +include/ +man/ +venv/ +*.egg-info/ +.installed.cfg +*.egg + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt +pip-selfcheck.json + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.cache +nosetests.xml +coverage.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Rope +.ropeproject + +# Django stuff: +*.log +*.pot + +.DS_Store + +# Sphinx documentation +docs/_build/ + +# PyCharm +.idea/ + +# VSCode +.vscode/ + +# Pyenv +.python-version \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..8c87fa0 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,313 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "general-sam" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6c977936300f6022f245ccdde355c3768c4144239db83ea0d2992747bde46bc" + +[[package]] +name = "general-sam-py" +version = "0.1.1" +dependencies = [ + "either", + "general-sam", + "pyo3", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + +[[package]] +name = "libc" +version = "0.2.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" + +[[package]] +name = "lock_api" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +dependencies = [ + "autocfg", + "scopeguard", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "parking_lot" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-targets", +] + +[[package]] +name = "proc-macro2" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134c189feb4956b20f6f547d2cf727d4c0fe06722b20a0eec87ed445a97f92da" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +dependencies = [ + "once_cell", + "python3-dll-a", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "python3-dll-a" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f07cd4412be8fa09a721d40007c483981bbe072cd6a21f2e83e04ec8f8343f" +dependencies = [ + "cc", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "smallvec" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" + +[[package]] +name = "syn" +version = "2.0.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..f1f3ff4 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "general-sam-py" +version = "0.1.1" +edition = "2021" +license = "MIT OR Apache-2.0" +description = "Python bindings for general-sam and some utilities" +repository = "https://github.com/ModelTC/general-sam" +homepage = "https://github.com/ModelTC/general-sam/tree/main/pybind" +readme = "README.md" +authors = ["Chielo Newctle "] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "general_sam" +crate-type = ["cdylib"] + +[dependencies] +general-sam = "0.1.1" +pyo3 = { version = "0.20.0", features = [ + "extension-module", + "abi3-py38", + "generate-import-lib", +] } +either = "1.9.0" diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..c98d27d --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + https://www.apache.org/licenses/LICENSE-2.0 + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/LICENSE-MIT b/LICENSE-MIT new file mode 100644 index 0000000..dd66111 --- /dev/null +++ b/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Chielo Newctle +Copyright (c) 2023 ModelTC Team + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is furnished +to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS +OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF +OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..7e1167b --- /dev/null +++ b/README.md @@ -0,0 +1,164 @@ +# general-sam-py + +![License](https://img.shields.io/badge/license-MIT%2FApache--2.0-informational?style=flat-square) + +Python bindings for [`general-sam`](https://github.com/ModelTC/general-sam) +and some utilities. + +| [![the suffix automaton of abcbc][sam-of-abcbc]][sam-oi-wiki] | +| :----------------------------------------------------------------------------: | +| The suffix automaton of abcbc, image from [后缀自动机 - OI Wiki][sam-oi-wiki]. | + +[sam-of-abcbc]: https://oi-wiki.org/string/images/SAM/SA_suffix_links.svg +[sam-oi-wiki]: https://oi-wiki.org/string/sam/ + +## Usage + +### `GeneralSAM` + +```python +from general_sam import GeneralSAM + + +sam = GeneralSAM.construct_from_bytes(b'abcbc') + +state = sam.get_root_state() +state.feed_bytes(b'cbc') +assert state.is_accepting() + +state = sam.get_root_state() +state.feed_bytes(b'bcb') +assert not state.is_accepting() +``` + +```python +from general_sam import GeneralSAM + + +sam = GeneralSAM.construct_from_chars('abcbc') +state = sam.get_root_state() + +state.feed_chars('b') +assert not state.is_accepting() +state.feed_chars('c') +assert state.is_accepting() +state.feed_chars('bc') +assert state.is_accepting() +state.feed_chars('bc') +assert not state.is_accepting() and state.is_nil() +``` + +```python +from general_sam import GeneralSAM, GeneralSAMState, construct_trie_from_chars + + +trie, _ = construct_trie_from_chars(['hello', 'Chielo']) +sam = GeneralSAM.construct_from_trie(trie) + +def fetch_state(s: str) -> GeneralSAMState: + state = sam.get_root_state() + state.feed_chars(s) + return state + +assert fetch_state('lo').is_accepting() +assert fetch_state('ello').is_accepting() +assert fetch_state('elo').is_accepting() + +state = fetch_state('el') +assert not state.is_accepting() and not state.is_nil() + +state = fetch_state('bye') +assert not state.is_accepting() and state.is_nil() +``` + +### `VocabPrefixAutomaton` + +```python +from general_sam import VocabPrefixAutomaton, CountInfo + + +vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'] +automaton = VocabPrefixAutomaton(vocab, bytes_or_chars='chars') + +# NOTE: CountInfo is related to the sorted vocab: +_ = ['播放歌曲', '查看歌词', '歌曲', '歌词', '聆听歌曲'] + +# 一起 | 聆 | 听 | 歌 +state = automaton.get_root_state() + +# feed 歌 +cnt_info = automaton.prepend_feed(state, '歌') +assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=2, tot_cnt_lower=2, tot_cnt_upper=4 +) + +selected_idx = automaton.get_order_slice(cnt_info) +assert frozenset(selected_idx) == {0, 3} +selected_vocab = [vocab[i] for i in selected_idx] +assert frozenset(selected_vocab) == {'歌曲', '歌词'} + +# feed 听 +cnt_info = automaton.prepend_feed(state, '听') +assert cnt_info is None +assert not state.is_nil() + +# feed 聆 +cnt_info = automaton.prepend_feed(state, '聆') +assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=4, tot_cnt_upper=5 +) + +selected_idx = automaton.get_order_slice(cnt_info) +assert frozenset(selected_idx) == {1} +selected_vocab = [vocab[i] for i in selected_idx] +assert frozenset(selected_vocab) == {'聆听歌曲'} + +# feed 一起 +assert not state.is_nil() +cnt_info = automaton.prepend_feed(state, '一起') +assert state.is_nil() + +# 来 | 查看 | 歌词 +state = automaton.get_root_state() + +# feed 歌词 +cnt_info = automaton.prepend_feed(state, '歌词') +assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=3, tot_cnt_upper=4 +) + +selected_idx = automaton.get_order_slice(cnt_info) +assert frozenset(selected_idx) == {3} +selected_vocab = [vocab[i] for i in selected_idx] +assert frozenset(selected_vocab) == {'歌词'} + +# feed 查看 +cnt_info = automaton.prepend_feed(state, '查看') +assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=1, tot_cnt_upper=2 +) + +selected_idx = automaton.get_order_slice(cnt_info) +assert frozenset(selected_idx) == {4} +selected_vocab = [vocab[i] for i in selected_idx] +assert frozenset(selected_vocab) == {'查看歌词'} + +# feed 来 +assert not state.is_nil() +cnt_info = automaton.prepend_feed(state, '来') +assert state.is_nil() +``` + +## License + +- © 2023 Chielo Newctle \ +- © 2023 ModelTC Team + +This project is licensed under either of + +- [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0) ([`LICENSE-APACHE`](LICENSE-APACHE)) +- [MIT license](https://opensource.org/licenses/MIT) ([`LICENSE-MIT`](LICENSE-MIT)) + +at your option. + +The [SPDX](https://spdx.dev) license identifier for this project is `MIT OR Apache-2.0`. diff --git a/general_sam/__init__.py b/general_sam/__init__.py new file mode 100644 index 0000000..f4259ff --- /dev/null +++ b/general_sam/__init__.py @@ -0,0 +1,35 @@ +from .general_sam import ( + GeneralSAM, + GeneralSAMState, + Trie, + TrieNode, +) +from .trie_utils import ( + CountInfo, + SortResult, + construct_trie_from_bytes, + construct_trie_from_chars, + sort_bytes, + sort_chars, + sort_seq_via_trie, +) +from .vocab_prefix import ( + VocabPrefixAutomaton, + VocabPrefixBytesOrChars, +) + +__all__ = [ + 'GeneralSAM', + 'GeneralSAMState', + 'Trie', + 'TrieNode', + 'CountInfo', + 'SortResult', + 'construct_trie_from_chars', + 'construct_trie_from_bytes', + 'sort_chars', + 'sort_bytes', + 'sort_seq_via_trie', + 'VocabPrefixAutomaton', + 'VocabPrefixBytesOrChars', +] diff --git a/general_sam/general_sam.pyi b/general_sam/general_sam.pyi new file mode 100644 index 0000000..7181913 --- /dev/null +++ b/general_sam/general_sam.pyi @@ -0,0 +1,165 @@ +from typing import Callable, Mapping, Optional, Sequence, Union + + +class TrieNode: + def is_in_chars(self) -> bool: + ... + + def is_in_bytes(self) -> bool: + ... + + def get_node_id(self) -> int: + ... + + def is_accepting(self) -> bool: + ... + + def get_trans(self) -> Mapping[Union[str, int], int]: + ... + + def get_parent(self) -> int: + ... + + +class Trie: + @staticmethod + def in_chars() -> 'Trie': + ... + + @staticmethod + def in_bytes() -> 'Trie': + ... + + def is_in_chars(self) -> bool: + ... + + def is_in_bytes(self) -> bool: + ... + + def num_of_nodes(self) -> int: + ... + + def insert_chars(self, s: str) -> int: + ... + + def insert_bytes(self, s: bytes) -> int: + ... + + def get_bfs_order(self) -> Sequence[int]: + ... + + def get_root(self) -> TrieNode: + ... + + def get_node(self, node_id: int) -> Optional[TrieNode]: + ... + + def dfs_travel( + self, + in_stack_callback: Callable[[int, Optional[str]], None], + out_stack_callback: Callable[[int], None], + root_node_id: Optional[int] = None, + ) -> TrieNode: + ... + + def bfs_travel( + self, + in_queue_callback: Callable[[int, Optional[str]], None], + out_queue_callback: Callable[[int], None], + root_node_id: Optional[int] = None, + ) -> TrieNode: + ... + + +class GeneralSAMState: + def is_in_str(self) -> bool: + ... + + def is_in_bytes(self) -> bool: + ... + + def get_node_id(self) -> int: + ... + + def is_nil(self) -> bool: + ... + + def is_root(self) -> bool: + ... + + def is_accepting(self) -> bool: + ... + + def get_trans(self) -> Mapping[Union[str, int], int]: + ... + + def get_suffix_parent_id(self) -> int: + ... + + def copy(self) -> 'GeneralSAMState': + ... + + def goto_suffix_parent(self): + ... + + def goto_char(self, t: str): + ... + + def goto_byte(self, t: int): + ... + + def feed_chars(self, s: str): + ... + + def feed_bytes(self, s: bytes): + ... + + def dfs_along( + self, + trie: Trie, + in_stack_callback: Callable[['GeneralSAMState', int, Optional[str]], None], + out_stack_callback: Callable[['GeneralSAMState', int], None], + trie_node_id: Optional[int] = None, + ) -> TrieNode: + ... + + def bfs_along( + self, + trie: Trie, + in_queue_callback: Callable[['GeneralSAMState', int, Optional[str]], None], + out_queue_callback: Callable[['GeneralSAMState', int], None], + trie_node_id: Optional[int] = None, + ) -> TrieNode: + ... + + +class GeneralSAM: + @staticmethod + def construct_from_chars(s: str) -> 'GeneralSAM': + ... + + @staticmethod + def construct_from_bytes(s: bytes) -> 'GeneralSAM': + ... + + @staticmethod + def construct_from_trie(trie: Trie) -> 'GeneralSAM': + ... + + def is_in_str(self) -> bool: + ... + + def is_in_bytes(self) -> bool: + ... + + def num_of_nodes(self) -> int: + ... + + def get_root_state(self) -> GeneralSAMState: + ... + + def get_state(self, node_id: int) -> GeneralSAMState: + ... + + def get_topo_order(self) -> Sequence[GeneralSAMState]: + ... diff --git a/general_sam/trie_utils.py b/general_sam/trie_utils.py new file mode 100644 index 0000000..1c3469d --- /dev/null +++ b/general_sam/trie_utils.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +from typing import Collection, Sequence, Tuple + +from .general_sam import Trie + + +def construct_trie_from_chars( + strings: Collection[str], +) -> Tuple[Trie, Sequence[int]]: + trie = Trie.in_chars() + node_ids = [trie.insert_chars(s) for s in strings] + return trie, node_ids + + +def construct_trie_from_bytes( + strings: Collection[bytes], +) -> Tuple[Trie, Sequence[int]]: + trie = Trie.in_bytes() + node_ids = [trie.insert_bytes(s) for s in strings] + return trie, node_ids + + +@dataclass +class CountInfo: + str_cnt: int + tot_cnt_lower: int + tot_cnt_upper: int + + +@dataclass +class SortResult: + trie: Trie + node_ids: Sequence[int] + cnt_info_on_nodes: Sequence[CountInfo] + cnt_info_on_strings: Sequence[CountInfo] + order: Sequence[int] + rank: Sequence[int] + + +def sort_chars(strings: Collection[str]) -> SortResult: + trie, node_ids = construct_trie_from_chars(strings) + return sort_seq_via_trie(trie, node_ids) + + +def sort_bytes(strings: Collection[bytes]) -> SortResult: + trie, node_ids = construct_trie_from_bytes(strings) + return sort_seq_via_trie(trie, node_ids) + + +def sort_seq_via_trie(trie: Trie, node_ids: Sequence[int]) -> SortResult: + num_of_seq = len(node_ids) + + cnt_info_on_nodes = [CountInfo(0, 0, 0) for _ in range(trie.num_of_nodes())] + for k in node_ids: + cnt_info_on_nodes[k].str_cnt += 1 + + tot_str_cnt = 0 + + def in_stack(node_id: int, _): + nonlocal tot_str_cnt + node_info = cnt_info_on_nodes[node_id] + node_info.tot_cnt_lower = tot_str_cnt + tot_str_cnt += node_info.str_cnt + + def out_stack(node_id: int): + nonlocal tot_str_cnt + node_info = cnt_info_on_nodes[node_id] + node_info.tot_cnt_upper = tot_str_cnt + + trie.dfs_travel(in_stack, out_stack) + + cnt_info_on_strings = [cnt_info_on_nodes[node_ids[i]] for i in range(num_of_seq)] + + order = sorted( + range(num_of_seq), + key=lambda i: cnt_info_on_strings[i].tot_cnt_lower, + ) + rank = [0] * num_of_seq + for k, i in enumerate(order): + rank[i] = k + + return SortResult( + trie=trie, + node_ids=node_ids, + cnt_info_on_nodes=cnt_info_on_nodes, + cnt_info_on_strings=cnt_info_on_strings, + order=order, + rank=rank, + ) diff --git a/general_sam/vocab_prefix.py b/general_sam/vocab_prefix.py new file mode 100644 index 0000000..dcf98af --- /dev/null +++ b/general_sam/vocab_prefix.py @@ -0,0 +1,148 @@ +import enum +from dataclasses import replace +from typing import ( + Callable, + Iterable, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) + +from .general_sam import GeneralSAM, GeneralSAMState, Trie +from .trie_utils import ( + CountInfo, + SortResult, + construct_trie_from_bytes, + construct_trie_from_chars, + sort_bytes, + sort_chars, +) + + +class VocabPrefixBytesOrChars(enum.Enum): + BYTES = enum.auto() + CHARS = enum.auto() + + +class VocabPrefixAutomaton(object): + def __init__( + self, + vocab: Iterable[Union[str, bytes]], + bytes_or_chars: Union[ + str, VocabPrefixBytesOrChars + ] = VocabPrefixBytesOrChars.CHARS, + ) -> None: + if isinstance(bytes_or_chars, str): + bytes_or_chars = getattr(VocabPrefixBytesOrChars, bytes_or_chars.upper()) + + self.bytes_or_chars = cast(VocabPrefixBytesOrChars, bytes_or_chars) + + self.vocab: Sequence[Union[str, bytes]] = list(vocab) + + if self.bytes_or_chars == VocabPrefixBytesOrChars.BYTES and isinstance( + self.vocab[0], str + ): + self.vocab = list(cast(str, i).encode() for i in self.vocab) + if self.bytes_or_chars == VocabPrefixBytesOrChars.CHARS and isinstance( + self.vocab[0], bytes + ): + self.vocab = list(cast(bytes, i).decode() for i in self.vocab) + + self.vocab_rev: Sequence[Union[str, bytes]] = list(s[::-1] for s in vocab) + + sort_seq, construct_trie = { + VocabPrefixBytesOrChars.BYTES: (sort_bytes, construct_trie_from_bytes), + VocabPrefixBytesOrChars.CHARS: (sort_chars, construct_trie_from_chars), + }[self.bytes_or_chars] + self.vocab_sort_res = cast(SortResult, sort_seq(self.vocab)) + self.trie_rev, self.trie_rev_node_ids = cast( + Tuple[Trie, Sequence[int]], + construct_trie(self.vocab_rev), + ) + + self.sam_rev = GeneralSAM.construct_from_trie(self.trie_rev) + self._gen_cnt_info_in_sam() + + @property + def _state_feed_fn(self) -> Callable[[GeneralSAMState, Union[bytes, str]], None]: + return { + VocabPrefixBytesOrChars.BYTES: GeneralSAMState.feed_bytes, + VocabPrefixBytesOrChars.CHARS: GeneralSAMState.feed_chars, + }[self.bytes_or_chars] + + def _gen_cnt_info_in_sam(self): + self.cnt_info_in_sam: List[Optional[CountInfo]] = [ + None for _ in range(self.sam_rev.num_of_nodes()) + ] + + for token_rev, cnt_info in zip( + self.vocab_rev, self.vocab_sort_res.cnt_info_on_strings + ): + state = self.sam_rev.get_root_state() + self._state_feed_fn(state, token_rev) + state_id = state.get_node_id() + self.cnt_info_in_sam[state_id] = replace(cnt_info, str_cnt=1) + + for sam_state in reversed(self.sam_rev.get_topo_order()): + assert not sam_state.is_nil() + if sam_state.is_root(): + continue + + state_id = sam_state.get_node_id() + state_cnt_info = self.cnt_info_in_sam[state_id] + if state_cnt_info is None: + continue + + link_id = sam_state.get_suffix_parent_id() + link_cnt_info = self.cnt_info_in_sam[link_id] + + if link_cnt_info is None: + self.cnt_info_in_sam[link_id] = replace(state_cnt_info) + continue + + link_cnt_info.str_cnt += state_cnt_info.str_cnt + link_cnt_info.tot_cnt_lower = min( + link_cnt_info.tot_cnt_lower, + state_cnt_info.tot_cnt_lower, + ) + link_cnt_info.tot_cnt_upper = max( + link_cnt_info.tot_cnt_upper, + state_cnt_info.tot_cnt_upper, + ) + + for state_id in range(self.sam_rev.num_of_nodes()): + sam_state = self.sam_rev.get_state(state_id) + state_cnt_info = self.cnt_info_in_sam[state_id] + if sam_state.is_nil() or sam_state.is_root() or state_cnt_info is None: + continue + + link_id = sam_state.get_suffix_parent_id() + link_cnt_info = self.cnt_info_in_sam[link_id] + + assert link_cnt_info is not None + assert link_cnt_info.tot_cnt_lower <= state_cnt_info.tot_cnt_lower + assert link_cnt_info.tot_cnt_upper >= state_cnt_info.tot_cnt_upper + + def get_root_state(self) -> GeneralSAMState: + return self.sam_rev.get_root_state() + + def prepend_feed( + self, state: GeneralSAMState, token: Union[str, bytes] + ) -> Optional[CountInfo]: + if self.bytes_or_chars == VocabPrefixBytesOrChars.BYTES and isinstance( + token, str + ): + token = token.encode() + self._state_feed_fn(state, token[::-1]) + return self.cnt_info_in_sam[state.get_node_id()] + + def get_order(self) -> Sequence[int]: + return self.vocab_sort_res.order + + def get_order_slice(self, cnt_info: CountInfo) -> Sequence[int]: + return self.vocab_sort_res.order[ + cnt_info.tot_cnt_lower : cnt_info.tot_cnt_upper + ] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..d33302d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.3,<2.0"] +build-backend = "maturin" + +[project] +name = "general_sam" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dynamic = ["version"] + +[tool.maturin] +features = ["pyo3/extension-module"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..5e8d8f9 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,429 @@ +extern crate general_sam as general_sam_rs; + +use std::{convert::Infallible, str::from_utf8, sync::Arc}; + +use either::{for_both, Either}; +use pyo3::{prelude::*, types::PyDict}; + +use general_sam_rs::{ + sam, trie, + trie_alike::{TravelEvent, TrieNodeAlike}, +}; + +#[pyclass] +struct Trie(Either, trie::Trie>); + +#[pyclass] +struct TrieNode(usize, Either, trie::Node>); + +#[pymethods] +impl TrieNode { + fn is_in_chars(&self) -> bool { + self.1.is_left() + } + + fn is_in_bytes(&self) -> bool { + self.1.is_right() + } + + fn get_node_id(&self) -> usize { + self.0 + } + + fn is_accepting(&self) -> bool { + for_both!(self.1.as_ref(), x => x.accept) + } + + fn get_trans(&self) -> PyObject { + Python::with_gil(|py| { + for_both!(self.1.as_ref(), x => { + x.get_trans().clone().into_py(py) + }) + }) + } + + fn get_parent(&self) -> usize { + for_both!(self.1.as_ref(), x => x.get_parent()) + } +} + +#[pymethods] +impl Trie { + #[staticmethod] + fn in_chars() -> Self { + Trie(Either::Left(trie::Trie::default())) + } + + #[staticmethod] + fn in_bytes() -> Self { + Trie(Either::Right(trie::Trie::default())) + } + + fn is_in_chars(&self) -> bool { + self.0.is_left() + } + + fn is_in_bytes(&self) -> bool { + self.0.is_right() + } + + fn num_of_nodes(&self) -> usize { + for_both!(self.0.as_ref(), x => x.num_of_nodes()) + } + + fn insert_chars(&mut self, s: &str) -> usize { + match self.0.as_mut() { + Either::Left(trie_chars) => trie_chars.insert_iter(s.chars()), + Either::Right(trie_bytes) => trie_bytes.insert_ref_iter(s.as_bytes().iter()), + } + } + + fn insert_bytes(&mut self, b: &[u8]) -> usize { + match self.0.as_mut() { + Either::Left(trie_chars) => trie_chars.insert_iter(from_utf8(b).unwrap().chars()), + Either::Right(trie_bytes) => trie_bytes.insert_ref_iter(b.iter()), + } + } + + fn get_bfs_order(&self) -> Vec { + for_both!(self.0.as_ref(), trie => { + let state = trie.get_root_state(); + let mut res = Vec::new(); + state + .bfs_travel(|event| -> Result<(), Infallible> { + if let TravelEvent::Push(s, _) = event { + res.push(s.node_id); + } + Ok(()) + }) + .unwrap(); + res + }) + } + + fn get_root(&self) -> TrieNode { + self.get_node(trie::TRIE_ROOT_NODE_ID).unwrap() + } + + fn get_node(&self, node_id: usize) -> Option { + match self.0.as_ref() { + Either::Left(trie) => trie + .get_node(node_id) + .map(|node| TrieNode(node_id, Either::Left(node.clone()))), + Either::Right(trie) => trie + .get_node(node_id) + .map(|node| TrieNode(node_id, Either::Right(node.clone()))), + } + } + + #[pyo3(signature = (in_stack_callback, out_stack_callback, root_node_id=None))] + fn dfs_travel( + &self, + in_stack_callback: PyObject, + out_stack_callback: PyObject, + root_node_id: Option, + ) -> Result<(), PyErr> { + for_both!(self.0.as_ref(), trie => { + let root_state = trie.get_state(root_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); + if root_state.is_nil() { + return Ok(()); + } + root_state.dfs_travel(|event| match event { + TravelEvent::Push(tn, key_opt) => Python::with_gil(|py| { + in_stack_callback.call1(py, (tn.node_id, key_opt.copied())) + }) + .map(|_| ()), + TravelEvent::Pop(tn) => { + Python::with_gil(|py| out_stack_callback.call1(py, (tn.node_id,))).map(|_| ()) + } + }) + }) + } + + #[pyo3(signature = (in_stack_callback, out_stack_callback, root_node_id=None))] + fn bfs_travel( + &self, + in_stack_callback: PyObject, + out_stack_callback: PyObject, + root_node_id: Option, + ) -> Result<(), PyErr> { + for_both!(self.0.as_ref(), trie => { + let root_state = trie.get_state(root_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); + if root_state.is_nil() { + return Ok(()); + } + root_state.bfs_travel(|event| match event { + TravelEvent::Push(tn, key_opt) => Python::with_gil(|py| { + in_stack_callback.call1(py, (tn.node_id, key_opt.copied())) + }) + .map(|_| ()), + TravelEvent::Pop(tn) => { + Python::with_gil(|py| out_stack_callback.call1(py, (tn.node_id,))).map(|_| ()) + } + }) + }) + } +} + +#[pyclass] +struct GeneralSAM(Arc, sam::GeneralSAM>>); + +#[pyclass] +#[derive(Clone)] +struct GeneralSAMState( + Arc, sam::GeneralSAM>>, + usize, +); + +impl GeneralSAMState { + fn get_state(&self) -> Either, sam::State> { + self.0 + .as_ref() + .as_ref() + .map_either(|x| x.get_state(self.1), |x| x.get_state(self.1)) + } +} + +#[pymethods] +impl GeneralSAMState { + fn is_in_chars(&self) -> bool { + self.0.is_left() + } + + fn is_in_bytes(&self) -> bool { + self.0.is_right() + } + + fn get_node_id(&self) -> usize { + self.1 + } + + fn is_nil(&self) -> bool { + for_both!(self.get_state().as_ref(), x => x.is_nil()) + } + + fn is_root(&self) -> bool { + for_both!(self.get_state().as_ref(), x => x.is_root()) + } + + fn is_accepting(&self) -> bool { + for_both!(self.get_state().as_ref(), x => x.is_accepting()) + } + + fn get_trans(&self) -> PyObject { + Python::with_gil(|py| { + for_both!(self.get_state().as_ref(), state => { + if let Some(node) = state.get_node() { + node.get_trans().clone().into_py(py) + } else { + PyDict::new(py).into_py(py) + } + }) + }) + } + + fn get_suffix_parent_id(&self) -> usize { + for_both!(self.get_state().as_ref() , x => { + x.get_node() + .map(|node| node.get_suffix_parent_id()) + .unwrap_or(sam::SAM_NIL_NODE_ID) + }) + } + + fn copy(&self) -> Self { + self.clone() + } + + fn goto_suffix_parent(&mut self) { + for_both!(self.get_state(), mut state => { + state.goto_suffix_parent(); + self.1 = state.node_id; + }) + } + + fn goto_char(&mut self, t: char) { + let mut state = self.get_state().left().unwrap(); + state.goto(&t); + self.1 = state.node_id; + } + + fn goto_byte(&mut self, t: u8) { + let mut state = self.get_state().right().unwrap(); + state.goto(&t); + self.1 = state.node_id; + } + + fn feed_chars(&mut self, s: &str) { + match self.get_state() { + Either::Left(state_chars) => { + let state_chars = state_chars.feed_chars(s); + self.1 = state_chars.node_id; + } + Either::Right(state_bytes) => { + let state_bytes = state_bytes.feed_ref_iter(s.as_bytes().iter()); + self.1 = state_bytes.node_id; + } + } + } + + fn feed_bytes(&mut self, s: &[u8]) { + match self.get_state() { + Either::Left(state_chars) => { + let state_chars = state_chars.feed_iter(from_utf8(s).unwrap().chars()); + self.1 = state_chars.node_id; + } + Either::Right(state_bytes) => { + let state_bytes = state_bytes.feed_ref_iter(s.iter()); + self.1 = state_bytes.node_id; + } + } + } + + #[pyo3(signature = (trie, in_stack_callback, out_stack_callback, trie_node_id=None))] + fn dfs_along( + &self, + trie: &Trie, + in_stack_callback: PyObject, + out_stack_callback: PyObject, + trie_node_id: Option, + ) -> Result<(), PyErr> { + assert!(trie.is_in_chars() == self.is_in_chars()); + let sam_and_trie = self.0.as_ref().as_ref().map_either( + |sam_chars| (sam_chars, trie.0.as_ref().left().unwrap()), + |sam_bytes| (sam_bytes, trie.0.as_ref().right().unwrap()), + ); + for_both!(sam_and_trie, (sam, trie) => { + let tn = trie.get_state(trie_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); + sam.dfs_along(tn, self.1, |event| match event { + TravelEvent::Push((st, tn), key_opt) => Python::with_gil(|py| { + in_stack_callback + .call1( + py, + ( + GeneralSAMState(self.0.clone(), st.node_id), + tn.node_id, + key_opt.copied(), + ), + ) + .map(|_| ()) + }) + .map(|_| ()), + TravelEvent::Pop((st, tn)) => Python::with_gil(|py| { + out_stack_callback + .call1( + py, + (GeneralSAMState(self.0.clone(), st.node_id), tn.node_id), + ) + .map(|_| ()) + }), + }) + }) + } + + #[pyo3(signature = (trie, in_stack_callback, out_stack_callback, trie_node_id=None))] + fn bfs_along( + &self, + trie: &Trie, + in_stack_callback: PyObject, + out_stack_callback: PyObject, + trie_node_id: Option, + ) -> Result<(), PyErr> { + assert!(trie.is_in_chars() == self.is_in_chars()); + let sam_and_trie = self.0.as_ref().as_ref().map_either( + |sam_chars| (sam_chars, trie.0.as_ref().left().unwrap()), + |sam_bytes| (sam_bytes, trie.0.as_ref().right().unwrap()), + ); + for_both!(sam_and_trie, (sam, trie) => { + let tn = trie.get_state(trie_node_id.unwrap_or(trie::TRIE_ROOT_NODE_ID)); + sam.bfs_along(tn, self.1, |event| match event { + TravelEvent::Push((st, tn), key_opt) => Python::with_gil(|py| { + in_stack_callback + .call1( + py, + ( + GeneralSAMState(self.0.clone(), st.node_id), + tn.node_id, + key_opt.copied(), + ), + ) + .map(|_| ()) + }) + .map(|_| ()), + TravelEvent::Pop((st, tn)) => Python::with_gil(|py| { + out_stack_callback + .call1( + py, + (GeneralSAMState(self.0.clone(), st.node_id), tn.node_id), + ) + .map(|_| ()) + }), + }) + }) + } +} + +#[pymethods] +impl GeneralSAM { + #[staticmethod] + fn construct_from_chars(s: &str) -> Self { + GeneralSAM(Arc::new(Either::Left( + sam::GeneralSAM::construct_from_chars(s.chars()), + ))) + } + + #[staticmethod] + fn construct_from_bytes(s: &[u8]) -> Self { + GeneralSAM(Arc::new(Either::Right( + sam::GeneralSAM::construct_from_bytes(s), + ))) + } + + #[staticmethod] + fn construct_from_trie(trie: &Trie) -> Self { + match trie.0.as_ref() { + Either::Left(trie_chars) => GeneralSAM(Arc::new(Either::Left( + sam::GeneralSAM::construct_from_trie(trie_chars.get_root_state()), + ))), + Either::Right(trie_bytes) => GeneralSAM(Arc::new(Either::Right( + sam::GeneralSAM::construct_from_trie(trie_bytes.get_root_state()), + ))), + } + } + + fn is_in_chars(&self) -> bool { + self.0.is_left() + } + + fn is_in_bytes(&self) -> bool { + self.0.is_right() + } + + fn num_of_nodes(&self) -> usize { + for_both!(self.0.as_ref(), x => x.num_of_nodes()) + } + + fn get_root_state(&self) -> GeneralSAMState { + GeneralSAMState(self.0.clone(), sam::SAM_ROOT_NODE_ID) + } + + fn get_state(&self, node_id: usize) -> GeneralSAMState { + GeneralSAMState(self.0.clone(), node_id) + } + + fn get_topo_order(&self) -> Vec { + for_both!(self.0.as_ref(), x => { + x.get_topo_order() + .map(|s| self.get_state(s.node_id)) + .collect() + }) + } +} + +#[pymodule] +fn general_sam(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} diff --git a/tests/test_general_sam.py b/tests/test_general_sam.py new file mode 100644 index 0000000..f3db3af --- /dev/null +++ b/tests/test_general_sam.py @@ -0,0 +1,47 @@ +from general_sam import GeneralSAM, GeneralSAMState, construct_trie_from_chars + + +def test_bytes_abcbc(): + sam = GeneralSAM.construct_from_bytes(b'abcbc') + + state = sam.get_root_state() + state.feed_bytes(b'cbc') + assert state.is_accepting() + + state = sam.get_root_state() + state.feed_bytes(b'bcb') + assert not state.is_accepting() + + +def test_chars_abcbc(): + sam = GeneralSAM.construct_from_chars('abcbc') + state = sam.get_root_state() + + state.feed_chars('b') + assert not state.is_accepting() + state.feed_chars('c') + assert state.is_accepting() + state.feed_chars('bc') + assert state.is_accepting() + state.feed_chars('bc') + assert not state.is_accepting() and state.is_nil() + + +def test_simple_sam_from_trie(): + trie, _ = construct_trie_from_chars(['hello', 'Chielo']) + sam = GeneralSAM.construct_from_trie(trie) + + def fetch_state(s: str) -> GeneralSAMState: + state = sam.get_root_state() + state.feed_chars(s) + return state + + assert fetch_state('lo').is_accepting() + assert fetch_state('ello').is_accepting() + assert fetch_state('elo').is_accepting() + + state = fetch_state('el') + assert not state.is_accepting() and not state.is_nil() + + state = fetch_state('bye') + assert not state.is_accepting() and state.is_nil() diff --git a/tests/test_token_healing.py b/tests/test_token_healing.py new file mode 100644 index 0000000..86c5ec4 --- /dev/null +++ b/tests/test_token_healing.py @@ -0,0 +1,134 @@ +from typing import Collection, Iterable, Optional, Sequence, Union + +from general_sam import ( + CountInfo, + GeneralSAMState, + VocabPrefixAutomaton, + VocabPrefixBytesOrChars, +) + + +def _test_token_healing_batch( + vocab: Collection[Union[str, bytes]], + token_sequences: Iterable[Union[Sequence[str], Sequence[bytes]]], + bytes_or_chars: VocabPrefixBytesOrChars, +): + automaton = VocabPrefixAutomaton(vocab, bytes_or_chars=bytes_or_chars) + + vocab_sorted = sorted(vocab) + + def validate( + query: Union[str, bytes], state: GeneralSAMState, cnt_info: Optional[CountInfo] + ): + import bisect + + expected_l = bisect.bisect_left( + vocab_sorted, query, key=lambda x: x[: len(query)] + ) + expected_r = bisect.bisect_right( + vocab_sorted, query, key=lambda x: x[: len(query)] + ) + + if expected_l < expected_r: + expected_cnt_info = CountInfo( + str_cnt=expected_r - expected_l, + tot_cnt_lower=expected_l, + tot_cnt_upper=expected_r, + ) + else: + expected_cnt_info = None + + assert cnt_info == expected_cnt_info, (query, cnt_info, expected_cnt_info) + + assert state.is_nil() ^ any(query in i for i in vocab) # pyright: ignore + + def check(tokens: Sequence[Union[str, bytes]]): + state = automaton.get_root_state() + query = '' if isinstance(tokens[0], str) else b'' + + # NOTE: tokens are prepended in the reverse order + for token in reversed(tokens): + query = token + query # pyright: ignore + cnt_info = automaton.prepend_feed(state, token) + validate(query, state, cnt_info) + + for tokens in token_sequences: + check(tokens) + + +def _test_batch( + vocab: Collection[str], + token_sequences: Iterable[Union[Sequence[str], Sequence[bytes]]], +): + _test_token_healing_batch( + vocab, + tuple(filter(lambda x: isinstance(x[0], str), token_sequences)), + VocabPrefixBytesOrChars.CHARS, + ) + _test_token_healing_batch( + tuple(i.encode() for i in vocab), + tuple( + tuple(i.encode() if isinstance(i, str) else i for i in s) + for s in token_sequences + ), + VocabPrefixBytesOrChars.BYTES, + ) + + +def test_simple_token_healing(): + _test_batch( + ['bb', 'ca', 'ab', 'c', 'aa', 'bbaa', 'a', 'cc', 'b'], + [ + ['bb', 'a'], + ['b', 'b', 'b'], + ['b', 'b', 'a'], + ['b', 'ba'], + ['ca', 'c', 'ab'], + ['c', 'c', 'c'], + ], + ) + + +def test_simple_chinese_token_healing(): + _test_batch( + ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'], + [ + ['歌曲'], + ['聆听歌曲'], + ['聆听', '歌曲'], + ['聆', '听', '歌曲'], + ['播放歌曲'], + ['播', '放歌曲'], + ['播放', '歌曲'], + ['歌词'], + ['查看歌词'], + ['查看', '歌词'], + ['听歌曲'], + ['听', '歌曲'], + ['放歌曲'], + ['听歌'], + ['放歌'], + ['词'], + ['查看'], + ['bb', 'a'], + ['b', 'b', 'b'], + ['b', 'b', 'a'], + ['b', 'ba'], + ['ca', 'c', 'ab'], + ['c', 'c', 'c'], + ], + ) + + +def test_simple_utf8_token_healing(): + # '䨻'.encode('utf8') == b'\xe4\xa8\xbb' + _test_batch( + ['䨻'], + [ + ['䨻'], + [b'\xe4', b'\xa8', b'\xbb'], + [b'\xe4', b'\xa8\xbb'], + [b'\xe4\xa8', b'\xbb'], + [b'\xe4\xa8\xbb'], + ], + ) diff --git a/tests/test_vocab_prefix.py b/tests/test_vocab_prefix.py new file mode 100644 index 0000000..f741008 --- /dev/null +++ b/tests/test_vocab_prefix.py @@ -0,0 +1,74 @@ +from general_sam import VocabPrefixAutomaton, CountInfo + + +def test_chinese_chars_vocab_prefix(): + vocab = ['歌曲', '聆听歌曲', '播放歌曲', '歌词', '查看歌词'] + automaton = VocabPrefixAutomaton(vocab, bytes_or_chars='chars') + + # NOTE: CountInfo is related to the sorted vocab: + _ = ['播放歌曲', '查看歌词', '歌曲', '歌词', '聆听歌曲'] + + # 一起 | 聆 | 听 | 歌 + state = automaton.get_root_state() + + # feed 歌 + cnt_info = automaton.prepend_feed(state, '歌') + assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=2, tot_cnt_lower=2, tot_cnt_upper=4 + ) + + selected_idx = automaton.get_order_slice(cnt_info) + assert frozenset(selected_idx) == {0, 3} + selected_vocab = [vocab[i] for i in selected_idx] + assert frozenset(selected_vocab) == {'歌曲', '歌词'} + + # feed 听 + cnt_info = automaton.prepend_feed(state, '听') + assert cnt_info is None + assert not state.is_nil() + + # feed 聆 + cnt_info = automaton.prepend_feed(state, '聆') + assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=4, tot_cnt_upper=5 + ) + + selected_idx = automaton.get_order_slice(cnt_info) + assert frozenset(selected_idx) == {1} + selected_vocab = [vocab[i] for i in selected_idx] + assert frozenset(selected_vocab) == {'聆听歌曲'} + + # feed 一起 + assert not state.is_nil() + cnt_info = automaton.prepend_feed(state, '一起') + assert state.is_nil() + + # 来 | 查看 | 歌词 + state = automaton.get_root_state() + + # feed 歌词 + cnt_info = automaton.prepend_feed(state, '歌词') + assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=3, tot_cnt_upper=4 + ) + + selected_idx = automaton.get_order_slice(cnt_info) + assert frozenset(selected_idx) == {3} + selected_vocab = [vocab[i] for i in selected_idx] + assert frozenset(selected_vocab) == {'歌词'} + + # feed 查看 + cnt_info = automaton.prepend_feed(state, '查看') + assert cnt_info is not None and cnt_info == CountInfo( + str_cnt=1, tot_cnt_lower=1, tot_cnt_upper=2 + ) + + selected_idx = automaton.get_order_slice(cnt_info) + assert frozenset(selected_idx) == {4} + selected_vocab = [vocab[i] for i in selected_idx] + assert frozenset(selected_vocab) == {'查看歌词'} + + # feed 来 + assert not state.is_nil() + cnt_info = automaton.prepend_feed(state, '来') + assert state.is_nil()