Skip to content

Commit

Permalink
* chore(embedding.py): reformat import statements and fix typing impo…
Browse files Browse the repository at this point in the history
…rt order

* feat(embedding.py): add docstring to Embedding class
* feat(embedding.py): add docstring to FlagEmbedding class
  • Loading branch information
NirantK committed Aug 22, 2023
1 parent b0f8a33 commit 6e72ec0
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions fastembed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tarfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Iterable
from typing import Iterable, List

import numpy as np
import onnxruntime as ort
Expand Down Expand Up @@ -33,6 +33,29 @@ class ONNXProviders:


class Embedding(ABC):
"""
Abstract class for embeddings.
Args:
ABC ():
Raises:
NotImplementedError: _description_
PermissionError: _description_
ValueError: _description_
ValueError: _description_
ValueError: _description_
ValueError: _description_
ValueError: _description_
NotImplementedError: _description_
Returns:
_type_: _description_
Yields:
_type_: _description_
"""

@abstractmethod
def encode(self, texts: List[str]) -> List[np.ndarray]:
raise NotImplementedError
Expand Down Expand Up @@ -154,7 +177,7 @@ def retrieve_model(self, model_name: str, cache_dir: str) -> Path:
class FlagEmbedding(Embedding):
def __init__(
self,
model_name: str,
model_name: str = "BAAI/bge-small-en",
onnx_providers=None,
max_length: int = 512,
cache_dir: str = None,
Expand Down Expand Up @@ -227,7 +250,7 @@ def encode(self, documents: List[str], batch_size: int = 256) -> Iterable[np.nda
class DefaultEmbedding(FlagEmbedding):
def __init__(
self,
model_name: str = "BAAI/bge-base-en",
model_name: str = "BAAI/bge-small-en",
onnx_providers: List[str] = None,
max_length: int = 512,
cache_dir: str = None,
Expand Down

0 comments on commit 6e72ec0

Please sign in to comment.