Skip to content

Commit

Permalink
refactor: Type hints cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
SilverRainZ committed Sep 2, 2023
1 parent 4dc9952 commit 38a1fb3
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 80 deletions.
13 changes: 5 additions & 8 deletions src/sphinxnotes/any/directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from typing import Type

from docutils import nodes
from docutils.nodes import Node, Element, fully_normalize_name
from docutils.statemachine import StringList
from docutils.parsers.rst import directives
from docutils.nodes import fully_normalize_name

from sphinx import addnodes
from sphinx.util.docutils import SphinxDirective
Expand Down Expand Up @@ -81,10 +81,7 @@ def _build_object(self) -> Object:
content='\n'.join(list(self.content.data)))


def _setup_nodes(self, obj:Object,
sectnode:nodes.Element,
ahrnode:nodes.Element|None,
contnode:nodes.Element) -> None:
def _setup_nodes(self, obj:Object, sectnode:Element, ahrnode:Element|None, contnode:Element) -> None:
"""
Attach necessary informations to nodes and note them.
Expand Down Expand Up @@ -127,7 +124,7 @@ def _setup_nodes(self, obj:Object,
contnode)


def _run_section(self, obj:Object) -> list[nodes.Node]:
def _run_section(self, obj:Object) -> list[Node]:
# Get the title of the "section" where the directive is located
sectnode = self.state.parent
titlenode = sectnode.next_node(nodes.title)
Expand Down Expand Up @@ -156,7 +153,7 @@ def _run_section(self, obj:Object) -> list[nodes.Node]:
return []


def _run_objdesc(self, obj:Object) -> list[nodes.Node]:
def _run_objdesc(self, obj:Object) -> list[Node]:
descnode = addnodes.desc()

# Generate signature node
Expand All @@ -179,7 +176,7 @@ def _run_objdesc(self, obj:Object) -> list[nodes.Node]:
return [descnode]


def run(self) -> list[nodes.Node]:
def run(self) -> list[Node]:
obj = self._build_object()
if self.schema.title_of(obj) == '_':
# If no argument is given, or the first argument is '_',
Expand Down
26 changes: 13 additions & 13 deletions src/sphinxnotes/any/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from __future__ import annotations
from typing import Tuple, Any, Iterator, Type, Set, Optional, TYPE_CHECKING
from typing import Any, Iterator, TYPE_CHECKING

from docutils.nodes import Element, literal, Text

Expand Down Expand Up @@ -44,14 +44,14 @@ class AnyDomain(Domain):
#: Type (usually directive) name -> ObjType instance
object_types:dict[str,ObjType]= {}
#: Directive name -> directive class
directives:dict[str,Type[AnyDirective]] = {}
directives:dict[str,type[AnyDirective]] = {}
#: Role name -> role callable
roles:dict[str,RoleFunction] = {}
#: A list of Index subclasses
indices:list[Type[AnyIndex]] = []
#: AnyDomain specific: Type -> index class
_indices_for_reftype:dict[str,Type[AnyIndex]] = {}
#: AnyDomain specific: Type -> Schema instance
indices:list[type[AnyIndex]] = []
#: AnyDomain specific: type -> index class
_indices_for_reftype:dict[str,type[AnyIndex]] = {}
#: AnyDomain specific: type -> Schema instance
_schemas:dict[str,Schema] = {}

initial_data:dict[str,Any] = {
Expand All @@ -62,12 +62,12 @@ class AnyDomain(Domain):
}

@property
def objects(self) -> dict[Tuple[str,str], Tuple[str,str,Object]]:
def objects(self) -> dict[tuple[str,str], tuple[str,str,Object]]:
"""(objtype, objid) -> (docname, anchor, obj)"""
return self.data.setdefault('objects', {})

@property
def references(self) -> dict[Tuple[str,str,str],Set[str]]:
def references(self) -> dict[tuple[str,str,str],set[str]]:
"""(objtype, objfield, objref) -> set(objid)"""
return self.data.setdefault('references', {})

Expand Down Expand Up @@ -105,7 +105,7 @@ def clear_doc(self, docname:str) -> None:
def resolve_xref(self, env:BuildEnvironment, fromdocname:str,
builder:Builder, typ:str, target:str,
node:pending_xref, contnode:Element,
) -> Optional[Element]:
) -> Element|None:
assert isinstance(contnode, literal)

logger.debug('[any] resolveing xref of %s', (typ, target))
Expand Down Expand Up @@ -151,7 +151,7 @@ def resolve_xref(self, env:BuildEnvironment, fromdocname:str,


# Override parent method
def get_objects(self) -> Iterator[Tuple[str, str, str, str, str, int]]:
def get_objects(self) -> Iterator[tuple[str, str, str, str, str, int]]:
for (objtype, objid), (docname, anchor, _) in self.data['objects'].items():
yield objid, objid, objtype, docname, anchor, 1

Expand Down Expand Up @@ -186,7 +186,7 @@ def add_schema(cls, schema:Schema) -> None:
cls._indices_for_reftype[r] = index


def _get_index_anchor(self, reftype:str, refval:str) -> Tuple[str,str]:
def _get_index_anchor(self, reftype:str, refval:str) -> tuple[str,str]:
"""
Return the docname and anchor name of index page. Can be used for ``make_refnode()``.
Expand All @@ -198,7 +198,7 @@ def _get_index_anchor(self, reftype:str, refval:str) -> Tuple[str,str]:


def warn_missing_reference(app: Sphinx, domain: Domain, node: pending_xref
) -> Optional[bool]:
) -> bool|None:
if domain and domain.name != AnyDomain.name:
return None

Expand All @@ -210,7 +210,7 @@ def warn_missing_reference(app: Sphinx, domain: Domain, node: pending_xref
return True


def reftype_to_objtype_and_objfield(reftype:str) -> Tuple[str,Optional[str]]:
def reftype_to_objtype_and_objfield(reftype:str) -> tuple[str,str|None]:
"""Helper function for converting reftype(role name) to object infos."""
v = reftype.split('.', maxsplit=1)
return v[0], v[1] if len(v) == 2 else None
Expand Down
14 changes: 7 additions & 7 deletions src/sphinxnotes/any/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
:copyright: Copyright 2021 Shengyu Zhang
:license: BSD, see LICENSE for details.
"""
from typing import Iterable, List, Tuple, Dict, Set, Optional, Type
from typing import Iterable
from sphinx.domains import Index, IndexEntry

from .schema import Schema
Expand All @@ -19,14 +19,14 @@ class AnyIndex(Index):

schema:Schema
# TODO: document
field:Optional[str] = None
field:str|None = None

name:str
localname:str
shortname:str

@classmethod
def derive(cls, schema:Schema, field:str|None=None) -> Type["AnyIndex"]:
def derive(cls, schema:Schema, field:str|None=None) -> type["AnyIndex"]:
"""Generate an AnyIndex child class for indexing object."""
if field:
typ = f'Any{schema.objtype.title()}{field.title()}Index'
Expand All @@ -45,14 +45,14 @@ def derive(cls, schema:Schema, field:str|None=None) -> Type["AnyIndex"]:


def generate(self, docnames:Iterable[str]|None = None
) -> Tuple[List[Tuple[str,List[IndexEntry]]], bool]:
) -> tuple[list[tuple[str,list[IndexEntry]]], bool]:
"""Override parent method."""
content = {} # type: Dict[str, List[IndexEntry]]
# List of all references
content = {} # type: dict[str, list[IndexEntry]]
# list of all references
objrefs = sorted(self.domain.data['references'].items())

# Reference value -> object IDs
objs_with_same_ref:Dict[str,Set[str]] = {}
objs_with_same_ref:dict[str,set[str]] = {}

for (objtype, objfield, objref), objids in objrefs:
if objtype != self.schema.objtype:
Expand Down
25 changes: 0 additions & 25 deletions src/sphinxnotes/any/perset.py

This file was deleted.

3 changes: 1 addition & 2 deletions src/sphinxnotes/any/roles.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
:license: BSD, see LICENSE for details.
"""
from __future__ import annotations
from typing import Type

from sphinx.util import logging
from sphinx.roles import XRefRole
Expand All @@ -29,7 +28,7 @@ class AnyRole(XRefRole):
schema:Schema

@classmethod
def derive(cls, schema:Schema, field:str|None=None) -> Type["AnyRole"]:
def derive(cls, schema:Schema, field:str|None=None) -> type["AnyRole"]:
"""Generate an AnyRole child class for referencing object."""
return type('Any%s%sRole' % (schema.objtype.title(), field.title() if field else ''),
(cls,),
Expand Down
44 changes: 22 additions & 22 deletions src/sphinxnotes/any/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
:copyright: Copyright 2021 Shengyu Zhang
:license: BSD, see LICENSE for details.
"""
from typing import Tuple, Dict, Iterator, List, Set, Optional, Union, Any
from typing import Iterator, Any
from enum import Enum, auto
from dataclasses import dataclass
import pickle
Expand All @@ -32,7 +32,7 @@ class SchemaError(AnyExtensionError):
class Object(object):
objtype:str
name:str
attrs:Dict[str,str]
attrs:dict[str,str]
content:str

def hexdigest(self) -> str:
Expand Down Expand Up @@ -90,20 +90,20 @@ def _as_plain(self, rawval:str) -> str:
return rawval


def _as_words(self, rawval:str) -> List[str]:
def _as_words(self, rawval:str) -> list[str]:
assert self.form == self.Form.WORDS
assert rawval is not None
return [x.strip() for x in rawval.split(' ') if x.strip() != '']


def _as_lines(self, rawval:str) -> List[str]:
def _as_lines(self, rawval:str) -> list[str]:
assert self.form == self.Form.LINES
assert rawval is not None
return rawval.split('\n')



def value_of(self, rawval:Optional[str]) -> Union[None,str,List[str]]:
def value_of(self, rawval:str|None) -> None|str|list[str]:
if rawval is None:
assert not self.required
return None
Expand Down Expand Up @@ -139,7 +139,7 @@ class Schema(object):

# Object fields
name:Field
attrs:Dict[str,Field]
attrs:dict[str,Field]
content:Field

# Class-wide shared template environment
Expand All @@ -160,9 +160,9 @@ class Schema(object):
ambiguous_reference_template:str

def __init__(self, objtype:str,
name:Optional[Field]=Field(unique=True, referenceable=True),
attrs:Dict[str,Field]={},
content:Optional[Field]=Field(),
name:Field|None=Field(unique=True, referenceable=True),
attrs:dict[str,Field]={},
content:Field|None=Field(),
description_template:str='{{ content }}',
reference_template:str='{{ title }}',
missing_reference_template:str='{{ title }} (missing reference)',
Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(self, objtype:str,
has_unique = field.unique


def object(self, name:Optional[str], attrs:Dict[str,str], content:Optional[str]) -> Object:
def object(self, name:str|None, attrs:dict[str,str], content:str|None) -> Object:
"""Generate a object"""
obj = Object(objtype=self.objtype,
name=name,
Expand All @@ -211,11 +211,11 @@ def object(self, name:Optional[str], attrs:Dict[str,str], content:Optional[str])
return obj


def fields_of(self, obj:Object) -> Iterator[Tuple[str,Field,Union[None,str,List[str]]]]:
def fields_of(self, obj:Object) -> Iterator[tuple[str,Field,None|str|list[str]]]:
"""
Helper method for returning all fields of object and its raw values.
-> Iterator[field_name, field_instance, field_value],
while the field_value is Union[string_value, string_list_value].
while the field_value is string_value|string_list_value.
"""
if self.name:
yield (self.NAME_KEY, self.name, self.name.value_of(obj.name) if obj else None)
Expand All @@ -225,22 +225,22 @@ def fields_of(self, obj:Object) -> Iterator[Tuple[str,Field,Union[None,str,List[
yield (self.CONTENT_KEY, self.content, self.content.value_of(obj.content) if obj else None)


def name_of(self, obj:Object) -> Union[None,str,List[str]]:
def name_of(self, obj:Object) -> None|str|list[str]:
assert obj
return self.content.value_of(obj.name)


def attrs_of(self, obj:Object) -> Dict[str,Union[None,str,List[str]]]:
def attrs_of(self, obj:Object) -> dict[str,None|str|list[str]]:
assert obj
return {k: f.value_of(obj.attrs.get(k)) for k, f in self.attrs.items()}


def content_of(self, obj:Object) -> Union[None,str,List[str]]:
def content_of(self, obj:Object) -> None|str|list[str]:
assert obj
return self.content.value_of(obj.content)


def identifier_of(self, obj:Object) -> Tuple[Optional[str],str]:
def identifier_of(self, obj:Object) -> tuple[str|None,str]:
"""
Return unique identifier of object.
If there is not any unique field, return (None, obj.hexdigest()) instead.
Expand All @@ -259,7 +259,7 @@ def identifier_of(self, obj:Object) -> Tuple[Optional[str],str]:
return None, obj.hexdigest()


def title_of(self, obj:Object) -> Optional[str]:
def title_of(self, obj:Object) -> str|None:
"""Return title (display name) of object."""
assert obj
name = self.name.value_of(obj.name)
Expand All @@ -271,7 +271,7 @@ def title_of(self, obj:Object) -> Optional[str]:
return None


def references_of(self, obj:Object) -> Set[Tuple[str,str]]:
def references_of(self, obj:Object) -> set[tuple[str,str]]:
"""Return all references (referenceable fields) of object"""
assert obj
refs = []
Expand All @@ -287,16 +287,16 @@ def references_of(self, obj:Object) -> Set[Tuple[str,str]]:
return set(refs)


def _context_without_object(self) -> Dict[str,Union[str,List[str]]]:
def _context_without_object(self) -> dict[str,str|list[str]]:
return {
self.TYPE_KEY: self.objtype,
}


def _context_of(self, obj:Object) -> Dict[str,Union[str,List[str]]]:
def _context_of(self, obj:Object) -> dict[str,str|list[str]]:
context = self._context_without_object()

def set_if_not_none(key:str, val:Union[str,List[str]]) -> None:
def set_if_not_none(key, val) -> None:
if val is not None:
context[key] = val
set_if_not_none(self.NAME_KEY, self.name_of(obj))
Expand All @@ -308,7 +308,7 @@ def set_if_not_none(key:str, val:Union[str,List[str]]) -> None:
return context


def render_description(self, obj:Object) -> List[str]:
def render_description(self, obj:Object) -> list[str]:
assert obj
tmpl = TemplateEnvironment().from_string(self.description_template)
description = tmpl.render(self._context_of(obj))
Expand Down
Loading

0 comments on commit 38a1fb3

Please sign in to comment.