Skip to content

Commit

Permalink
remove batch_vectorizer.py under common
Browse files Browse the repository at this point in the history
  • Loading branch information
xionghuaidong committed Oct 24, 2024
1 parent b2dfe9e commit b46e78a
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 314 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
*.jc
*.pyc
/dist
.vscode/
.vscode/
__pycache__/
112 changes: 107 additions & 5 deletions kag/builder/component/vectorizer/batch_vectorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,115 @@

from kag.builder.model.sub_graph import SubGraph
from knext.common.base.runnable import Input, Output
from kag.common.vectorizer import Vectorizer, Neo4jBatchVectorizer
from kag.common.vectorizer import Vectorizer
from kag.interface.builder.vectorizer_abc import VectorizerABC
from knext.schema.client import SchemaClient
from knext.project.client import ProjectClient
from knext.schema.model.base import IndexTypeEnum


class EmbeddingVectorPlaceholder(object):
def __init__(self, number, properties, vector_field, property_key, property_value):
self._number = number
self._properties = properties
self._vector_field = vector_field
self._property_key = property_key
self._property_value = property_value
self._embedding_vector = None

def replace(self):
if self._embedding_vector is not None:
self._properties[self._vector_field] = self._embedding_vector

def __repr__(self):
return repr(self._number)


class EmbeddingVectorManager(object):
def __init__(self):
self._placeholders = []

def _create_vector_field_name(self, property_key):
from kag.common.utils import to_snake_case

name = f"{property_key}_vector"
name = to_snake_case(name)
return "_" + name

def get_placeholder(self, properties, vector_field):
for property_key, property_value in properties.items():
field_name = self._create_vector_field_name(property_key)
if field_name != vector_field:
continue
if not property_value:
return None
if not isinstance(property_value, str):
message = f"property {property_key!r} must be string to generate embedding vector"
raise RuntimeError(message)
num = len(self._placeholders)
placeholder = EmbeddingVectorPlaceholder(
num, properties, vector_field, property_key, property_value
)
self._placeholders.append(placeholder)
return placeholder
return None

def _get_text_batch(self):
text_batch = dict()
for placeholder in self._placeholders:
property_value = placeholder._property_value
if property_value not in text_batch:
text_batch[property_value] = list()
text_batch[property_value].append(placeholder)
return text_batch

def _generate_vectors(self, vectorizer, text_batch):
texts = list(text_batch)
vectors = vectorizer.vectorize(texts)
return vectors

def _fill_vectors(self, vectors, text_batch):
for vector, (_text, placeholders) in zip(vectors, text_batch.items()):
for placeholder in placeholders:
placeholder._embedding_vector = vector

def batch_generate(self, vectorizer):
text_batch = self._get_text_batch()
vectors = self._generate_vectors(vectorizer, text_batch)
self._fill_vectors(vectors, text_batch)

def patch(self):
for placeholder in self._placeholders:
placeholder.replace()


class EmbeddingVectorGenerator(object):
def __init__(self, vectorizer, vector_index_meta=None, extra_labels=("Entity",)):
self._vectorizer = vectorizer
self._extra_labels = extra_labels
self._vector_index_meta = vector_index_meta or {}

def batch_generate(self, node_batch):
manager = EmbeddingVectorManager()
vector_index_meta = self._vector_index_meta
for node_item in node_batch:
label, properties = node_item
labels = [label]
if self._extra_labels:
labels.extend(self._extra_labels)
for label in labels:
if label not in vector_index_meta:
continue
for vector_field in vector_index_meta[label]:
if vector_field in properties:
continue
placeholder = manager.get_placeholder(properties, vector_field)
if placeholder is not None:
properties[vector_field] = placeholder
manager.batch_generate(self._vectorizer)
manager.patch()


class BatchVectorizer(VectorizerABC):

def __init__(self, project_id: str = None, **kwargs):
Expand Down Expand Up @@ -70,7 +172,7 @@ def _create_vector_field_name(self, property_key):
name = to_snake_case(name)
return "_" + name

def _neo4j_batch_vectorize(self, vectorizer: Vectorizer, input: SubGraph) -> SubGraph:
def _generate_embedding_vectors(self, vectorizer: Vectorizer, input: SubGraph) -> SubGraph:
node_list = []
node_batch = []
for node in input.nodes:
Expand All @@ -80,8 +182,8 @@ def _neo4j_batch_vectorize(self, vectorizer: Vectorizer, input: SubGraph) -> Sub
properties.update(node.properties)
node_list.append((node, properties))
node_batch.append((node.label, properties.copy()))
batch_vectorizer = Neo4jBatchVectorizer(vectorizer, self.vec_meta)
batch_vectorizer.batch_vectorize(node_batch)
generator = EmbeddingVectorGenerator(vectorizer, self.vec_meta)
generator.batch_generate(node_batch)
for (node, properties), (_node_label, new_properties) in zip(
node_list, node_batch
):
Expand All @@ -92,5 +194,5 @@ def _neo4j_batch_vectorize(self, vectorizer: Vectorizer, input: SubGraph) -> Sub
return input

def invoke(self, input: Input, **kwargs) -> List[Output]:
modified_input = self._neo4j_batch_vectorize(self.vectorizer, input)
modified_input = self._generate_embedding_vectors(self.vectorizer, input)
return [modified_input]
2 changes: 0 additions & 2 deletions kag/common/vectorizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,13 @@
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied.

from kag.common.vectorizer.batch_vectorizer import Neo4jBatchVectorizer
from kag.common.vectorizer.local_bge_m3_vectorizer import LocalBGEM3Vectorizer
from kag.common.vectorizer.local_bge_vectorizer import LocalBGEVectorizer
from kag.common.vectorizer.openai_vectorizer import OpenAIVectorizer
from kag.common.vectorizer.vectorizer import Vectorizer


__all__ = [
"Neo4jBatchVectorizer",
"LocalBGEM3Vectorizer",
"LocalBGEVectorizer",
"OpenAIVectorizer",
Expand Down
112 changes: 0 additions & 112 deletions kag/common/vectorizer/batch_vectorizer.py

This file was deleted.

4 changes: 2 additions & 2 deletions kag/solver/logic/core_modules/common/text_sim_by_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from kag.common.vectorizer.vectorizer import Vectorizer
from kag.common.vectorizer import Vectorizer


def cosine_similarity(vector1, vector2):
Expand Down Expand Up @@ -86,4 +86,4 @@ def text_type_sim(self, mention, candidates, topk=1):
res = self.text_sim_result(mention, candidates, topk)
if len(res) == 0:
return [('Entity', 1.)]
return res
return res
Loading

0 comments on commit b46e78a

Please sign in to comment.