Skip to content

Commit

Permalink
Migrate Checkpointing to Kura (#8)
Browse files Browse the repository at this point in the history
Migrate checkpointing to within the kura class itself  for more flexible checkpointing
  • Loading branch information
ivanleomk authored Jan 19, 2025
1 parent 84c7701 commit 60a797a
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 116 deletions.
3 changes: 1 addition & 2 deletions kura/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ async def analyse_conversations(conversation_data: ConversationData):
checkpoint_dir=str(
Path(os.path.abspath(os.environ["KURA_CHECKPOINT_DIR"]))
),
conversations=conversations,
)
await kura.cluster_conversations()
clusters = await kura.cluster_conversations(conversations)

with open(clusters_file) as f:
clusters_data = []
Expand Down
29 changes: 1 addition & 28 deletions kura/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,11 @@ def __init__(
client=instructor.from_gemini(
genai.GenerativeModel("gemini-1.5-flash-latest"), use_async=True
),
checkpoint_dir: str = "checkpoints",
checkpoint_file: str = "cluster_checkpoint.json",
):
self.clustering_method = clustering_method
self.embedding_model = embedding_model
self.max_concurrent_requests = max_concurrent_requests
self.client = client
self.checkpoint_dir = checkpoint_dir
self.checkpoint_file = checkpoint_file

if not os.path.exists(self.checkpoint_dir):
print(f"Creating checkpoint directory {self.checkpoint_dir}")
os.makedirs(self.checkpoint_dir)

def save_checkpoint(self, clusters: list[Cluster]):
with open(os.path.join(self.checkpoint_dir, self.checkpoint_file), "w") as f:
for cluster in clusters:
f.write(cluster.model_dump_json() + "\n")

print(
f"Saved checkpoint to {os.path.join(self.checkpoint_dir, self.checkpoint_file)}"
)

def load_checkpoint(self) -> list[Cluster]:
print(
f"Loading Cluster Checkpoint from {os.path.join(self.checkpoint_dir, self.checkpoint_file)}"
)
with open(os.path.join(self.checkpoint_dir, self.checkpoint_file), "r") as f:
return [Cluster.model_validate_json(line) for line in f]

def get_contrastive_examples(
self,
Expand Down Expand Up @@ -129,9 +105,6 @@ async def generate_cluster(
async def cluster_summaries(
self, summaries: list[ConversationSummary]
) -> list[Cluster]:
if os.path.exists(os.path.join(self.checkpoint_dir, self.checkpoint_file)):
return self.load_checkpoint()

sem = Semaphore(self.max_concurrent_requests)
embeddings: list[list[float]] = await tqdm_asyncio.gather(
*[
Expand Down Expand Up @@ -162,5 +135,5 @@ async def cluster_summaries(
],
desc="Generating Base Clusters",
)
self.save_checkpoint(clusters)

return clusters
19 changes: 0 additions & 19 deletions kura/dimensionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,12 @@ class HDBUMAP(BaseDimensionalityReduction):
def __init__(
self,
embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
checkpoint_dir: str = "checkpoints",
checkpoint_name: str = "dimensionality_checkpoints.json",
):
self.embedding_model = embedding_model
self.checkpoint_dir = checkpoint_dir
self.checkpoint_name = checkpoint_name

async def reduce_dimensionality(
self, clusters: list[Cluster]
) -> list[ProjectedCluster]:
if os.path.exists(os.path.join(self.checkpoint_dir, self.checkpoint_name)):
with open(
os.path.join(self.checkpoint_dir, self.checkpoint_name), "r"
) as f:
print(
f"Loading UMAP Checkpoint from {self.checkpoint_dir}/{self.checkpoint_name}"
)
return [ProjectedCluster.model_validate_json(line) for line in f]

# Embed all clusters
sem = asyncio.Semaphore(50)
cluster_embeddings = await asyncio.gather(
Expand Down Expand Up @@ -70,10 +57,4 @@ async def reduce_dimensionality(
)
res.append(projected)

with open(os.path.join(self.checkpoint_dir, self.checkpoint_name), "w") as f:
for c in res:
f.write(c.model_dump_json() + "\n")

print(f"Saved UMAP Checkpoint to {self.checkpoint_dir}/{self.checkpoint_name}")

return res
140 changes: 103 additions & 37 deletions kura/kura.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,85 @@
BaseMetaClusterModel,
BaseDimensionalityReduction,
)
import json
from typing import Union
import os
from typing import TypeVar
from pydantic import BaseModel

from kura.types.dimensionality import ProjectedCluster
from kura.types.summarisation import ConversationSummary

T = TypeVar("T", bound=BaseModel)


class Kura:
def __init__(
self,
conversations: list[Conversation] = [],
embedding_model: BaseEmbeddingModel = OpenAIEmbeddingModel(),
summarisation_model: BaseSummaryModel = SummaryModel(),
cluster_model: BaseClusterModel = ClusterModel(),
meta_cluster_model: BaseMetaClusterModel = MetaClusterModel(),
dimensionality_reduction: BaseDimensionalityReduction = HDBUMAP(),
max_clusters: int = 10,
checkpoint_dir: str = "./checkpoints",
cluster_checkpoint_name: str = "clusters.json",
meta_cluster_checkpoint_name: str = "meta_clusters.json",
summary_checkpoint_name: str = "summaries.jsonl",
cluster_checkpoint_name: str = "clusters.jsonl",
meta_cluster_checkpoint_name: str = "meta_clusters.jsonl",
dimensionality_checkpoint_name: str = "dimensionality.jsonl",
disable_checkpoints: bool = False,
):
# TODO: Manage Checkpoints within Kura class itself so we can directly disable checkpointing easily
summarisation_model.checkpoint_dir = checkpoint_dir # pyright: ignore
cluster_model.checkpoint_dir = checkpoint_dir # pyright: ignore
meta_cluster_model.checkpoint_dir = checkpoint_dir # pyright: ignore
dimensionality_reduction.checkpoint_dir = checkpoint_dir # pyright: ignore

# Define Models that we're using
self.embedding_model = embedding_model
self.embedding_model = embedding_model
self.summarisation_model = summarisation_model
self.conversations = conversations
self.max_clusters = max_clusters
self.cluster_model = cluster_model
self.meta_cluster_model = meta_cluster_model
self.dimensionality_reduction = dimensionality_reduction
self.checkpoint_dir = checkpoint_dir
self.cluster_checkpoint_name = cluster_checkpoint_name
self.meta_cluster_checkpoint_name = meta_cluster_checkpoint_name

# Define Checkpoints
self.checkpoint_dir = os.path.join(checkpoint_dir)
self.cluster_checkpoint_name = os.path.join(
self.checkpoint_dir, cluster_checkpoint_name
)
self.meta_cluster_checkpoint_name = os.path.join(
self.checkpoint_dir, meta_cluster_checkpoint_name
)
self.dimensionality_checkpoint_name = os.path.join(
self.checkpoint_dir, dimensionality_checkpoint_name
)
self.summary_checkpoint_name = os.path.join(
self.checkpoint_dir, summary_checkpoint_name
)
self.disable_checkpoints = disable_checkpoints

if not os.path.exists(self.checkpoint_dir) and not self.disable_checkpoints:
os.makedirs(self.checkpoint_dir)

def load_checkpoint(
self, checkpoint_path: str, response_model: type[T]
) -> Union[list[T], None]:
if not self.disable_checkpoints:
if os.path.exists(checkpoint_path):
print(
f"Loading checkpoint from {checkpoint_path} for {response_model.__name__}"
)
with open(checkpoint_path, "r") as f:
return [response_model.model_validate_json(line) for line in f]
return None

def save_checkpoint(self, checkpoint_path: str, data: list[T]) -> None:
if not self.disable_checkpoints:
with open(checkpoint_path, "w") as f:
for item in data:
f.write(item.model_dump_json() + "\n")

async def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:
if os.path.exists(
os.path.join(self.checkpoint_dir, self.cluster_checkpoint_name)
):
print(
f"Loading Meta Cluster Checkpoint from {self.checkpoint_dir}/{self.cluster_checkpoint_name}"
)
with open(
os.path.join(self.checkpoint_dir, self.cluster_checkpoint_name), "r"
) as f:
return [Cluster(**json.loads(line)) for line in f]
checkpoint_items = self.load_checkpoint(
self.meta_cluster_checkpoint_name, Cluster
)
if checkpoint_items:
return checkpoint_items

root_clusters = clusters

Expand All @@ -81,23 +113,57 @@ async def reduce_clusters(self, clusters: list[Cluster]) -> list[Cluster]:

print(f"Reduced to {len(root_clusters)} clusters")

with open(
os.path.join(self.checkpoint_dir, self.cluster_checkpoint_name), "w"
) as f:
print(f"Saving {len(clusters)} clusters to checkpoint")
for c in clusters:
f.write(c.model_dump_json() + "\n")

self.save_checkpoint(self.meta_cluster_checkpoint_name, clusters)
return clusters

async def cluster_conversations(self):
summaries = await self.summarisation_model.summarise(self.conversations)
async def summarise_conversations(
self, conversations: list[Conversation]
) -> list[ConversationSummary]:
checkpoint_items = self.load_checkpoint(
self.summary_checkpoint_name, ConversationSummary
)
if checkpoint_items:
return checkpoint_items

summaries = await self.summarisation_model.summarise(conversations)
self.save_checkpoint(self.summary_checkpoint_name, summaries)
return summaries

async def generate_base_clusters(self, summaries: list[ConversationSummary]):
base_cluster_checkpoint_items = self.load_checkpoint(
self.cluster_checkpoint_name, Cluster
)
if base_cluster_checkpoint_items:
return base_cluster_checkpoint_items

clusters: list[Cluster] = await self.cluster_model.cluster_summaries(summaries)
processed_clusters = await self.reduce_clusters(clusters)
self.save_checkpoint(self.cluster_checkpoint_name, clusters)
return clusters

async def reduce_dimensionality(
self, clusters: list[Cluster]
) -> list[ProjectedCluster]:
checkpoint_items = self.load_checkpoint(
self.dimensionality_checkpoint_name, ProjectedCluster
)
if checkpoint_items:
return checkpoint_items

dimensionality_reduced_clusters = (
await self.dimensionality_reduction.reduce_dimensionality(
processed_clusters
)
await self.dimensionality_reduction.reduce_dimensionality(clusters)
)

self.save_checkpoint(
self.dimensionality_checkpoint_name, dimensionality_reduced_clusters
)
return dimensionality_reduced_clusters

async def cluster_conversations(self, conversations: list[Conversation]):
summaries = await self.summarise_conversations(conversations)
clusters: list[Cluster] = await self.generate_base_clusters(summaries)
processed_clusters: list[Cluster] = await self.reduce_clusters(clusters)
dimensionality_reduced_clusters = await self.reduce_dimensionality(
processed_clusters
)

return dimensionality_reduced_clusters
29 changes: 0 additions & 29 deletions kura/summarisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ class SummaryModel(BaseSummaryModel):
def __init__(
self,
max_concurrent_requests: int = 50,
checkpoint_dir: str = "checkpoints",
checkpoint_file: str = "summarisation_checkpoint.json",
):
self.sem = Semaphore(max_concurrent_requests)
self.client = instructor.from_gemini(
Expand All @@ -22,44 +20,17 @@ def __init__(
),
use_async=True,
)
self.checkpoint_dir = checkpoint_dir
self.checkpoint_file = checkpoint_file

if not os.path.exists(self.checkpoint_dir):
print(f"Creating checkpoint directory {self.checkpoint_dir}")
os.makedirs(self.checkpoint_dir)

def load_checkpoint(self):
print(
f"Loading Summary Checkpoint from {os.path.join(self.checkpoint_dir, self.checkpoint_file)}"
)
with open(os.path.join(self.checkpoint_dir, self.checkpoint_file), "r") as f:
return [ConversationSummary.model_validate_json(line) for line in f]

def save_checkpoint(self, summaries: list[ConversationSummary]):
with open(os.path.join(self.checkpoint_dir, self.checkpoint_file), "w") as f:
for summary in summaries:
f.write(summary.model_dump_json() + "\n")

print(
f"Saved {len(summaries)} summaries to {os.path.join(self.checkpoint_dir, self.checkpoint_file)}"
)

async def summarise(
self, conversations: list[Conversation]
) -> list[ConversationSummary]:
if os.path.exists(os.path.join(self.checkpoint_dir, self.checkpoint_file)):
return self.load_checkpoint()

summaries = await tqdm_asyncio.gather(
*[
self.summarise_conversation(conversation)
for conversation in conversations
],
desc=f"Summarising {len(conversations)} conversations",
)

self.save_checkpoint(summaries)
return summaries

async def apply_hooks(
Expand Down
4 changes: 3 additions & 1 deletion kura/types/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ def from_claude_conversation_dump(cls, file_path: str) -> list["Conversation"]:
messages=[
Message(
created_at=message["created_at"],
role=message["sender"],
role="user"
if message["sender"] == "human"
else "assistant",
content="\n".join(
[
item["text"]
Expand Down

0 comments on commit 60a797a

Please sign in to comment.