Skip to content

Commit

Permalink
ensure right ignore for the docs
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed May 20, 2024
1 parent 4c2e205 commit 104d57e
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 98 deletions.
5 changes: 1 addition & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@ docs/source/
llama3/
docs/build/
docs/_build/
docs/source/_static
docs/source/_templates
docs/source/apis/tests*.rst
tests/log_test/*.log
docs/source/documents/tests*.rst
docs/source/documents/
docts/source/apis/*
lib/
docs/source/apis/tests*.rst
li_test/
.mypy_cache/
.pytest_cache/
Expand Down
4 changes: 2 additions & 2 deletions components/agent/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def call(
self.reset()
return answer

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
s = f"tools={self.tools}, max_steps={self.max_steps}, "
s += super().extra_repr()
s += super()._extra_repr()
return s


Expand Down
2 changes: 1 addition & 1 deletion components/retriever/faiss_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,6 @@ def __call__(
response = self.retrieve(query_or_queries=query_or_queries, top_k=top_k)
return response

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
s = f"top_k={self.top_k}, dimensions={self.dimensions}, "
return s
48 changes: 2 additions & 46 deletions core/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@ class APIClient(Component):

def __init__(self, *args, **kwargs) -> None:
super().__init__()
# take the *args and **kwargs to be compatible with the Component class
# comvert args to attributes
for i, arg in enumerate(args):
super().__setattr__(f"arg_{i}", arg)
# convert kwargs to attributes
for key, value in kwargs.items():
super().__setattr__(key, value)

# TODO: recheck to see if we need to initialize the client here

self.sync_client = self._init_sync_client()
self.async_client = None

Expand Down Expand Up @@ -72,7 +64,7 @@ def convert_input_to_api_kwargs(
raise NotImplementedError(
f"{type(self).__name__} must implement _combine_input_and_model_kwargs method"
)

def parse_chat_completion(self, completion: Any) -> str:
r"""
Parse the chat completion to a structure your sytem standarizes. (here is str)
Expand All @@ -81,7 +73,6 @@ def parse_chat_completion(self, completion: Any) -> str:
f"{type(self).__name__} must implement parse_chat_completion method"
)


@staticmethod
def _process_text(text: str) -> str:
"""
Expand All @@ -90,15 +81,6 @@ def _process_text(text: str) -> str:
text = text.replace("\n", " ")
return text

# def format_input(self, *, input: Any) -> Any:
# """
# This is specific to APIClient.
# # convert your component input to the API-specific format
# """
# raise NotImplementedError(
# f"{type(self).__name__} must implement format_input method"
# )

def _track_usage(self, **kwargs):
pass

Expand All @@ -112,29 +94,3 @@ def _track_usage(self, **kwargs):

def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)

# def call(
# self,
# input: Any,
# model_kwargs: dict = {},
# model_type: ModelType = ModelType.UNDEFINED,
# ) -> Any:
# # adapt the format and the key for input and model_kwargs
# combined_model_kwargs = self._combine_input_and_model_kwargs(
# input, model_kwargs, model_type=model_type
# )
# return self._call(api_kwargs=combined_model_kwargs, model_type=model_type)

# async def acall(
# self,
# *,
# input: Any,
# model_kwargs: dict = {},
# model_type: ModelType = ModelType.UNDEFINED,
# ) -> Any:
# combined_model_kwargs = self._combine_input_and_model_kwargs(
# input, model_kwargs, model_type=model_type
# )
# return await self._acall(
# api_kwargs=combined_model_kwargs, model_type=model_type
# )
6 changes: 3 additions & 3 deletions core/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _call_unimplemented(self, *input: Any) -> None:

class Component:
r"""
Component defines all functional base classes such as Embedder, Retriever, Generator.
Component is the base class for all LightRAG components, such as Prompt, APIClient, Embedder, Retriever, Generator, etc.
We purposly avoid using the name "Module" to avoid confusion with PyTorch's nn.Module.
As we consider 'Component' to be an extension to 'Moduble' as if you use a local llm model
Expand Down Expand Up @@ -118,7 +118,7 @@ def __delattr__(self, name: str) -> None:
else:
super().__delattr__(name)

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
"""
Normally implemented by subcomponents to print additional positional or keyword arguments.
# NOTE: Dont add components as it will have its own __repr__
Expand All @@ -132,7 +132,7 @@ def _get_name(self):
def __repr__(self):
# We treat the extra repr like the sub-module, one item per line
extra_lines = []
extra_repr = self.extra_repr()
extra_repr = self._extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
Expand Down
2 changes: 1 addition & 1 deletion core/data_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ def __call__(
retriever_output=input, deduplicate=self.deduplicate
)

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
s = f"deduplicate={self.deduplicate}"
return s
8 changes: 4 additions & 4 deletions core/default_prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
"""
See the Promp in developer notes for why we design the system prompt this way.
"""

DEFAULT_LIGHTRAG_SYSTEM_PROMPT = r"""{# task desc #}
{% if task_desc_str %}
{{task_desc_str}}
Expand Down Expand Up @@ -37,3 +33,7 @@
</STEPS>
{% endif %}
"""
"""This is the default system prompt template used in the LightRAG.
Use :ref:`Prompt<core-prompt_builder>` class to manage it.
"""
2 changes: 1 addition & 1 deletion core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def update_default_model_kwargs(self, **model_kwargs) -> Dict:
def print_prompt(self, **kwargs) -> str:
self.system_prompt.print_prompt(**kwargs)

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
s = f"model_kwargs={self.model_kwargs}, model_type={self.model_type}"
return s

Expand Down
72 changes: 53 additions & 19 deletions core/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from core.component import Component

from core.default_prompt_template import DEFAULT_LIGHTRAG_SYSTEM_PROMPT
import logging

logger = logging.getLogger(__name__)


# cache the environment for faster template rendering
@lru_cache(None)
def get_jinja2_environment():
r"""Helper function for Prompt component to get the Jinja2 environment with the default settings."""
try:
default_environment = Environment(
undefined=jinja2.StrictUndefined,
Expand All @@ -27,11 +30,35 @@ def get_jinja2_environment():


class Prompt(Component):
"""
A component that renders a text string from a template using Jinja2 templates.
As inherited from component, it is highly flexible, it can have
other subcomponents which might do things like query expansion, document retrieval if you prefer
to have it here.
__doc__ = r"""A component that renders a text string from a template using Jinja2 templates.
In default, we use the :ref:`DEFAULT_LIGHTRAG_SYSTEM_PROMPT<core-default_prompt_template>` as the template.
Args:
template (str, optional): The Jinja2 template string. Defaults to DEFAULT_LIGHTRAG_SYSTEM_PROMPT.
preset_prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to {}.
Examples:
>>> from core.prompt_builder import Prompt
>>> prompt = Prompt(preset_prompt_kwargs={"task_desc_str": "You are a helpful assistant."})
>>> print(prompt)
>>> prompt.print_prompt_template()
>>> prompt.print_prompt(context_str="This is a context string.")
>>> prompt.call(context_str="This is a context string.")
When examples_str itself is another template with variables, You can use another Prompt to render it.
>>> EXAMPLES_TEMPLATE = r'''
>>> {% if examples %}
>>> {% for example in examples %}
>>> {{loop.index}}. {{example}}
>>> {% endfor %}
>>> {% endif %}
>>> '''
>>> examples_prompt = Prompt(template=EXAMPLES_TEMPLATE)
>>> examples_str = examples_prompt.call(examples=["Example 1", "Example 2"])
>>> # pass it to the main prompt
>>> prompt.print_prompt(examples_str=examples_str)
"""

def __init__(
Expand All @@ -40,6 +67,7 @@ def __init__(
template: str = DEFAULT_LIGHTRAG_SYSTEM_PROMPT,
preset_prompt_kwargs: Optional[Dict] = {},
):

super().__init__()
self._template_string = template
self.template: Template = None
Expand All @@ -55,9 +83,15 @@ def __init__(
self.preset_prompt_kwargs = preset_prompt_kwargs

def update_preset_prompt_kwargs(self, **kwargs):
r"""Update the preset prompt kwargs after Prompt is initialized."""
self.preset_prompt_kwargs.update(kwargs)

def get_prompt_kwargs(self) -> Dict:
r"""Get the prompt kwargs."""
return self.prompt_kwargs

def is_key_in_template(self, key: str) -> bool:
r"""Check if the key exists in the template."""
return key in self.prompt_kwargs

def _find_template_variables(self, template_str: str):
Expand All @@ -66,19 +100,26 @@ def _find_template_variables(self, template_str: str):
return jinja2.meta.find_undeclared_variables(parsed_content)

def compose_prompt_kwargs(self, **kwargs) -> Dict:
r"""Compose the final prompt kwargs by combining the preset_prompt_kwargs and the provided kwargs."""
composed_kwargs = self.prompt_kwargs.copy()
if self.preset_prompt_kwargs:
composed_kwargs.update(self.preset_prompt_kwargs)
# runtime kwargs will overwrite the preset kwargs
if kwargs:
for key, _ in kwargs.items():
if key not in composed_kwargs:
logger.warning(f"Key {key} does not exist in the prompt_kwargs.")
composed_kwargs.update(kwargs)
return composed_kwargs

def print_prompt_template(self):
r"""Print the template string."""
print("Template:")
print(f"-------")
print(f"{self._template_string}")
print(f"-------")

def print_prompt(self, **kwargs):
r"""To better visualize the prompt: as close as the final prompt string.
For task-specific variables, such as task_desc_str, tools_str, we replace the them with the actual values from the preset_prompt_kwargs.
For per-query variables such as query_str, chat_history_str, we leave it as it is in the template using the custom filter none_filter.
"""
r"""Print the rendered prompt string using the preset_prompt_kwargs and the provided kwargs."""
try:
pass_kwargs = self.compose_prompt_kwargs(**kwargs)

Expand All @@ -88,16 +129,9 @@ def print_prompt(self, **kwargs):
except Exception as e:
raise ValueError(f"Error rendering Jinja2 template: {e}")

def print_prompt_template(self):
print("Template:")
print(f"-------")
print(f"{self._template_string}")
print(f"-------")

def call(self, **kwargs) -> str:
"""
Renders the prompt template with the provided variables.
TODO: if there are submodules,
"""
try:
pass_kwargs = self.compose_prompt_kwargs(**kwargs)
Expand All @@ -108,7 +142,7 @@ def call(self, **kwargs) -> str:
except Exception as e:
raise ValueError(f"Error rendering Jinja2 template: {e}")

def extra_repr(self) -> str:
def _extra_repr(self) -> str:
s = f"template: {self._template_string}"
if self.preset_prompt_kwargs:
s += f", preset_prompt_kwargs: {self.preset_prompt_kwargs}"
Expand Down
27 changes: 19 additions & 8 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
exclude_patterns = ["tests", "test_*"]

# exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '**.module.rst', '**/tests/*', '**/test_*.py', '*test.rst']

Expand Down Expand Up @@ -91,16 +91,27 @@ def setup(app):
# relevant when documenting Python modules and their contents, such as classes, functions, and methods.
add_module_names = False


autodoc_docstring_signature = True

# autodoc_default_options = {
# "autosummary-no-titles": True,
# "autosummary-force-inline": True,
# "autosummary-nosignatures": True,
# "members": True,
# "private-members": False, # (those starting with _).
# "special-members": False, # (those starting and ending with __).
# "member-order": "bysource",
# "show-inheritance": True,
# # "undoc-members": True,
# "autosectionlabel_prefix_document": True,
# }
autodoc_default_options = {
"autosummary-no-titles": True,
"autosummary-force-inline": True,
"autosummary-nosignatures": True,
"members": True,
"private-members": False, # (those starting with _).
"special-members": False, # (those starting and ending with __).
"undoc-members": True,
"member-order": "bysource",
"show-inheritance": True,
"undoc-members": True,
"private-members": False, # Ensure this is True if you want to document private members
"special-members": False, # (those starting and ending with __).
# "special-members": "__init__", # Document special members like __init__
"autosectionlabel_prefix_document": True,
}
4 changes: 2 additions & 2 deletions docs/source/developer_notes/component.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract class Component
Component
============
Component is the base classes for all components. It is similar to PyTorch's `nn.Module` class.
:ref:`Component<core-component>` is the base class for all LightRAG components. It is similar to PyTorch's `nn.Module` class.
We name it differently to avoid confusion and also for better compatibility with `PyTorch`.
You write the code similar to how you write a PyTorch model.

Expand Down
Loading

0 comments on commit 104d57e

Please sign in to comment.