diff --git a/experimentation/embeddings_warmup.py b/experimentation/embeddings_warmup.py index f4f76dd..a604875 100644 --- a/experimentation/embeddings_warmup.py +++ b/experimentation/embeddings_warmup.py @@ -5,6 +5,7 @@ DEFAULT_EMBEDDING_MODEL, AutoModelEmbeddings, ) +from nxontology_ml.utils import CACHE_DIR def warmup_cache( @@ -14,6 +15,7 @@ def warmup_cache( # Warm up the embedding cache ame = ame or AutoModelEmbeddings.from_pretrained( DEFAULT_EMBEDDING_MODEL, + cache_dir=CACHE_DIR, ) nxo = get_efo_otar_slim() X, _ = read_training_data(take=take) diff --git a/experimentation/model_runner.py b/experimentation/model_runner.py index 98aeb2b..89294f8 100644 --- a/experimentation/model_runner.py +++ b/experimentation/model_runner.py @@ -139,12 +139,7 @@ def run_experiments( SubsetsFeatures(enabled=exp.subsets_enabled), TherapeuticAreaFeatures(enabled=exp.ta_enabled), GptTagFeatures.from_config(exp.gpt_tagger_config), - TextEmbeddingsTransformer.from_config( - enabled=exp.embedding_enabled, - pca_components=exp.pca_components, - use_lda=exp.use_lda, - embedding_model=ame, - ), + TextEmbeddingsTransformer.from_config(conf=exp, embedding_model=ame), CatBoostDataFormatter(), ) mmb.steps_from_pipeline(feature_pipeline) diff --git a/experimentation/tests/embeddings_warmup_test.py b/experimentation/tests/embeddings_warmup_test.py index 7554a64..fb425e5 100644 --- a/experimentation/tests/embeddings_warmup_test.py +++ b/experimentation/tests/embeddings_warmup_test.py @@ -9,8 +9,7 @@ def test_warmup_cache() -> None: ame = AutoModelEmbeddings.from_pretrained( DEFAULT_EMBEDDING_MODEL, - cache_path=ROOT_DIR - / "nxontology_ml/text_embeddings/tests/test_resources/embeddings_cache.ldb", + cache_dir=ROOT_DIR / "nxontology_ml/text_embeddings/tests/test_resources", ) warmup_cache(ame=ame, take=10) assert dict(ame._counter) == {"AutoModelEmbeddings/CACHE_HIT": 10} diff --git a/experimentation/tests/gpt_tags_warmup_test.py b/experimentation/tests/gpt_tags_warmup_test.py index 1a9dca7..18bf4ad 100644 --- a/experimentation/tests/gpt_tags_warmup_test.py +++ b/experimentation/tests/gpt_tags_warmup_test.py @@ -1,12 +1,15 @@ from experimentation.gpt_tags_warmup import warmup_gpt_tags +from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.gpt_tagger.tests._utils import mk_test_gpt_tagger +from nxontology_ml.gpt_tagger.tests.conftest import precision_config # noqa -def test_warmup_gpt_tags() -> None: +def test_warmup_gpt_tags(precision_config: TaskConfig) -> None: # noqa: F811 tagger = mk_test_gpt_tagger( + config=precision_config, cache_content={ "/a93f3eabc24f867ae4f1d6b371ba6734e38ea0a4": b'["medium"]', - } + }, ) warmup_gpt_tags( tagger=tagger, diff --git a/experimentation/tests/model_runner_test.py b/experimentation/tests/model_runner_test.py index 00960f9..9dee9f4 100644 --- a/experimentation/tests/model_runner_test.py +++ b/experimentation/tests/model_runner_test.py @@ -9,7 +9,9 @@ def test_run_experiments(tmp_path: Path) -> None: - ame = AutoModelEmbeddings.from_pretrained(DEFAULT_EMBEDDING_MODEL) + ame = AutoModelEmbeddings.from_pretrained( + DEFAULT_EMBEDDING_MODEL, cache_dir=tmp_path + ) experiments = [ ModelConfig( eval_metric="BiasedMaeMetric", diff --git a/nxontology_ml/gpt_tagger/_cache.py b/nxontology_ml/gpt_tagger/_cache.py index f635e9f..50da3f0 100644 --- a/nxontology_ml/gpt_tagger/_cache.py +++ b/nxontology_ml/gpt_tagger/_cache.py @@ -10,7 +10,6 @@ from nxontology_ml.gpt_tagger._models import TaskConfig from nxontology_ml.gpt_tagger._utils import config_to_cache_namespace, counter_or_empty -from nxontology_ml.utils import ROOT_DIR class _Cache: @@ -70,12 +69,10 @@ def from_config( cls, config: TaskConfig, counter: Counter[str] | None = None, - cache_path: Path | None = None, ) -> "_Cache": cache_namespace = config_to_cache_namespace(config) - if not cache_path: - cache_path = ROOT_DIR / f".cache/{cache_namespace}.ldb" - cache_path.parent.mkdir(parents=True, exist_ok=True) + config.cache_dir.mkdir(parents=True, exist_ok=True) + cache_path = config.cache_dir / f"{cache_namespace}.ldb" return cls( storage=LazyLSM(cache_path.as_posix()), namespace="", # Namespace is already part of the storage path diff --git a/nxontology_ml/gpt_tagger/_models.py b/nxontology_ml/gpt_tagger/_models.py index d99002b..aa2d9bf 100644 --- a/nxontology_ml/gpt_tagger/_models.py +++ b/nxontology_ml/gpt_tagger/_models.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pathlib import Path -from nxontology_ml.utils import ROOT_DIR +from nxontology_ml.utils import CACHE_DIR, ROOT_DIR LOG_DIR = ROOT_DIR / "logs/openai-api" @@ -70,6 +70,8 @@ class TaskConfig: # Optionally persist logs to disk logs_path: Path | None = LOG_DIR + cache_dir: Path = CACHE_DIR + @dataclass class LabelledNode: diff --git a/nxontology_ml/gpt_tagger/tests/_cache_test.py b/nxontology_ml/gpt_tagger/tests/_cache_test.py index ab516d6..8a4533e 100644 --- a/nxontology_ml/gpt_tagger/tests/_cache_test.py +++ b/nxontology_ml/gpt_tagger/tests/_cache_test.py @@ -1,33 +1,38 @@ from pathlib import Path from tempfile import TemporaryDirectory +from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.gpt_tagger._cache import LazyLSM, _Cache -from nxontology_ml.gpt_tagger.tests._utils import precision_config from nxontology_ml.utils import ROOT_DIR def test_from_config() -> None: - expected_cache_path = ROOT_DIR / ".cache/precision_v1_n1.ldb" - cache = _Cache.from_config(precision_config) + config = TaskConfig( + name="precision", + prompt_path=ROOT_DIR / "prompts/precision_v1.txt", + openai_model_name="gpt-4", + node_attributes=["efo_id", "efo_label", "efo_definition"], + model_n=3, + ) + expected_cache_path = Path("/tmp/nxontology-ml/cache/precision_v1_n3.ldb") + cache = _Cache.from_config(config) assert isinstance(cache._storage, LazyLSM) assert Path(cache._storage._filename) == expected_cache_path assert cache._key_hash_fn == "sha1" assert cache._namespace == "" -def test_main() -> None: - with TemporaryDirectory() as tmpdir: - cache_path = Path(tmpdir) / "precision_v1_n1.ldb" - cache = _Cache.from_config(precision_config, cache_path=cache_path) +def test_main(precision_config: TaskConfig) -> None: + cache = _Cache.from_config(precision_config) - assert cache.get("KEY", "DEFAULT") == "DEFAULT" - cache["KEY"] = "value" - assert cache.get("KEY", "DEFAULT") == "value" + assert cache.get("KEY", "DEFAULT") == "DEFAULT" + cache["KEY"] = "value" + assert cache.get("KEY", "DEFAULT") == "value" - cache2 = _Cache.from_config(precision_config, cache_path=cache_path) - cache2["KEY"] = "value" - del cache2["KEY"] - assert cache2.get("KEY", "DEFAULT") == "DEFAULT" + cache2 = _Cache.from_config(precision_config) + cache2["KEY"] = "value" + del cache2["KEY"] + assert cache2.get("KEY", "DEFAULT") == "DEFAULT" def test_LazyLSM() -> None: diff --git a/nxontology_ml/gpt_tagger/tests/_chat_completion_middleware_test.py b/nxontology_ml/gpt_tagger/tests/_chat_completion_middleware_test.py index ad0abfd..e898d24 100644 --- a/nxontology_ml/gpt_tagger/tests/_chat_completion_middleware_test.py +++ b/nxontology_ml/gpt_tagger/tests/_chat_completion_middleware_test.py @@ -8,6 +8,7 @@ import pytest from _pytest._py.path import LocalPath +from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.gpt_tagger._chat_completion_middleware import ( _ChatCompletionMiddleware, ) @@ -19,7 +20,6 @@ from nxontology_ml.gpt_tagger._utils import node_to_str_fn from nxontology_ml.gpt_tagger.tests._utils import ( mk_stub_ccm, - precision_config, sanitize_json_format, ) from nxontology_ml.tests.utils import get_test_nodes, read_test_resource @@ -67,7 +67,7 @@ def test_ctor_verify() -> None: _mk_test_ccm(prompt_template="foo") -def test_create(tmpdir: LocalPath) -> None: +def test_create(tmpdir: LocalPath, precision_config: TaskConfig) -> None: logdir = Path(tmpdir) / "logs" config = copy(precision_config) config.model_temperature = 1 @@ -103,7 +103,7 @@ def test_create(tmpdir: LocalPath) -> None: assert sanitize_json_format(resp_file.read_text()) == json_resp -def test_from_config() -> None: +def test_from_config(precision_config: TaskConfig) -> None: ccm = _ChatCompletionMiddleware.from_config(precision_config) assert ccm._partial_payload["model"] == "gpt-3.5-turbo" assert ccm._partial_payload["messages"][0]["content"] == "__PLACEHOLDER__" diff --git a/nxontology_ml/gpt_tagger/tests/_features_test.py b/nxontology_ml/gpt_tagger/tests/_features_test.py index 2954489..f15bba0 100644 --- a/nxontology_ml/gpt_tagger/tests/_features_test.py +++ b/nxontology_ml/gpt_tagger/tests/_features_test.py @@ -8,10 +8,10 @@ from nxontology_ml.data import get_efo_otar_slim from nxontology_ml.features import PrepareNodeFeatures -from nxontology_ml.gpt_tagger import GptTagger +from nxontology_ml.gpt_tagger import GptTagger, TaskConfig from nxontology_ml.gpt_tagger._features import DEFAULT_CONF, GptTagFeatures from nxontology_ml.gpt_tagger._openai_models import Response -from nxontology_ml.gpt_tagger.tests._utils import mk_test_gpt_tagger, precision_config +from nxontology_ml.gpt_tagger.tests._utils import mk_test_gpt_tagger from nxontology_ml.sklearn_transformer import NodeFeatures from nxontology_ml.tests.utils import read_test_resource from nxontology_ml.utils import ROOT_DIR @@ -27,13 +27,19 @@ def sampled_nxo() -> NXOntology[str]: @pytest.fixture -def tagger() -> GptTagger: +def tagger(precision_config: TaskConfig) -> GptTagger: expected_req = read_test_resource("precision_payload.json") stub_resp = Response(**json.loads(read_test_resource("precision_resp.json"))) # type: ignore[misc] - return mk_test_gpt_tagger(stub_content={expected_req: stub_resp}, cache_content={}) + return mk_test_gpt_tagger( + config=precision_config, + stub_content={expected_req: stub_resp}, + cache_content={}, + ) -def test_transform(tagger: GptTagger, sampled_nxo: NXOntology[str]) -> None: +def test_transform( + tagger: GptTagger, sampled_nxo: NXOntology[str], precision_config: TaskConfig +) -> None: p = make_pipeline( PrepareNodeFeatures(sampled_nxo), GptTagFeatures( @@ -50,7 +56,9 @@ def test_transform(tagger: GptTagger, sampled_nxo: NXOntology[str]) -> None: assert_frame_equal(df, expected_df) -def test_disabled(tagger: GptTagger, sampled_nxo: NXOntology[str]) -> None: +def test_disabled( + tagger: GptTagger, sampled_nxo: NXOntology[str], precision_config: TaskConfig +) -> None: p = make_pipeline( PrepareNodeFeatures(sampled_nxo), GptTagFeatures( diff --git a/nxontology_ml/gpt_tagger/tests/_gpt_tagger_test.py b/nxontology_ml/gpt_tagger/tests/_gpt_tagger_test.py index 2664c39..6a871ee 100644 --- a/nxontology_ml/gpt_tagger/tests/_gpt_tagger_test.py +++ b/nxontology_ml/gpt_tagger/tests/_gpt_tagger_test.py @@ -9,15 +9,15 @@ from nxontology_ml.data import get_efo_otar_slim from nxontology_ml.gpt_tagger._gpt_tagger import GptTagger -from nxontology_ml.gpt_tagger._models import LabelledNode +from nxontology_ml.gpt_tagger._models import LabelledNode, TaskConfig from nxontology_ml.gpt_tagger._openai_models import Response -from nxontology_ml.gpt_tagger.tests._utils import mk_test_gpt_tagger, precision_config +from nxontology_ml.gpt_tagger.tests._utils import mk_test_gpt_tagger from nxontology_ml.tests.utils import get_test_nodes, read_test_resource -def test_fetch_labels() -> None: +def test_fetch_labels(precision_config: TaskConfig) -> None: cache_content: dict[str, bytes] = {} - tagger = mk_test_gpt_tagger(cache_content) + tagger = mk_test_gpt_tagger(precision_config, cache_content) labels = tagger.fetch_labels(get_test_nodes()) assert list(labels) == [ LabelledNode(node_efo_id="DOID:0050890", labels=["medium"]), @@ -41,13 +41,13 @@ def test_fetch_labels() -> None: } -def test_fetch_labels_cached() -> None: +def test_fetch_labels_cached(precision_config: TaskConfig) -> None: # Pre-loaded cache cache_content = { "/7665404d4f2728a09ed26b8ebf2b3be612bd7da2": b'["medium"]', "/962b25d69f79f600f23a17e2c3fe79948013b4de": b'["medium"]', } - tagger = mk_test_gpt_tagger(cache_content) + tagger = mk_test_gpt_tagger(precision_config, cache_content) labels = tagger.fetch_labels(get_test_nodes()) assert list(labels) == [ LabelledNode(node_efo_id="DOID:0050890", labels=["medium"]), @@ -56,13 +56,13 @@ def test_fetch_labels_cached() -> None: assert tagger.get_metrics() == Counter({"Cache/get": 2, "Cache/hits": 2}) -def test_fetch_many_records() -> None: +def test_fetch_many_records(precision_config: TaskConfig) -> None: # Disable caching class PassthroughDict(dict[str, bytes]): def __setitem__(self, key: str, value: bytes) -> None: return - tagger = mk_test_gpt_tagger(cache_content=PassthroughDict()) + tagger = mk_test_gpt_tagger(precision_config, cache_content=PassthroughDict()) def _r(n: int) -> Response: r = json.loads(read_test_resource("precision_resp.json")) @@ -92,8 +92,8 @@ def _r(n: int) -> Response: ) -def test_get_metrics() -> None: - tagger = mk_test_gpt_tagger(cache_content={}) +def test_get_metrics(precision_config: TaskConfig) -> None: + tagger = mk_test_gpt_tagger(precision_config, cache_content={}) tagger._counter["test"] += 42 # Defensive copy: No effect @@ -107,7 +107,7 @@ def test_get_metrics() -> None: assert tagger.get_metrics() == Counter({"test": 43}) -def test_from_config() -> None: +def test_from_config(precision_config: TaskConfig) -> None: counter: Counter[str] = Counter() tagger = GptTagger.from_config(precision_config, counter=counter) @@ -118,13 +118,15 @@ def test_from_config() -> None: assert id(tagger._cache._counter) == counter_id -def test_resp_truncated() -> None: +def test_resp_truncated(precision_config: TaskConfig) -> None: stub_resp = Response(**json.loads(read_test_resource("precision_resp.json"))) # type: ignore[misc] assert stub_resp["choices"][0]["finish_reason"] == "stop" stub_resp["choices"][0]["finish_reason"] = "length" # Simulate resp truncation expected_req = read_test_resource("precision_payload.json") tagger = mk_test_gpt_tagger( - stub_content={expected_req: stub_resp}, cache_content={} + config=precision_config, + stub_content={expected_req: stub_resp}, + cache_content={}, ) with pytest.raises( ValueError, @@ -142,11 +144,13 @@ def _assert_user_warning_starts_with(warn: WarningMessage, s: str) -> None: assert warn_msg.startswith(s) -def test_resp_id_mismatch() -> None: +def test_resp_id_mismatch(precision_config: TaskConfig) -> None: expected_req = read_test_resource("mismatch_payload.json") stub_resp = Response(**json.loads(read_test_resource("mismatch_resp.json"))) # type: ignore[misc] tagger = mk_test_gpt_tagger( - stub_content={expected_req: stub_resp}, cache_content={} + config=precision_config, + stub_content={expected_req: stub_resp}, + cache_content={}, ) nxo = get_efo_otar_slim() valid_resp_node = "DOID:0050890" diff --git a/nxontology_ml/gpt_tagger/tests/_integration_test.py b/nxontology_ml/gpt_tagger/tests/_integration_test.py index fc1d80b..5cc0679 100644 --- a/nxontology_ml/gpt_tagger/tests/_integration_test.py +++ b/nxontology_ml/gpt_tagger/tests/_integration_test.py @@ -12,13 +12,12 @@ from nxontology_ml.gpt_tagger._models import TaskConfig from nxontology_ml.gpt_tagger._openai_models import Response from nxontology_ml.gpt_tagger._utils import node_to_str_fn -from nxontology_ml.gpt_tagger.tests._utils import precision_config from nxontology_ml.tests.utils import get_test_nodes, read_test_resource from nxontology_ml.utils import ROOT_DIR @pytest.mark.skip(reason="IT: Makes a real openai api call") -def test_chat_completion_precision_it() -> None: +def test_chat_completion_precision_it(precision_config: TaskConfig) -> None: # NOTE: Flaky API response, even with temp=0 :( # NOTE: Needs an OPENAI_API_KEY setup, see main README.md ccm = _ChatCompletionMiddleware.from_config(precision_config) diff --git a/nxontology_ml/gpt_tagger/tests/_tiktoken_batcher_test.py b/nxontology_ml/gpt_tagger/tests/_tiktoken_batcher_test.py index 2c3c892..dc0a339 100644 --- a/nxontology_ml/gpt_tagger/tests/_tiktoken_batcher_test.py +++ b/nxontology_ml/gpt_tagger/tests/_tiktoken_batcher_test.py @@ -6,9 +6,9 @@ import tiktoken from tiktoken import Encoding +from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.gpt_tagger._openai_models import _4K from nxontology_ml.gpt_tagger._tiktoken_batcher import _TiktokenBatcher -from nxontology_ml.gpt_tagger.tests._utils import precision_config from nxontology_ml.tests.utils import get_test_resource_path @@ -54,7 +54,9 @@ def test_add_tokens(tiktoken_cl100k_encoding: Encoding) -> None: batcher._do_add_record_to_buffer(record) -def test_from_config(tiktoken_cl100k_encoding: Encoding) -> None: +def test_from_config( + tiktoken_cl100k_encoding: Encoding, precision_config: TaskConfig +) -> None: # Valid config batcher = _TiktokenBatcher.from_config(precision_config) assert batcher._tiktoken_encoding == tiktoken_cl100k_encoding diff --git a/nxontology_ml/gpt_tagger/tests/_utils.py b/nxontology_ml/gpt_tagger/tests/_utils.py index 5172b63..bba5e95 100644 --- a/nxontology_ml/gpt_tagger/tests/_utils.py +++ b/nxontology_ml/gpt_tagger/tests/_utils.py @@ -10,17 +10,7 @@ from nxontology_ml.gpt_tagger._models import TaskConfig from nxontology_ml.gpt_tagger._openai_models import Response from nxontology_ml.gpt_tagger._tiktoken_batcher import _TiktokenBatcher -from nxontology_ml.tests.utils import get_test_resource_path, read_test_resource - -precision_config = TaskConfig( - name="precision", - prompt_path=get_test_resource_path("precision_v1.txt"), - node_attributes=["efo_id", "efo_label", "efo_definition"], - openai_model_name="gpt-3.5-turbo", - model_temperature=0, - allowed_labels=frozenset({"low", "medium", "high"}), - logs_path=None, # Don't log during tests (unless integration) -) +from nxontology_ml.tests.utils import read_test_resource def sanitize_json_format(s: str | dict[str, Any]) -> str: @@ -32,12 +22,10 @@ def sanitize_json_format(s: str | dict[str, Any]) -> str: def mk_stub_ccm( - config: TaskConfig | None = None, + config: TaskConfig, stub_content: dict[str, Response] | None = None, counter: Counter[str] | None = None, ) -> _ChatCompletionMiddleware: - if not config: - config = precision_config if not stub_content: stub_payload_json = read_test_resource("precision_payload.json") stub_resp = Response(**json.loads(read_test_resource("precision_resp.json"))) # type: ignore @@ -54,9 +42,9 @@ def create_fn_stub(**kwargs: ParamSpecKwargs) -> Response: def mk_test_gpt_tagger( + config: TaskConfig, cache_content: dict[str, bytes], stub_content: dict[str, Response] | None = None, - config: TaskConfig = precision_config, ) -> GptTagger: """ Helper to build test GptTagger instances diff --git a/nxontology_ml/gpt_tagger/tests/_utils_test.py b/nxontology_ml/gpt_tagger/tests/_utils_test.py index f4daa1d..5202a8c 100644 --- a/nxontology_ml/gpt_tagger/tests/_utils_test.py +++ b/nxontology_ml/gpt_tagger/tests/_utils_test.py @@ -1,5 +1,6 @@ from textwrap import dedent +from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.gpt_tagger._utils import ( config_to_cache_namespace, efo_id_from_yaml, @@ -7,15 +8,14 @@ node_to_str_fn, parse_model_output, ) -from nxontology_ml.gpt_tagger.tests._utils import precision_config from nxontology_ml.tests.utils import get_test_nodes -def test_config_to_cache_namespace() -> None: +def test_config_to_cache_namespace(precision_config: TaskConfig) -> None: assert config_to_cache_namespace(precision_config) == "precision_v1_n1" -def test_node_to_str_fn() -> None: +def test_node_to_str_fn(precision_config: TaskConfig) -> None: test_nodes = get_test_nodes() fn = node_to_str_fn(precision_config) expected_str = """\ @@ -47,7 +47,7 @@ def test_node_efo_id() -> None: assert node_efo_id(node) == "DOID:0050890" -def test_efo_id_from_yaml() -> None: +def test_efo_id_from_yaml(precision_config: TaskConfig) -> None: node = next(iter(get_test_nodes())) fn = node_to_str_fn(precision_config) assert efo_id_from_yaml(fn(node)) == "DOID:0050890" diff --git a/nxontology_ml/gpt_tagger/tests/conftest.py b/nxontology_ml/gpt_tagger/tests/conftest.py new file mode 100644 index 0000000..6b622ff --- /dev/null +++ b/nxontology_ml/gpt_tagger/tests/conftest.py @@ -0,0 +1,20 @@ +from pathlib import Path + +import pytest + +from nxontology_ml.gpt_tagger import TaskConfig +from nxontology_ml.tests.utils import get_test_resource_path + + +@pytest.fixture +def precision_config(tmp_path: Path) -> TaskConfig: + return TaskConfig( + name="precision", + prompt_path=get_test_resource_path("precision_v1.txt"), + node_attributes=["efo_id", "efo_label", "efo_definition"], + openai_model_name="gpt-3.5-turbo", + model_temperature=0, + allowed_labels=frozenset({"low", "medium", "high"}), + logs_path=None, # Don't log during tests (unless integration) + cache_dir=tmp_path / "cache", + ) diff --git a/nxontology_ml/model/config.py b/nxontology_ml/model/config.py index f64d09e..6f21a82 100644 --- a/nxontology_ml/model/config.py +++ b/nxontology_ml/model/config.py @@ -4,7 +4,7 @@ from nxontology_ml.gpt_tagger import TaskConfig from nxontology_ml.model.utils import BiasedMaeMetric -from nxontology_ml.utils import ROOT_DIR +from nxontology_ml.utils import CACHE_DIR, ROOT_DIR EXPERIMENT_MODEL_DIR = ROOT_DIR / "data/experiments" @@ -33,6 +33,7 @@ class ModelConfig(BaseModel): # type: ignore[misc] iterations: int = 5000 eval_metric: str = "MultiClass" base_dir: Path = EXPERIMENT_MODEL_DIR + cache_dir: Path = CACHE_DIR @property def name(self) -> str: # noqa: C901 diff --git a/nxontology_ml/model/train.py b/nxontology_ml/model/train.py index c45b851..98615df 100644 --- a/nxontology_ml/model/train.py +++ b/nxontology_ml/model/train.py @@ -31,7 +31,8 @@ def train_model( (X, y) = training_set or read_training_data( filter_out_non_disease=True, nxo=nxo, take=take ) - + if conf.gpt_tagger_config: + assert conf.gpt_tagger_config.cache_dir == conf.cache_dir feature_pipeline: Pipeline = make_pipeline( PrepareNodeFeatures(nxo=nxo), NodeInfoFeatures(), @@ -39,11 +40,7 @@ def train_model( SubsetsFeatures(enabled=conf.subsets_enabled), TherapeuticAreaFeatures(enabled=conf.ta_enabled), GptTagFeatures.from_config(conf.gpt_tagger_config), - TextEmbeddingsTransformer.from_config( - enabled=conf.embedding_enabled, - pca_components=conf.pca_components, - use_lda=conf.use_lda, - ), + TextEmbeddingsTransformer.from_config(conf=conf), CatBoostDataFormatter(), ) X_transform = feature_pipeline.fit_transform(X, y) diff --git a/nxontology_ml/text_embeddings/embeddings_model.py b/nxontology_ml/text_embeddings/embeddings_model.py index b447963..035a6a2 100644 --- a/nxontology_ml/text_embeddings/embeddings_model.py +++ b/nxontology_ml/text_embeddings/embeddings_model.py @@ -18,16 +18,15 @@ from transformers.modeling_outputs import ModelOutput from nxontology_ml.gpt_tagger._cache import LazyLSM -from nxontology_ml.utils import ROOT_DIR DEFAULT_EMBEDDING_MODEL = "michiyasunaga/BioLinkBERT-base" EMBEDDING_SIZES: dict[str, int] = {DEFAULT_EMBEDDING_MODEL: 768} _model_poolers: dict[str, str] = {DEFAULT_EMBEDDING_MODEL: "pooler_output"} -def _cache_path(pretrained_model_name: str) -> Path: +def _cache_filename(pretrained_model_name: str) -> str: safe_prefix = re.sub("[^0-9a-zA-Z]+", "_", pretrained_model_name) - return ROOT_DIR / f".cache/{safe_prefix}.ldb" + return f"{safe_prefix}.ldb" class _LazyAutoModel: @@ -110,7 +109,7 @@ def embed_text(self, text: str) -> np.ndarray: def from_pretrained( cls, pretrained_model_name: str, - cache_path: Path | None = None, + cache_dir: Path, lazy_model: _LazyAutoModel | None = None, counter: Counter[str] | None = None, ) -> "AutoModelEmbeddings": @@ -118,11 +117,12 @@ def from_pretrained( Note: pretrained_model_name should be an encoder only model (e.g. BERT) """ # FIXME: should we add truncation of input?? - cache_filename = (cache_path or _cache_path(pretrained_model_name)).as_posix() - logging.info(f"Caching embeddings into: {cache_filename}") + + cache_file = cache_dir / _cache_filename(pretrained_model_name) + logging.info(f"Caching embeddings into: {cache_file}") return cls( lazy_model=lazy_model or _LazyAutoModel(pretrained_model_name), pooler_attr=_model_poolers[pretrained_model_name], - cache=LazyLSM(filename=cache_filename), + cache=LazyLSM(filename=cache_file.as_posix()), counter=counter or Counter(), ) diff --git a/nxontology_ml/text_embeddings/tests/conftest.py b/nxontology_ml/text_embeddings/tests/conftest.py index b74eb90..4ecf4a6 100644 --- a/nxontology_ml/text_embeddings/tests/conftest.py +++ b/nxontology_ml/text_embeddings/tests/conftest.py @@ -8,9 +8,9 @@ @pytest.fixture -def embeddings_test_cache() -> Path: +def embeddings_cache_dir() -> Path: # We don't want to fetch embeddings over the internet during unit tests - return get_test_resource_path("embeddings_cache.ldb") + return get_test_resource_path("") @pytest.fixture diff --git a/nxontology_ml/text_embeddings/tests/embeddings_model_test.py b/nxontology_ml/text_embeddings/tests/embeddings_model_test.py index 6a4b5f1..994f1e2 100644 --- a/nxontology_ml/text_embeddings/tests/embeddings_model_test.py +++ b/nxontology_ml/text_embeddings/tests/embeddings_model_test.py @@ -13,16 +13,15 @@ DEFAULT_EMBEDDING_MODEL, EMBEDDING_SIZES, AutoModelEmbeddings, - _cache_path, + _cache_filename, _LazyAutoModel, ) -from nxontology_ml.utils import ROOT_DIR -def test_embed_node(nxo: NXOntology[str], embeddings_test_cache: Path) -> None: +def test_embed_node(nxo: NXOntology[str], embeddings_cache_dir: Path) -> None: ame = AutoModelEmbeddings.from_pretrained( pretrained_model_name=DEFAULT_EMBEDDING_MODEL, - cache_path=embeddings_test_cache, + cache_dir=embeddings_cache_dir, ) X, _ = read_training_data(nxo=nxo, take=10) vecs = np.array([ame.embed_node(nxo.node_info(node_id)) for node_id in X]) @@ -57,7 +56,7 @@ def model(self) -> PreTrainedModel: ame = AutoModelEmbeddings.from_pretrained( pretrained_model_name=DEFAULT_EMBEDDING_MODEL, lazy_model=model_mock, - cache_path=Path(tmp_path / "cache.ldb"), + cache_dir=tmp_path, ) test_node = "DOID:0050890" vec = ame.embed_node(nxo.node_info(test_node)) @@ -95,9 +94,9 @@ def test_lazy_automodel() -> None: tokenizer_cls_mock.from_pretrained.assert_called_once_with(model_name) -def test_cache_path() -> None: - p = _cache_path(pretrained_model_name=DEFAULT_EMBEDDING_MODEL) - assert p == ROOT_DIR / ".cache/michiyasunaga_BioLinkBERT_base.ldb" +def test_cache_filename() -> None: + fn = _cache_filename(pretrained_model_name=DEFAULT_EMBEDDING_MODEL) + assert fn == "michiyasunaga_BioLinkBERT_base.ldb" @pytest.mark.skip(reason="Pulls resources off the internet") @@ -105,7 +104,7 @@ def test_builder(tmp_path: Path) -> None: ame = AutoModelEmbeddings.from_pretrained( pretrained_model_name=DEFAULT_EMBEDDING_MODEL, # Ensure that cache doesn't exit - cache_path=Path(tmp_path / "cache.ldb"), + cache_dir=tmp_path, ) vec = ame.embed_text("Sunitinib is a tyrosine kinase inhibitor") assert vec.shape == (EMBEDDING_SIZES[DEFAULT_EMBEDDING_MODEL],) diff --git a/nxontology_ml/text_embeddings/tests/test_resources/embeddings_cache.ldb b/nxontology_ml/text_embeddings/tests/test_resources/michiyasunaga_BioLinkBERT_base.ldb similarity index 100% rename from nxontology_ml/text_embeddings/tests/test_resources/embeddings_cache.ldb rename to nxontology_ml/text_embeddings/tests/test_resources/michiyasunaga_BioLinkBERT_base.ldb diff --git a/nxontology_ml/text_embeddings/tests/text_embeddings_transformer_test.py b/nxontology_ml/text_embeddings/tests/text_embeddings_transformer_test.py index 0b99d4b..1e7641f 100644 --- a/nxontology_ml/text_embeddings/tests/text_embeddings_transformer_test.py +++ b/nxontology_ml/text_embeddings/tests/text_embeddings_transformer_test.py @@ -6,6 +6,7 @@ from nxontology_ml.data import read_training_data from nxontology_ml.features import PrepareNodeFeatures +from nxontology_ml.model.config import ModelConfig from nxontology_ml.sklearn_transformer import ( NodeFeatures, ) @@ -19,19 +20,21 @@ ) -def test_end_to_end(nxo: NXOntology[str], embeddings_test_cache: Path) -> None: +def test_end_to_end( + nxo: NXOntology[str], + embeddings_cache_dir: Path, +) -> None: cached_ame = AutoModelEmbeddings.from_pretrained( pretrained_model_name=DEFAULT_EMBEDDING_MODEL, - cache_path=embeddings_test_cache, + cache_dir=embeddings_cache_dir, ) pnf = PrepareNodeFeatures(nxo) X, y = read_training_data(nxo=nxo, take=10) + conf = ModelConfig(cache_dir=embeddings_cache_dir) ## # Disabled testing - tet = TextEmbeddingsTransformer.from_config( - enabled=False, embedding_model=cached_ame - ) + tet = TextEmbeddingsTransformer.from_config(conf=conf, embedding_model=cached_ame) nf = make_pipeline(pnf, tet).fit_transform(X, y) assert isinstance(nf, NodeFeatures) assert len(nf.cat_features) == 0 @@ -39,9 +42,8 @@ def test_end_to_end(nxo: NXOntology[str], embeddings_test_cache: Path) -> None: ## # Full embedding Testing - tet = TextEmbeddingsTransformer.from_config( - use_lda=False, embedding_model=cached_ame - ) + conf.embedding_enabled = True + tet = TextEmbeddingsTransformer.from_config(conf, embedding_model=cached_ame) nf = make_pipeline(pnf, tet).fit_transform(X, y) assert isinstance(nf, NodeFeatures) assert len(nf.cat_features) == 0 @@ -53,9 +55,8 @@ def test_end_to_end(nxo: NXOntology[str], embeddings_test_cache: Path) -> None: ## # LDA Testing - tet = TextEmbeddingsTransformer.from_config( - use_lda=True, embedding_model=cached_ame - ) + conf.use_lda = True + tet = TextEmbeddingsTransformer.from_config(conf=conf, embedding_model=cached_ame) nf = make_pipeline(pnf, tet).fit_transform(X, y) assert isinstance(nf, NodeFeatures) assert len(nf.cat_features) == 0 @@ -67,9 +68,9 @@ def test_end_to_end(nxo: NXOntology[str], embeddings_test_cache: Path) -> None: ## # PCA Testing - tet = TextEmbeddingsTransformer.from_config( - use_lda=False, pca_components=8, embedding_model=cached_ame - ) + conf.use_lda = False + conf.pca_components = 8 + tet = TextEmbeddingsTransformer.from_config(conf, embedding_model=cached_ame) nf = make_pipeline(pnf, tet).fit_transform(X, y) assert isinstance(nf, NodeFeatures) assert len(nf.cat_features) == 0 diff --git a/nxontology_ml/text_embeddings/text_embeddings_transformer.py b/nxontology_ml/text_embeddings/text_embeddings_transformer.py index 065da94..612c1ef 100644 --- a/nxontology_ml/text_embeddings/text_embeddings_transformer.py +++ b/nxontology_ml/text_embeddings/text_embeddings_transformer.py @@ -8,6 +8,7 @@ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA from tqdm import tqdm +from nxontology_ml.model.config import ModelConfig from nxontology_ml.sklearn_transformer import ( NodeFeatures, ) @@ -99,18 +100,17 @@ def _nodes_to_vec(self, X: NodeFeatures) -> np.ndarray: @classmethod def from_config( cls, - enabled: bool = True, - use_lda: bool = True, - pca_components: int | None = None, + conf: ModelConfig, pretrained_model_name: str = DEFAULT_EMBEDDING_MODEL, embedding_model: AutoModelEmbeddings | None = None, ) -> "TextEmbeddingsTransformer": return cls( - enabled=enabled, - lda=LDA() if use_lda else None, - pca=PCA(n_components=pca_components) if pca_components else None, + enabled=conf.embedding_enabled, + lda=LDA() if conf.use_lda else None, + pca=PCA(n_components=conf.pca_components) if conf.pca_components else None, embedding_model=embedding_model or AutoModelEmbeddings.from_pretrained( - pretrained_model_name=pretrained_model_name + pretrained_model_name=pretrained_model_name, + cache_dir=conf.cache_dir, ), ) diff --git a/nxontology_ml/utils.py b/nxontology_ml/utils.py index f8686a7..b51e45c 100644 --- a/nxontology_ml/utils.py +++ b/nxontology_ml/utils.py @@ -5,6 +5,9 @@ ROOT_DIR: Path = Path(__file__).parent.parent +# Will require override on Windows +CACHE_DIR = Path("/tmp/nxontology-ml/cache") + def get_output_directory(nxo: NXOntology[NodeT], parent_dir: Path = ROOT_DIR) -> Path: """Get output directory for an nxontology, using the ontology name for the directory."""