Skip to content

Commit

Permalink
Condition sparse-dot-topn version on Python version
Browse files Browse the repository at this point in the history
  • Loading branch information
RUrlus committed Apr 26, 2024
1 parent 191dcfa commit 1ff33b4
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
6 changes: 1 addition & 5 deletions polyfuzz/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import importlib.util

import numpy as np
import pandas as pd
from scipy.sparse import csr_matrix
from typing import List
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity as scikit_cosine_similarity

_HAVE_SPARSE_DOT = importlib.util.find_spec("sparse_dot_topn") is not None
if _HAVE_SPARSE_DOT:
from sparse_dot_topn import sp_matmul_topn

from polyfuzz.models._utils_sdtn import _HAVE_SPARSE_DOT, sp_matmul_topn

def cosine_similarity(from_vector: np.ndarray,
to_vector: np.ndarray,
Expand Down
33 changes: 33 additions & 0 deletions polyfuzz/models/_utils_sdtn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import sys
import importlib.util
from scipy.sparse import csr_matrix

from typing import Optional

_HAVE_SPARSE_DOT = importlib.util.find_spec("sparse_dot_topn") is not None
if _HAVE_SPARSE_DOT:
if sys.version_info >= (3, 8):
from sparse_dot_topn import sp_matmul_topn
else:
from sparse_dot_topn import awesome_cossim_topn

def sp_matmul_topn(
A: csr_matrix,
B: csr_matrix,
top_n: int,
threshold: float,
sort: bool = True,
n_threads: Optional[int] = None,
):
n_threads = n_threads or 1
use_threads = n_threads > 1
return awesome_cossim_topn(
A,
B.T,
ntop=max(top_n, 2),
lower_bound=threshold,
use_threads=use_threads,
n_jobs=n_threads,
)

__all__ = ["sp_matmul_topn"]
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
"sentence-transformers>=0.4.1"
]

fast_cosine = ["sparse_dot_topn>=1.1.1"]
fast_cosine = [
"sparse_dot_topn<1.0; python_version < '3.8'",
"sparse_dot_topn>=1.1.1; python_version >= '3.8'",
]

embeddings_packages = [
"torch>=1.4.0",
Expand Down

0 comments on commit 1ff33b4

Please sign in to comment.