Skip to content

Commit

Permalink
More comments
Browse files Browse the repository at this point in the history
  • Loading branch information
movchan74 committed Feb 2, 2024
1 parent f373e33 commit 4f6260a
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 57 deletions.
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,52 @@ that Ruff problems appear while you edit, and formatting is applied
automatically on save.


## Testing

The project uses pytest for testing. To run the tests, use the following command:

```bash
poetry run pytest
```

If you are using VS Code, you can run the tests using the Test Explorer that is installed with the [Python extension](https://code.visualstudio.com/docs/python/testing).

There are a few environment variables that can be set to control the behavior of the tests:
- `USE_DEPLOYMENT_CACHE`: If set to `true`, the tests will use the deployment cache to avoid downloading the models and running the deployments. This is useful for running integration tests faster and in the environment where GPU is not available.
- `SAVE_DEPLOYMENT_CACHE`: If set to `true`, the tests will save the deployment cache after running the deployments. This is useful for updating the deployment cache if new deployments or tests are added.

### How to use the deployment cache environment variables

Here are some examples of how to use the deployment cache environment variables.

#### Do you want to run the tests normally using GPU?

```bash
USE_DEPLOYMENT_CACHE=false
SAVE_DEPLOYMENT_CACHE=false
```

This is the default behavior. The tests will run normally using GPU and the deployment cache will be completely ignored.

#### Do you want to run the tests faster without GPU?

```bash
USE_DEPLOYMENT_CACHE=true
SAVE_DEPLOYMENT_CACHE=false
```

This will run the tests using the deployment cache to avoid downloading the models and running the deployments. The deployment cache will not be updated after running the deployments. Only use it if you are sure that the deployment cache is up to date.

#### Do you want to update the deployment cache?

```bash
USE_DEPLOYMENT_CACHE=false
SAVE_DEPLOYMENT_CACHE=true
```

This will run the tests normally using GPU and save the deployment cache after running the deployments. Use it if you have added new deployments or tests and want to update the deployment cache.


## Databases
The project uses two databases: a vector database as well as a tradtional SQL database,
referred to internally as vectorstore and datastore, respectively.
Expand Down
63 changes: 47 additions & 16 deletions aana/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# This file is used to define fixtures that are used in the integration tests.
# The fixtures are used to setup Ray and Ray Serve, and to call the endpoints.
# The fixtures depend on each other, to setup the environment for the tests.
# Here is a dependency graph of the fixtures:
# ray_setup (session scope, starts Ray cluster)
# -> setup_deployment (module scope, starts Ray deployment, args: deployment)
# -> ray_serve_setup (module scope, starts Ray Serve app, args: endpoints, nodes, context, runtime_env)
# -> app_setup (module scope, starts Ray Serve app for a specific target, args: target)
# -> call_endpoint (module scope, calls endpoint, args: endpoint_path, data, ignore_expected_output, expected_error)


# ruff: noqa: S101
import json
import os
import tempfile
from pathlib import Path
Expand All @@ -26,29 +36,31 @@
is_gpu_available,
is_using_deployment_cache,
)
from aana.utils.json import json_serializer_default
from aana.utils.general import jsonify


@pytest.fixture(scope="session")
def ray_setup():
"""Setup Ray instance."""
"""Setup Ray cluster."""
ray.init(num_cpus=6) # pretend we have 6 cpus
yield
ray.shutdown()


@pytest.fixture(scope="module")
def setup_deployment(ray_setup, request):
"""Setup Ray Deployment."""
def setup_deployment(ray_setup, request): # noqa: D417
"""Setup Ray Deployment.
Args:
deployment: The deployment to start.
bind (bool): Whether to bind the deployment. Defaults to False.
"""

def start_deployment(deployment, bind=False):
"""Start deployment."""
port = portpicker.pick_unused_port()
name = request.node.name.replace("/", "_")
route_prefix = f"/test/{name}"
print(
f"Starting deployment {name} on port {port} with route prefix {route_prefix}"
)
if bind:
if not is_gpu_available() and is_using_deployment_cache():
# if GPU is not available and we are using deployment cache,
Expand All @@ -64,8 +76,15 @@ def start_deployment(deployment, bind=False):


@pytest.fixture(scope="module")
def ray_serve_setup(setup_deployment, request):
"""Setup the Ray Serve from specified endpoints and nodes."""
def ray_serve_setup(setup_deployment, request): # noqa: D417
"""Setup the Ray Serve app from specified endpoints and nodes.
Args:
endpoints: App endpoints.
nodes: App nodes.
context: App context.
runtime_env: The runtime environment. Defaults to None.
"""

def start_ray_serve(endpoints, nodes, context, runtime_env=None):
if runtime_env is None:
Expand All @@ -81,16 +100,21 @@ def start_ray_serve(endpoints, nodes, context, runtime_env=None):


@pytest.fixture(scope="module")
def app_setup(ray_serve_setup):
"""Setup app for a specific target."""
def app_setup(ray_serve_setup): # noqa: D417
"""Setup Ray Serve app for a specific target.
Args:
target: The target deployment.
"""
# create temporary database
tmp_database_path = Path(tempfile.mkstemp(suffix=".db")[1])
db_config = DBConfig(
datastore_type=DbType.SQLITE,
datastore_config=SQLiteConfig(path=tmp_database_path),
)
# set environment variable for the database config so Ray can find it
os.environ["DB_CONFIG"] = json.dumps(db_config, default=json_serializer_default)
os.environ["DB_CONFIG"] = jsonify(db_config)

# set database config in aana settings
aana_settings.db_config = db_config

Expand All @@ -110,7 +134,7 @@ def start_app(target):
deployments = configuration["deployments"]
runtime_env = {
"env_vars": {
"DB_CONFIG": json.dumps(db_config, default=json_serializer_default)
"DB_CONFIG": jsonify(db_config),
}
}
context = {"deployments": {}}
Expand All @@ -133,8 +157,15 @@ def start_app(target):


@pytest.fixture(scope="module")
def call_endpoint(app_setup, request):
"""Call endpoint."""
def call_endpoint(app_setup, request): # noqa: D417
"""Call endpoint.
Args:
endpoint_path: The endpoint path.
data: The data to send.
ignore_expected_output: Whether to ignore the expected output. Defaults to False.
expected_error: The expected error. Defaults to None.
"""
target = request.param
handle, port, route_prefix = app_setup(target)

Expand Down
1 change: 1 addition & 0 deletions aana/tests/integration/test_chat_with_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_chat_with_video(call_endpoint, video):
)


# FIX: See https://github.com/mobiusml/aana_sdk/issues/42
# @pytest.mark.parametrize(
# "endpoint, data",
# [
Expand Down
25 changes: 0 additions & 25 deletions aana/tests/integration/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,28 +77,3 @@ def test_video_transcribe(call_endpoint, video):
VIDEO_TRANSCRIBE_ENDPOINT,
{"video": video},
)


# @pytest.mark.skipif(
# not is_gpu_available() and not is_using_deployment_cache(),
# reason="GPU is not available",
# )
# @pytest.mark.parametrize("call_endpoint", [TARGET], indirect=True)
# @pytest.mark.parametrize(
# "video",
# [
# {
# "path": str(resources.path("aana.tests.files.videos", "physicsworks.webm")),
# "media_id": "physicsworks.webm",
# }
# ],
# )
# def test_video_transcribe2(call_endpoint, video):
# """Test video transcribe endpoint."""
# media_id = video["media_id"]

# # transcribe video
# call_endpoint(
# VIDEO_TRANSCRIBE_ENDPOINT,
# {"video": video},
# )
56 changes: 46 additions & 10 deletions aana/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
drop_all_tables,
run_alembic_migrations,
)
from aana.configs.settings import Settings
from aana.tests.const import ALLOWED_LEVENSTEIN_ERROR_RATE
from aana.utils.general import get_endpoint
from aana.utils.json import json_serializer_default
from aana.utils.general import get_endpoint, jsonify


def is_gpu_available() -> bool:
Expand Down Expand Up @@ -106,15 +106,16 @@ def call_endpoint(
"""Call an endpoint.
Args:
target (str): the name of the target deployment
target (str): the name of the target.
port (int): Port of the server.
route_prefix (str): Route prefix of the server.
endpoint_path (str): Endpoint to call.
data (dict): Data to send to the endpoint.
Returns:
dict | list: Output of the endpoint. If an error occurs, the output will be a dict
with the error message.
dict | list: Output of the endpoint. If the endpoint is a streaming endpoint, the output will be a list of output chunks.
If the endpoint is not a streaming endpoint, the output will be a dict.
If an error occurs, the output will be a dict with the error message.
"""
endpoint = get_endpoint(target, endpoint_path)
if endpoint.streaming:
Expand Down Expand Up @@ -205,8 +206,14 @@ def compare_output(expected_output: dict, output: dict):
compare_texts(expected_json, actual_json)


def clear_database(aana_settings):
"""Clear the database."""
def clear_database(aana_settings: Settings):
"""Clear the database.
It drops all tables and runs alembic migrations to create the tables again.
Args:
aana_settings (Settings): AANA settings.
"""
drop_all_tables(aana_settings)
run_alembic_migrations(aana_settings)

Expand All @@ -219,7 +226,20 @@ def check_output(
ignore_expected_output=False,
expected_error=None,
):
"""Compare output with expected output."""
"""Compare output with expected output.
Args:
target (str): the name of the target.
endpoint_path (str): Endpoint path.
key (str): Key of the expected output.
output (dict | list): Output of the endpoint.
ignore_expected_output (bool, optional): If True, do not compare the output with the expected output. Defaults to False.
expected_error (str | None, optional): Expected error. If not None, the output will be compared with the expected error
and the expected output will be ignored. Defaults to None.
Raises:
AssertionError: if the output is different from the expected output.
"""
endpoint = get_endpoint(target, endpoint_path)
# if we expect an error, then we only check the error
if expected_error:
Expand Down Expand Up @@ -263,8 +283,24 @@ def call_and_check_endpoint(
ignore_expected_output: bool = False,
expected_error: str | None = None,
) -> dict | list:
"""Call and check endpoint."""
data_json = json.dumps(data, default=json_serializer_default)
"""Call endpoint and compare the output with the expected output.
Args:
target (str): the name of the target.
port (int): Port of the server.
route_prefix (str): Route prefix of the server.
endpoint_path (str): Endpoint to call.
data (dict): Data to send to the endpoint.
ignore_expected_output (bool, optional): If True, do not compare the output with the expected output. Defaults to False.
expected_error (str | None, optional): Expected error. If not None, the output will be compared with the expected error
and the expected output will be ignored. Defaults to None.
Returns:
dict | list: Output of the endpoint. If the endpoint is a streaming endpoint, the output will be a list of output chunks.
If the endpoint is not a streaming endpoint, the output will be a dict.
If an error occurs, the output will be a dict with the error message.
"""
data_json = jsonify(data)
# "aana.tests.files.videos" will be resolved to a different path on different systems
# so we need to replace it with a path that is the same on all systems
# to make sure that the hash of the data is the same on all systems
Expand Down
Loading

0 comments on commit 4f6260a

Please sign in to comment.