diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d05f4d..b7ce470 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,6 +23,5 @@ jobs: - name: Run tests run: | - cd xgit - coverage run -m pytest . + coverage run -m pytest xgit coverage report -m diff --git a/.gitignore b/.gitignore index 0d83276..45a8579 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ .mypy_cache *.egg-info/ .pytest_cache/ +.coverage diff --git a/xgit/cli.py b/xgit/cli.py index 19302b0..7363af1 100644 --- a/xgit/cli.py +++ b/xgit/cli.py @@ -1,6 +1,6 @@ import typer -from xgit.commands import init, cat_file, hash_object +from xgit.commands import init, cat_file, ls_files, hash_object app = typer.Typer(add_completion=False, rich_markup_mode="markdown") @@ -8,6 +8,7 @@ app.command()(hash_object.hash_object) app.command()(init.init) app.command()(cat_file.cat_file) +app.command()(ls_files.ls_files) def main(): diff --git a/xgit/commands/ls_files.py b/xgit/commands/ls_files.py new file mode 100644 index 0000000..766d5a0 --- /dev/null +++ b/xgit/commands/ls_files.py @@ -0,0 +1,26 @@ +from pathlib import Path + +import typer +from typer import Option +from typing_extensions import Annotated + +from xgit.types.index import get_index +from xgit.utils.utils import find_repo + + +def ls_files( + full_name: Annotated[bool, Option("--full-name", help="输出相对于项目根目录,而非当前目录")] = False, +): + """ + 输出 index 中在当前目录下的所有文件 + """ + index = get_index() + for entry in index.entries: + f = find_repo() / entry.file_name + cwd = Path.cwd() + + if f.is_relative_to(cwd): + if full_name: + typer.echo(entry.file_name) + else: + typer.echo(f.relative_to(cwd)) diff --git a/xgit/commands/update_index.py b/xgit/commands/update_index.py new file mode 100644 index 0000000..32cac84 --- /dev/null +++ b/xgit/commands/update_index.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import typer +from typer import Option, Argument +from typing_extensions import Annotated, Optional + +from xgit.types.index import get_index +from xgit.utils.utils import find_repo + +def update_index( + files: Annotated[Optional[list[str]], Argument(help="要更新的文件")] = None, + add: Annotated[bool, Option("--add", help="如果文件在暂存区中不存在,则加入暂存区")] = False, + remove: Annotated[bool, Option("--remove", help="如果文件在暂存区存在,但在本地不存在,则从暂存区移除")] = False, + force_remove: Annotated[bool, Option("--force-remove", help="从暂存区移除文件,即使文件在本地存在")] = False, + refresh: Annotated[bool, Option("--refresh", help="刷新暂存区中的文件状态")] = False, + verbose: Annotated[bool, Option("-v", help="详细输出添加和删除的文件")] = False, + cacheinfo: Annotated[str, Option("--cacheinfo", help="与 `--add` 一同使用,用 `,,` 指定一个 blob 加入暂存区")] = "", +): + """ + 更新暂存区 + """ + + \ No newline at end of file diff --git a/xgit/test/test_ls_files.py b/xgit/test/test_ls_files.py new file mode 100644 index 0000000..b027d21 --- /dev/null +++ b/xgit/test/test_ls_files.py @@ -0,0 +1,19 @@ +import os + +from typer.testing import CliRunner + +from xgit.test.test_utils import check_same_output + +runner = CliRunner() + + +def test_ls_files(): + assert check_same_output(["ls-files"]) + + cwd = os.getcwd() + try: + os.chdir("xgit") + assert check_same_output(["ls-files", "--full-name"]) + assert check_same_output(["ls-files"]) + finally: + os.chdir(cwd) diff --git a/xgit/test/test_utils.py b/xgit/test/test_utils.py index c02da08..4c1fa36 100644 --- a/xgit/test/test_utils.py +++ b/xgit/test/test_utils.py @@ -61,7 +61,8 @@ def check_same_output(cmd: list[str]) -> bool: if git_result.stdout != xgit_stdout: logger.info(f"cmd: {cmd}") - logger.info(f"stdout not equal: {git_result.stdout} != {xgit_stdout}") + logger.info(f"stdout not equal: {git_result.stdout!r}\n!=\n{xgit_stdout!r}") + logger.info(f"stdout not equal: {git_result.stdout.decode()}\n!=\n{xgit_stdout.decode()}") return False return True diff --git a/xgit/types/index.py b/xgit/types/index.py new file mode 100644 index 0000000..960c474 --- /dev/null +++ b/xgit/types/index.py @@ -0,0 +1,228 @@ +import hashlib +from typing import Optional + +from xgit.utils.utils import find_repo +from xgit.utils.constants import GIT_DIR + + +def print_bytes(data, group_size=4, group_each_line=6): + def is_printable(byte): + return 32 <= byte <= 126 + + byte_each_line = group_size * group_each_line + + n = len(data) + for i in range(0, n, byte_each_line): + line = data[i : i + byte_each_line] + for j in range(0, len(line), group_size): + group = line[j : j + group_size] + print("0x" + "".join(f"{byte:02x}" for byte in group).upper(), end=" ") + print() + + for j in range(0, len(line), group_size): + group = line[j : j + group_size] + print(" ", end="") + print(" ".join(chr(byte) if is_printable(byte) else "*" for byte in group), end=" ") + print("\n") + + +class IndexEntry: + class Flag: + assume_valid: bool + extended: bool + stage: int + name_length: int + + def __init__(self, assume_valid, extended, stage, name_length): + self.assume_valid = assume_valid + self.extended = extended + self.stage = stage + self.name_length = name_length + + @staticmethod + def from_bytes(data: bytes): + assert len(data) == 2 + flag = int.from_bytes(data, "big") + assume_valid = flag & 0x8000 != 0 + extended = flag & 0x4000 != 0 + stage = (flag & 0x3000) >> 12 + name_length = flag & 0x0FFF + return IndexEntry.Flag(assume_valid, extended, stage, name_length) + + def to_bytes(self) -> bytes: + data = 0 + if self.assume_valid: + data |= 0x8000 + if self.extended: + data |= 0x4000 + data |= (self.stage << 12) & 0x3000 + data |= self.name_length + return data.to_bytes(2, "big") + + ctime_s: int + ctime_ns: int + mtime_s: int + mtime_ns: int + dev: int + inode: int + mode: int + uid: int + gid: int + file_size: int + sha: str + flags: Flag + extended_flags: Optional[bytes] + file_name: str + + def __init__( + self, + ctime_s, + ctime_ns, + mtime_s, + mtime_ns, + dev, + inode, + mode, + uid, + gid, + file_size, + sha, + flags, + extended_flags, + file_name, + ): + self.ctime_s = ctime_s + self.ctime_ns = ctime_ns + self.mtime_s = mtime_s + self.mtime_ns = mtime_ns + self.dev = dev + self.inode = inode + self.mode = mode + self.uid = uid + self.gid = gid + self.file_size = file_size + self.sha = sha + self.flags = flags + self.extended_flags = extended_flags + self.file_name = file_name + + @staticmethod + def parse(data: bytes) -> tuple["IndexEntry", bytes]: + ctime_s = int.from_bytes(data[:4], "big") + ctime_ns = int.from_bytes(data[4:8], "big") + mtime_s = int.from_bytes(data[8:12], "big") + mtime_ns = int.from_bytes(data[12:16], "big") + dev = int.from_bytes(data[16:20], "big") + inode = int.from_bytes(data[20:24], "big") + mode = int.from_bytes(data[24:28], "big") + uid = int.from_bytes(data[28:32], "big") + gid = int.from_bytes(data[32:36], "big") + file_size = int.from_bytes(data[36:40], "big") + sha = data[40:60].hex() + flags = IndexEntry.Flag.from_bytes(data[60:62]) + + entry_len = 62 + + # if flags.extended == True, then there is a 16-bit extended flag + if flags.extended: + extended_flags = data[62:64] + entry_len += 2 + else: + extended_flags = None + + if flags.name_length < 0xFFF: + file_name = data[entry_len : entry_len + flags.name_length] + assert data[entry_len + flags.name_length] == 0 + entry_len += flags.name_length + 1 + else: + # if name_length >= 0xFFF, then find `\x00` to get the file name + file_name, _ = data[entry_len:].split(b"\x00", maxsplit=1) + entry_len += len(file_name) + 1 + + entry_len = (entry_len + 7) // 8 * 8 # aligned to 8 bytes + rest = data[entry_len:] # remove padding + + return ( + IndexEntry( + ctime_s, + ctime_ns, + mtime_s, + mtime_ns, + dev, + inode, + mode, + uid, + gid, + file_size, + sha, + flags, + extended_flags, + file_name.decode(), + ), + rest, + ) + + def to_bytes(self) -> bytes: + entry = self.ctime_s.to_bytes(4, "big") + entry += self.ctime_ns.to_bytes(4, "big") + entry += self.mtime_s.to_bytes(4, "big") + entry += self.mtime_ns.to_bytes(4, "big") + entry += self.dev.to_bytes(4, "big") + entry += self.inode.to_bytes(4, "big") + entry += self.mode.to_bytes(4, "big") + entry += self.uid.to_bytes(4, "big") + entry += self.gid.to_bytes(4, "big") + entry += self.file_size.to_bytes(4, "big") + entry += bytes.fromhex(self.sha) + entry += self.flags.to_bytes() + if self.extended_flags is not None: + entry += self.extended_flags + entry += self.file_name.encode() + entry += b"\x00" + + # padding to 8 bytes + padding = (8 - len(entry) % 8) % 8 + entry += b"\x00" * padding + + return entry + + +class Index: + version: int + entry_count: int + entries: list[IndexEntry] + extensions: bytes + + def __init__(self, data: Optional[bytes] = None): + if data is None: + self.version = 2 + self.entry_count = 0 + self.entries = [] + else: + self.version = int.from_bytes(data[4:8], "big") + self.entry_count = int.from_bytes(data[8:12], "big") + self.entries = [] + data = data[12:] + for _ in range(self.entry_count): + entry, data = IndexEntry.parse(data) + self.entries.append(entry) + self.extensions = data[:-20] + + def to_bytes(self) -> bytes: + index = b"DIRC" + index += self.version.to_bytes(4, "big") + index += self.entry_count.to_bytes(4, "big") + index += b"".join(entry.to_bytes() for entry in self.entries) + index += self.extensions + index += hashlib.sha1(index).digest() + return index + + +def get_index() -> Index: + index_path = find_repo() / GIT_DIR / "index" + if not index_path.exists(): + return Index() + with index_path.open("rb") as f: + data = f.read() + assert data[-20:] == hashlib.sha1(data[:-20]).digest() + return Index(data)