Skip to content

Commit

Permalink
Revert "32 precision for cinemaot"
Browse files Browse the repository at this point in the history
This reverts commit 0c7e180.
  • Loading branch information
Zethson committed Feb 22, 2025
1 parent 0c7e180 commit 6a97036
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions pertpy/tools/_cinemaot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def causaleffect(
dim = self.get_dim(adata, use_rep=use_rep)

transformer = FastICA(n_components=dim, random_state=0, whiten="arbitrary-variance")
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float32)
X_transformed = np.array(transformer.fit_transform(adata.obsm[use_rep][:, :dim]), dtype=np.float64)
groupvec = (adata.obs[pert_key] == control * 1).values # control
xi = np.zeros(dim)
j = 0
Expand All @@ -100,9 +100,9 @@ def causaleffect(
xi[j] = xi_obj.correlation
j = j + 1

cf = np.array(X_transformed[:, xi < thres], np.float32)
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float32)
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float32)
cf = np.array(X_transformed[:, xi < thres], np.float64)
cf1 = np.array(cf[adata.obs[pert_key] == control, :], np.float64)
cf2 = np.array(cf[adata.obs[pert_key] != control, :], np.float64)
if sum(xi < thres) == 1:
sklearn.metrics.pairwise_distances(cf1.reshape(-1, 1), cf2.reshape(-1, 1))
elif sum(xi < thres) == 0:
Expand Down Expand Up @@ -170,7 +170,7 @@ def causaleffect(
else:
_solver = sinkhorn.Sinkhorn(threshold=eps)
ot_sink = _solver(ot_prob)
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float32)
ot_matrix = np.array(ot_sink.matrix.T, dtype=np.float64)
embedding = X_transformed[adata.obs[pert_key] != control, :] - np.matmul(
ot_matrix / np.sum(ot_matrix, axis=1)[:, None], X_transformed[adata.obs[pert_key] == control, :]
)
Expand Down

0 comments on commit 6a97036

Please sign in to comment.