-
Notifications
You must be signed in to change notification settings - Fork 0
/
00.seq_analysis.py
103 lines (97 loc) · 2.5 KB
/
00.seq_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# %% import and definition
import os
import numpy as np
import plotly.express as px
import xarray as xr
from ds_utils.utils.num import norm
from scipy.signal import medfilt
from scipy.stats import zscore
from seqnmf import seqnmf
from routine.io import load_F
IN_DPATH = "./data/ANMP215/A215-20230118/04/suite2p/plane0/"
OUT_PATH = "./intermediate/seqnmf/"
FIG_PATH = "./figs/seqnmf/"
PARAM_FILT_WND = 31
PARAM_NCOMP = 50
PARAM_TLEN = 300
PARAM_DS = 4
PARAM_F_CLIP = (0, 5)
os.makedirs(OUT_PATH, exist_ok=True)
os.makedirs(FIG_PATH, exist_ok=True)
# %% seqNMF analysis
F = load_F(IN_DPATH) # TODO: check why some values are negative
F = (
xr.apply_ufunc(
lambda x: zscore(medfilt(x, kernel_size=PARAM_FILT_WND)),
F,
input_core_dims=[["frame"]],
output_core_dims=[["frame"]],
vectorize=True,
)
.coarsen({"frame": PARAM_DS}, boundary="trim")
.mean()
.clip(*PARAM_F_CLIP)
)
W, H, cost, loadings, power = seqnmf(
F,
K=PARAM_NCOMP,
L=PARAM_TLEN,
Lambda=1e-4,
lambda_L1H=0,
lambda_L1W=0,
max_iter=10,
shift=False,
)
W = xr.DataArray(
W,
dims=["cell", "comp", "t"],
coords={
"cell": F.coords["cell"],
"comp": np.arange(PARAM_NCOMP),
"t": np.arange(PARAM_TLEN),
},
name="W",
)
H = xr.DataArray(
H,
dims=["comp", "frame"],
coords={"comp": np.arange(PARAM_NCOMP), "frame": F.coords["frame"]},
name="H",
)
seq_ds = xr.merge([F, W, H])
seq_ds.to_netcdf(os.path.join(OUT_PATH, "seq_ds.nc"))
# %% plotting
plt_F_nfm = int(5e3)
plt_ncomp = 10
seq_ds = xr.open_dataset(os.path.join(OUT_PATH, "seq_ds.nc"))
F, W, H = seq_ds["F"], seq_ds["W"], seq_ds["H"]
H = xr.apply_ufunc(
norm,
H,
input_core_dims=[["frame"]],
output_core_dims=[["frame"]],
vectorize=True,
kwargs={"q": (0, 0.97)},
)
W = xr.apply_ufunc(
norm,
W,
input_core_dims=[["cell", "t"]],
output_core_dims=[["cell", "t"]],
vectorize=True,
kwargs={"q": (0, 0.999)},
)
fig_F = px.imshow(F.isel(frame=slice(0, plt_F_nfm)))
fig_F.write_html(os.path.join(FIG_PATH, "F.html"))
fig_H = px.imshow(H)
fig_H.write_html(os.path.join(FIG_PATH, "H.html"))
for icomp in range(0, W.sizes["comp"], plt_ncomp):
c0, c1 = icomp, icomp + plt_ncomp
fig_W = px.imshow(
W.isel(comp=slice(c0, c1)),
facet_col="comp",
facet_col_wrap=5,
facet_row_spacing=0.01,
)
fig_W.update_layout(height=plt_ncomp / 5 * 800)
fig_W.write_html(os.path.join(FIG_PATH, "W-{}_{}.html".format(c0, c1)))