diff --git a/scripts/lint.sh b/scripts/lint.sh index 5285a44..684be83 100755 --- a/scripts/lint.sh +++ b/scripts/lint.sh @@ -1,4 +1,4 @@ isort -l 120 -m 3 --length-sort . black --line-length 120 . -pylint $(git status -s | grep -E '\.py$' | cut -c 4-) --max-line-length 120 --disable=missing-docstring,empty-docstring,redefined-builtin +pylint $(git status -s | grep -E '\.py$' | cut -c 4-) --max-line-length 120 --disable=missing-docstring,empty-docstring,redefined-builtin,too-many-arguments,too-many-instance-attributes,too-many-locals mypy . --ignore-missing-imports diff --git a/scripts/lint_all.sh b/scripts/lint_all.sh index 624f04d..b45c3f3 100755 --- a/scripts/lint_all.sh +++ b/scripts/lint_all.sh @@ -1,4 +1,4 @@ isort -l 120 -m 3 --length-sort . black --line-length 120 . -pylint --recursive=y . --max-line-length 120 --disable=missing-docstring,empty-docstring,redefined-builtin +pylint --recursive=y . --max-line-length 120 --disable=missing-docstring,empty-docstring,redefined-builtin,too-many-arguments,too-many-instance-attributes,too-many-locals mypy . --ignore-missing-imports diff --git a/xgit/cli.py b/xgit/cli.py index 36f5732..eafc16f 100644 --- a/xgit/cli.py +++ b/xgit/cli.py @@ -1,6 +1,6 @@ import typer -from xgit.commands import init, cat_file, ls_files, show_index, hash_object +from xgit.commands import init, cat_file, ls_files, show_index, hash_object, update_index app = typer.Typer(add_completion=False, rich_markup_mode="markdown") @@ -9,6 +9,8 @@ app.command()(init.init) app.command()(cat_file.cat_file) app.command()(ls_files.ls_files) +app.command()(update_index.update_index) + app.command(hidden=True)(show_index.show_index) diff --git a/xgit/commands/update_index.py b/xgit/commands/update_index.py new file mode 100644 index 0000000..0c8e169 --- /dev/null +++ b/xgit/commands/update_index.py @@ -0,0 +1,227 @@ +import sys +import contextlib +from pathlib import Path + +import typer +from typer import Option, Argument +from typing_extensions import Optional, Annotated + +from xgit.types.index import Index, get_index +from xgit.utils.utils import find_repo + + +def update_index( + files: Annotated[Optional[list[str]], Argument(help="要更新的文件")] = None, + add: Annotated[bool, Option("--add", help="如果文件在暂存区中不存在,则加入暂存区")] = False, + remove: Annotated[bool, Option("--remove", help="如果文件在暂存区存在,但在本地不存在,则从暂存区移除")] = False, + force_remove: Annotated[bool, Option("--force-remove", help="从暂存区移除文件,即使文件在本地存在")] = False, + refresh: Annotated[bool, Option("--refresh", help="检查当前 index 中的文件是否需要 merge 或 update")] = False, + verbose: Annotated[bool, Option("--verbose", help="详细输出添加和删除的文件")] = False, + cacheinfo: Annotated[ + Optional[list[str]], Option("--cacheinfo", help="与 `--add` 一同使用,用 `,,` 指定一个 blob 加入暂存区") + ] = None, +): + """ + 更新暂存区 + """ + files = files or [] + + # 如果没有 git repo,会报错 + find_repo() + + # === 检查参数正确性 === + # add, remove, force_remove, refresh 互斥 + if add + remove + force_remove + refresh > 1: + typer.echo("fatal: only one of the options can be used", err=True) + sys.exit(1) + + # refresh 不能与 files 同时使用 + if refresh and files: + typer.echo("fatal: --refresh cannot be used with files", err=True) + sys.exit(1) + + # cacheinfo 必须与 add 同时使用 + if cacheinfo and not add: + typer.echo("fatal: --cacheinfo can only be used with --add", err=True) + sys.exit(1) + + # 如果有不在当前仓库中的文件,报错 + repo = find_repo().resolve() + for f in files: + if not Path(f).absolute().is_relative_to(repo): + typer.echo(f"fatal: '{f}' is outside repository at '{repo}'", err=True) + sys.exit(128) + + # files 中的目录需要被忽略 + dir_in_files = [f for f in files if Path(f).is_dir()] + for f in dir_in_files: + typer.echo(f"Ignoring path '{f}'", err=True) + + # 实际需要处理的 files + files = [f for f in files if f not in dir_in_files] + + # === 执行操作 === + sys.exit( + _update_index( + files=files, + add=add, + remove=remove, + force_remove=force_remove, + refresh=refresh, + verbose=verbose, + cacheinfo=cacheinfo, + ) + ) + + +def _update_index( + files: list[str], + add: bool, + remove: bool, + force_remove: bool, + refresh: bool, + verbose: bool, + cacheinfo: Optional[list[str]], +) -> int: + with working_index() as index: + if refresh: + return refresh_index(index=index) + + if not (add or remove or force_remove): + return update(index=index, files=files, add_if_absent=False, verbose=verbose) + + if add: + exit_code = update(index=index, files=files, add_if_absent=True, verbose=verbose) + if exit_code != 0: + return exit_code + return add_cacheinfo(index=index, cacheinfo=cacheinfo, verbose=verbose) + + if remove: + return remove_files(index=index, files=files, force=False, verbose=verbose) + + if force_remove: + return remove_files(index=index, files=files, force=True, verbose=verbose) + + assert False, "unreachable" + + +@contextlib.contextmanager +def working_index(): + """ + 获取 index,执行完毕后写回 + xgit 预期不会抛出异常。因此如果抛出异常,不写回 + """ + index = get_index() + yield index + index.write() + + +def refresh_index(index: Index) -> int: + """ + update-index --refresh: 刷新 index 中的 metadata,报告需要 update 的文件 + + 对于 index 中的每个 entry: + 检查其 metadata 是否与本地一致 + 如果不一致,检查文件 sha 是否一致:如果一致,更新 metadata;否则报告 needs update + + 我们暂时忽略 merge 相关的状态 + + 如果有 needs update 的文件,返回 1;否则返回 0 + """ + needs_update = index.refresh() + for f in needs_update: + typer.echo(f"{f}: needs update") + return 0 if len(needs_update) == 0 else 1 + + +def update(index: Index, files: list[str], add_if_absent: bool, verbose: bool) -> int: + for f in files: + ret_code = do_update(index=index, f=f, add_if_absent=add_if_absent) + if ret_code != 0: + typer.echo(f"fatal: Unable to process path {f}", err=True) + return ret_code + if verbose: + typer.echo(f"add '{f}'") + return 0 + + +def do_update(index: Index, f: str, add_if_absent: bool = False) -> int: + file_path = Path(f) + + if not file_path.exists(): + typer.echo(f"error: {f}: does not exist and --remove not passed", err=True) + return 128 + + if index.update(file_path, add_if_absent): + return 0 + if not add_if_absent: + typer.echo(f"error: {f}: cannot add to the index - missing --add option?", err=True) + return 128 + + assert False, "update-index --add should not fail here" + + +def add_cacheinfo(index: Index, cacheinfo: Optional[list[str]], verbose: bool) -> int: + cacheinfo = cacheinfo or [] + for cache_info in cacheinfo: + mode, obj, path = cache_info.split(",") + ret_code = do_add_cacheinfo(index=index, mode=mode, obj=obj, path=path) + + if ret_code != 0: + typer.echo(f"fatal: git update-index: --cacheinfo cannot add {path}") + return ret_code + if verbose: + typer.echo(f"add '{path}'") + + return 0 + + +def do_add_cacheinfo(index: Index, mode: str, obj: str, path: str) -> int: + """ + update-index --add --cacheinfo ,, + + 将一个 blob 加入暂存区 + """ + + def is_path_valid(path: str) -> bool: + """ + path 不应以 / 开头或结尾,不应包含连续的 //,或者单独的 . 和 .. + """ + parts = path.split("/") + return not ("" in parts or "." in parts or ".." in parts) + + def find_path_conflict(path: str) -> Optional[str]: + """ + path 不应与现有 index 产生冲突 + 具体来说,如果 index 中有 a,那么 a/b 非法,因为 a 是一个已知文件,但 a/b 暗示 a 是一个目录 + """ + parts = path.split("/") + + for i in range(1, len(parts)): + prefix = "/".join(parts[:i]) + if prefix in index.entry_paths(): + return prefix + + return None + + if not is_path_valid(path): + typer.echo(f"error: Invalid path '{path}'", err=True) + return 128 + + conflict = find_path_conflict(path) + if conflict is not None: + typer.echo(f"error: '{conflict}' appears as both a file and as a directory", err=True) + return 128 + + index.add_cacheinfo(mode, obj, path) + return 0 + + +def remove_files(index: Index, files: list[str], force: bool, verbose: bool) -> int: + for f in files: + file_path = Path(f) + if not file_path.exists() or force: + removed = index.remove(file_path) + if removed and verbose: + typer.echo(f"remove '{f}'") + return 0 diff --git a/xgit/test/test_update_index.py b/xgit/test/test_update_index.py new file mode 100644 index 0000000..e69de29 diff --git a/xgit/types/index.py b/xgit/types/index.py index d66033a..c9896a8 100644 --- a/xgit/types/index.py +++ b/xgit/types/index.py @@ -1,7 +1,10 @@ import hashlib +from enum import Enum from typing import Optional +from pathlib import Path -from xgit.utils.utils import find_repo, get_repo_file, timestamp_to_str +from xgit.utils.sha import hash_file +from xgit.utils.utils import find_repo, get_repo_file, timestamp_to_str, get_file_path_in_repo from xgit.types.metadata import Metadata from xgit.utils.constants import GIT_DIR @@ -86,6 +89,25 @@ def __init__( self.extended_flags = extended_flags self.file_name = file_name + @staticmethod + def from_file(f: Path) -> "IndexEntry": + metadata = Metadata.get_metadata(f) + sha = hash_file(str(f)) + flags = IndexEntry.Flag(False, False, 0, len(f.name)) + extended_flags = None + file_name = str(get_file_path_in_repo(f)) + return IndexEntry(metadata, sha, flags, extended_flags, file_name) + + @staticmethod + def from_cache_info(mode: str, obj: str, path: str) -> "IndexEntry": + # example: 100644,e69de29bb2d1d6434b8b29ae775ad8c2e48c5391,a + metadata = Metadata.from_cache_info(get_repo_file(path), int(mode, base=8)) + sha = obj + flags = IndexEntry.Flag(False, False, 0, len(path)) + extended_flags = None + file_name = path + return IndexEntry(metadata, sha, flags, extended_flags, file_name) + @staticmethod def parse(data: bytes) -> tuple["IndexEntry", bytes]: ctime_s = int.from_bytes(data[:4], "big") @@ -171,6 +193,30 @@ def to_bytes(self) -> bytes: return entry + class IndexEntryStatus(Enum): + UP_TO_DATE = 0 + NEEDS_UPDATE = 1 + + def refresh(self) -> IndexEntryStatus: + """ + 执行 update-index --refresh + """ + f = get_repo_file(self.file_name) + + if not f.exists(): + return IndexEntry.IndexEntryStatus.NEEDS_UPDATE + + metadata = Metadata.get_metadata(f) + if metadata == self.metadata: + return IndexEntry.IndexEntryStatus.UP_TO_DATE + + sha = hash_file(str(f)) + if sha != self.sha: + return IndexEntry.IndexEntryStatus.NEEDS_UPDATE + + self.metadata = metadata + return IndexEntry.IndexEntryStatus.UP_TO_DATE + # 以下用于 show-index 输出 verbose: bool = False @@ -220,12 +266,69 @@ def to_bytes(self) -> bytes: index += hashlib.sha1(index).digest() return index + def write(self): + with open(find_repo() / GIT_DIR / "index", "wb") as f: + print(f"write index to {f.name}") + f.write(self.to_bytes()) + def __rich_repr__(self): yield "version", self.version yield "entry_count", self.entry_count yield "entries", self.entries yield "extensions", self.extensions + def entry_paths(self) -> list[str]: + return [e.file_name for e in self.entries] + + def refresh(self) -> list[str]: + """ + 执行 update-index --refresh,返回需要 update 的文件名 + """ + needs_update = [] + for entry in self.entries: + if entry.refresh() == IndexEntry.IndexEntryStatus.NEEDS_UPDATE: + needs_update.append(entry.file_name) + return needs_update + + def update_or_add(self, entry: IndexEntry, add_if_absent: bool) -> bool: + """ + 假设 entries 有序 + 返回是否做了 update / add + """ + for i, e in enumerate(self.entries): + if e.file_name == entry.file_name: + self.entries[i] = entry + return True + if e.file_name > entry.file_name: + if add_if_absent: + self.entries.insert(i, entry) + return True + return False + + if add_if_absent: + self.entries.append(entry) + return True + return False + + def update(self, file: Path, add_if_absent: bool) -> bool: + entry = IndexEntry.from_file(file) + return self.update_or_add(entry, add_if_absent) + + def add_cacheinfo(self, mode: str, obj: str, path: str): + entry = IndexEntry.from_cache_info(mode, obj, path) + assert self.update_or_add(entry, True), "--add --cacheinfo should not fail here" + + def remove(self, file: Path) -> bool: + """ + 返回是否成功删除 + """ + file_name = get_file_path_in_repo(file) + for i, e in enumerate(self.entries): + if e.file_name == file_name: + self.entries.pop(i) + return True + return False + def get_index() -> Index: """ diff --git a/xgit/types/metadata.py b/xgit/types/metadata.py index 674577d..66c3212 100644 --- a/xgit/types/metadata.py +++ b/xgit/types/metadata.py @@ -41,6 +41,22 @@ def __init__( self.gid = gid self.file_size = file_size + @staticmethod + def from_cache_info(path, mode): + return Metadata( + path=path, + mode=mode, + ctime_s=0, + ctime_ns=0, + mtime_s=0, + mtime_ns=0, + dev=0, + inode=0, + uid=0, + gid=0, + file_size=0, + ) + def __rich_repr__(self): yield "ctime_s", self.ctime_s yield "ctime_ns", self.ctime_ns @@ -52,3 +68,30 @@ def __rich_repr__(self): yield "uid", self.uid yield "gid", self.gid yield "file_size", self.file_size + + @staticmethod + def get_metadata(path: Path): + assert path.exists() + stat = path.stat() + return Metadata( + path, + int(stat.st_ctime), + int(stat.st_ctime_ns), + int(stat.st_mtime), + int(stat.st_mtime_ns), + int(stat.st_dev), + int(stat.st_ino), + int(stat.st_mode), + int(stat.st_uid), + int(stat.st_gid), + int(stat.st_size), + ) + + def __eq__(self, __value: object) -> bool: + assert isinstance(__value, Metadata) + return ( + self.ctime_s == __value.ctime_s + and self.ctime_ns == __value.ctime_ns + and self.mtime_s == __value.mtime_s + and self.mtime_ns == __value.mtime_ns + ) diff --git a/xgit/utils/utils.py b/xgit/utils/utils.py index 5c6ee48..2198138 100644 --- a/xgit/utils/utils.py +++ b/xgit/utils/utils.py @@ -29,6 +29,14 @@ def get_repo_file(f: str) -> Path: return (find_repo() / f).resolve() +def get_file_path_in_repo(f: Path) -> str: + """ + `f` 是本地实际路径,返回相对于 repo 的路径 + """ + repo = find_repo() + return str(f.resolve().relative_to(repo)) + + def get_object(obj: str): """ 给定一个 object 的 ID (sha),返回它在 objects 中的路径