diff --git a/pyproject.toml b/pyproject.toml index ac6ba41..a7dede5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "blitz-models" -version = "1.2.1" +version = "1.2.2" authors = [{ name = "Jylpah", email = "jylpah@gmail.com" }] description = "Pydantic models for Wargaming's World of Tanks Blitz game " readme = "README.md" diff --git a/src/blitzmodels/__init__.py b/src/blitzmodels/__init__.py index 791ed91..01144b1 100644 --- a/src/blitzmodels/__init__.py +++ b/src/blitzmodels/__init__.py @@ -1,5 +1,5 @@ from .config import get_config_file as get_config_file - +from .types import AccountId as AccountId, TankId as TankId from .region import Region as Region from .release import Release as Release from .account import Account as Account @@ -38,6 +38,7 @@ __all__ = [ + "types", "account", "config", "map", diff --git a/src/blitzmodels/account.py b/src/blitzmodels/account.py index a671b94..b27416b 100644 --- a/src/blitzmodels/account.py +++ b/src/blitzmodels/account.py @@ -19,6 +19,7 @@ from .region import Region from .wg_api import AccountInfo +from .types import AccountId logger = logging.getLogger() error = logger.error @@ -37,7 +38,6 @@ TypeAccountDict = dict[str, int | bool | Region | None] - # def lateinit_region() -> Region: # """Required for initializing a model w/o a 'region' field""" # raise RuntimeError("lateinit_region(): should never be called") @@ -45,7 +45,7 @@ class Account(JSONExportable, CSVExportable, TXTExportable, TXTImportable, Importable): # fmt: off - id : int = Field(alias="_id") + id : AccountId = Field(alias="_id") # lateinit is a trick to fool mypy since region is set in root_validator region : Region = Field(default=Region.bot, alias="r") last_battle_time: int = Field(default=0, alias="l") diff --git a/src/blitzmodels/tank.py b/src/blitzmodels/tank.py index 1084164..3baf3c6 100644 --- a/src/blitzmodels/tank.py +++ b/src/blitzmodels/tank.py @@ -14,6 +14,8 @@ TEXT, ) +from .types import TankId + logger = logging.getLogger() error = logger.error message = logger.warning @@ -119,7 +121,7 @@ def __str__(self) -> str: class Tank(JSONExportable, CSVExportable, TXTExportable): # fmt: off - tank_id : int = Field(default=..., alias = '_id') + tank_id : TankId = Field(default=..., alias = '_id') name : str = Field(default="") code : str | None = Field(default=None) nation : EnumNation = Field(default=EnumNation.european) diff --git a/src/blitzmodels/types.py b/src/blitzmodels/types.py new file mode 100644 index 0000000..c6eb2db --- /dev/null +++ b/src/blitzmodels/types.py @@ -0,0 +1,4 @@ +## Type aliases + +AccountId = int +TankId = int diff --git a/src/blitzmodels/wg_api.py b/src/blitzmodels/wg_api.py index cb37294..cce533a 100644 --- a/src/blitzmodels/wg_api.py +++ b/src/blitzmodels/wg_api.py @@ -5,6 +5,7 @@ TypeVar, Sequence, Tuple, + Set, Self, Type, Dict, @@ -12,7 +13,6 @@ ) from types import TracebackType import logging -from sys import path import pyarrow # type: ignore from bson import ObjectId from pydantic import ( @@ -46,18 +46,15 @@ from pyutils.utils import epoch_now from pyutils import ThrottledClientSession -# Fix relative imports -from pathlib import Path - -path.insert(0, str(Path(__file__).parent.parent.resolve())) - -from blitzmodels.region import Region # noqa: E402 -from blitzmodels.tank import ( # noqa: E402 +from .region import Region +from .tank import ( Tank, EnumNation, EnumVehicleTypeStr, EnumVehicleTier, ) +from .types import AccountId, TankId + TYPE_CHECKING = True logger = logging.getLogger() @@ -83,8 +80,7 @@ class WGApiError(JSONExportable): message: str | None field: str | None value: str | None - # TODO[pydantic]: The following keys were removed: `allow_mutation`. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. + model_config = ConfigDict( frozen=False, validate_assignment=True, populate_by_name=True ) @@ -93,8 +89,6 @@ def str(self) -> str: return f"code: {self.code} {self.message}" - - class WGTankStatAll(JSONExportable): # fmt: off battles: int = Field(..., alias="b") @@ -132,8 +126,8 @@ class WGTankStatAll(JSONExportable): class AccountInfoStats(WGTankStatAll): - max_frags_tank_id : int = Field(default=0, alias="mft") - max_xp_tank_id : int = Field(default=0, alias="mxt") + max_frags_tank_id: int = Field(default=0, alias="mft") + max_xp_tank_id: int = Field(default=0, alias="mxt") class TankStat(JSONExportable): @@ -142,8 +136,8 @@ class TankStat(JSONExportable): region: Region | None = Field(default=None, alias="r") all: WGTankStatAll = Field(..., alias="s") last_battle_time: int = Field(..., alias="lb") - account_id: int = Field(..., alias="a") - tank_id: int = Field(..., alias="t") + account_id: TankId = Field(..., alias="a") + tank_id: TankId = Field(..., alias="t") mark_of_mastery: int = Field(default=0, alias="m") battle_life_time: int = Field(default=0, alias="l") release: str | None = Field(default=None, alias="u") @@ -274,7 +268,7 @@ def arrow_schema(cls) -> pyarrow.schema: @classmethod def mk_id( - cls, account_id: int, last_battle_time: int, tank_id: int = 0 + cls, account_id: AccountId, last_battle_time: int, tank_id: TankId = 0 ) -> ObjectId: return ObjectId( hex(account_id)[2:].zfill(10) @@ -317,6 +311,7 @@ def __str__(self) -> str: tank_id={self.tank_id} \ last_battle_time={self.last_battle_time}" + ########################################### # # AccountInfo() @@ -324,7 +319,6 @@ def __str__(self) -> str: ########################################### - class AccountInfo(JSONExportable): # fmt: off account_id: int = Field(alias="id") @@ -337,7 +331,7 @@ class AccountInfo(JSONExportable): # fmt: on model_config = ConfigDict( - # arbitrary_types_allowed=True, # should this be removed? + # arbitrary_types_allowed=True, # should this be removed? frozen=False, validate_assignment=True, populate_by_name=True, @@ -405,6 +399,7 @@ def set_region(self) -> Self: "nickname": "jylpah" }""" + class WGApiWoTBlitz(JSONExportable): # fmt: off status: str = Field(default="ok", alias="s") @@ -594,8 +589,7 @@ class PlayerAchievementsMain(JSONExportable): max_series: PlayerAchievementsMaxSeries | None = Field(default=None, alias="m") account_id: int | None = Field(default=None) updated: int | None = Field(default=None) - # TODO[pydantic]: The following keys were removed: `allow_mutation`. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. + model_config = ConfigDict( frozen=False, validate_assignment=True, populate_by_name=True ) @@ -609,8 +603,7 @@ class PlayerAchievementsMain(JSONExportable): class WGApiWoTBlitzPlayerAchievements(WGApiWoTBlitz): data: dict[str, PlayerAchievementsMain] | None = Field(default=None, alias="d") - # TODO[pydantic]: The following keys were removed: `allow_mutation`. - # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information. + model_config = ConfigDict( frozen=False, validate_assignment=True, populate_by_name=True ) @@ -674,8 +667,12 @@ class WGApiWoTBlitzTankopedia(WGApiWoTBlitz): data: Dict[str, Tank] = Field(default=dict(), alias="d") codes: Dict[str, Tank] = Field(default=dict(), alias="c") + # TODO: Implement tier cache + _tier_cache: Dict[int, Set[TankId]] = dict() + _exclude_export_DB_fields = {"codes": True} _exclude_export_src_fields = {"codes": True} + model_config = ConfigDict( frozen=False, validate_assignment=True, populate_by_name=True ) @@ -684,12 +681,14 @@ class WGApiWoTBlitzTankopedia(WGApiWoTBlitz): def _validate_code(self) -> Self: if len(self.codes) == 0: self._set_skip_validation("codes", self._update_codes(data=self.data)) + if len(self._tier_cache) == 0: + self._set_skip_validation("_tier_cache", self._update_tier_cache()) return self def __len__(self) -> int: return len(self.data) - def __getitem__(self, key: str | int) -> Tank: + def __getitem__(self, key: str | TankId) -> Tank: if isinstance(key, int): key = str(key) return self.data[key] @@ -703,6 +702,15 @@ def update_count(self) -> None: self.meta = dict() self.meta["count"] = len(self.data) + def _update_tier_cache(self) -> Dict[int, Set[TankId]]: + """Update tier cache and return new cache""" + res: Dict[int, Set[TankId]] = dict() + for tier in range(1, 11): + res[tier] = set() + for tank in self.data.values(): + res[tank.tier].add(tank.tank_id) + return res + def _code_add(self, tank: Tank, codes: dict[str, Tank]) -> bool: if tank.code is not None: codes[tank.code] = tank @@ -711,16 +719,18 @@ def _code_add(self, tank: Tank, codes: dict[str, Tank]) -> bool: def add(self, tank: Tank) -> None: self.data[str(tank.tank_id)] = tank + self._tier_cache[tank.tier].add(tank.tank_id) self._code_add(tank, self.codes) self.update_count() - def pop(self, tank_id: int) -> Tank: + def pop(self, tank_id: TankId) -> Tank: """Raises KeyError if tank_id is not found in self.data""" tank: Tank = self.data.pop(str(tank_id)) self.update_count() if tank.code is not None: try: del self.codes[tank.code] + self._tier_cache[tank.tier].remove(tank.tank_id) except Exception as err: debug(f"could not remove code for tank_id={tank.tank_id}: {err}") pass @@ -751,16 +761,18 @@ def update_codes(self) -> None: """update _code dict""" self._set_skip_validation("codes", self._update_codes(self.data)) - def update_tanks(self, new: "WGApiWoTBlitzTankopedia") -> Tuple[set[int], set[int]]: + def update_tanks( + self, new: "WGApiWoTBlitzTankopedia" + ) -> Tuple[set[TankId], set[TankId]]: """update tankopedia with another one""" - new_ids: set[int] = {tank.tank_id for tank in new} - old_ids: set[int] = {tank.tank_id for tank in self} - added: set[int] = new_ids - old_ids - updated: set[int] = new_ids & old_ids + new_ids: set[TankId] = {tank.tank_id for tank in new} + old_ids: set[TankId] = {tank.tank_id for tank in self} + added: set[TankId] = new_ids - old_ids + updated: set[TankId] = new_ids & old_ids updated = {tank_id for tank_id in updated if new[tank_id] != self[tank_id]} self.data.update({(str(tank_id), new[tank_id]) for tank_id in added}) - updated_ids: set[int] = set() + updated_ids: set[TankId] = set() for tank_id in updated: if self.data[str(tank_id)].update(new[tank_id]): updated_ids.add(tank_id) @@ -768,6 +780,11 @@ def update_tanks(self, new: "WGApiWoTBlitzTankopedia") -> Tuple[set[int], set[in self.update_codes() return (added, updated_ids) + def get_tank_ids_by_tier(self, tier: int) -> Set[TankId]: + if tier < 1 or tier > 10: + raise ValueError(f"tier must be between 1-10: {tier}") + return self._tier_cache[tier] + class WGApiTankString(JSONExportable): id: int diff --git a/tests/test_tank.py b/tests/test_tank.py index 47850a7..4b4d923 100644 --- a/tests/test_tank.py +++ b/tests/test_tank.py @@ -249,6 +249,12 @@ async def test_10_WGApiTankopedia( assert ( False ), f"Parsing test file List[Tank] failed: {basename(tanks_json_fn)}" + N: int = 0 + for tier in range(1, 11): + N += len(tankopedia.get_tank_ids_by_tier(tier=tier)) + assert N == len( + tankopedia + ), f"incorrect number of tanks in the tier cache: {N} != {len(tankopedia)}" for tank in tanks_json: tankopedia.add(tank) debug("read %d tanks", len(tankopedia.data)) @@ -296,9 +302,7 @@ async def test_11_WGApiTankopedia( await file.read() ) except Exception: - assert ( - False - ), f"Parsing test file WGApiWoTBlitzTankopedia() failed: {basename(tankopedia_fn)}" + assert False, f"Parsing test file WGApiWoTBlitzTankopedia() failed: {basename(tankopedia_fn)}" debug("read %d tanks", len(tankopedia.data)) assert tankopedia.meta is not None, "Failed to update meta" @@ -309,9 +313,7 @@ async def test_11_WGApiTankopedia( len(tankopedia.data) == tankopedia_tanks ), f"could not import all the tanks: got {tankopedia.data}, should be {tankopedia_tanks}" - assert ( - tankopedia.has_codes - ), f"could not generate all the codes: tanks={len(tankopedia.data)}, codes={len(tankopedia.codes)}" + assert tankopedia.has_codes, f"could not generate all the codes: tanks={len(tankopedia.data)}, codes={len(tankopedia.codes)}" # test tankopedia export import tankopedia_file: str = f"{tmp_path.resolve()}/tankopedia.json" try: @@ -347,9 +349,7 @@ async def test_12_WGApiTankopedia_sorted( await file.read() ) except Exception: - assert ( - False - ), f"Parsing test file WGApiWoTBlitzTankopedia() failed: {basename(tankopedia_fn)}" + assert False, f"Parsing test file WGApiWoTBlitzTankopedia() failed: {basename(tankopedia_fn)}" debug("read %d tanks", len(tankopedia.data)) tanks: list[Tank] = list() diff --git a/tests/test_wgapi.py b/tests/test_wgapi.py index 985aef0..2bc576e 100644 --- a/tests/test_wgapi.py +++ b/tests/test_wgapi.py @@ -4,8 +4,11 @@ import json from bson import ObjectId from typing import Dict, List -from blitzmodels import Account, Region, WGApi, AccountInfo from blitzmodels import ( + Account, + Region, + WGApi, + AccountInfo, PlayerAchievementsMaxSeries, TankStat, WGApiWoTBlitzTankopedia, @@ -13,6 +16,7 @@ WGApiTankString, ) + logger = logging.getLogger() error = logger.error message = logger.warning @@ -163,7 +167,9 @@ def tanks_updated() -> list[Tank]: @pytest.mark.asyncio @ACCOUNTS async def test_1_api_account_info(datafiles: Path) -> None: - assert (acc_info := AccountInfo.example_instance()) is not None, "AccountInfo.example_instance() failed" + assert ( + acc_info := AccountInfo.example_instance() + ) is not None, "AccountInfo.example_instance() failed" async with WGApi() as wg: for account_fn in datafiles.iterdir(): accounts: Dict[int, Account] = dict() @@ -285,9 +291,7 @@ def test_5_player_achievements() -> None: assert pa.indexes["account_id"] == 521458531, "indexes @property failed" assert "account_id" in f"{pa}", "'account_id' not found in str(TankStats)" except Exception as err: - assert ( - False - ), f"Could not validate PlayerAchievementsMaxSeries example instance: {type(err)}: {err}" + assert False, f"Could not validate PlayerAchievementsMaxSeries example instance: {type(err)}: {err}" assert ( len(PlayerAchievementsMaxSeries.backend_indexes()) > 0 @@ -313,10 +317,8 @@ async def test_6_api_tankopedia( assert len(tankopedia) > 0, "API returned empty tankopedia" assert ( - tankopedia := await wg.get_tankopedia() - ) is not None, ( - "could not fetch tankopedia from WG API from (default server = eu)" - ) + (tankopedia := await wg.get_tankopedia()) is not None + ), "could not fetch tankopedia from WG API from (default server = eu)" for tank_id in tanks_remove: tankopedia.pop(tank_id) @@ -329,11 +331,11 @@ async def test_6_api_tankopedia( (added, updated) = tankopedia.update_tanks(tankopedia_new) - assert len(added) == len( - tanks_remove + assert ( + len(added) == len(tanks_remove) ), f"incorrect number of added tanks reported {len(added) } != {len(tanks_remove)}" - assert len(updated) == len( - tanks_updated + assert ( + len(updated) == len(tanks_updated) ), f"incorrect number of updated tanks reported {len(updated) } != {len(tanks_updated)}" @@ -355,9 +357,7 @@ async def test_7_api_tankstrs( False ), f"failed to parse test file as WGApiTankString(): {fn.name}: {err}" if (tank := Tank.transform(tank_str)) is None: - assert ( - False - ), f"could not transform WGApiTankString() to Tank(): {tank_str.user_string}" + assert False, f"could not transform WGApiTankString() to Tank(): {tank_str.user_string}" async with WGApi() as wg: for user_str in wgapi_tankstrs_user_strings: