Skip to content

Commit

Permalink
Merge pull request #550 from gnzng/drop-pd
Browse files Browse the repository at this point in the history
Drop pandas dependencies + remove by test created files
  • Loading branch information
newville authored Feb 8, 2025
2 parents f58c8c2 + 57644c2 commit b4f5d6d
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 66 deletions.
116 changes: 81 additions & 35 deletions larch/math/deglitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import logging
import numpy as np
from scipy.ndimage import median_filter

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,14 +120,39 @@ def remove_spikes_pymca(y_spiky, kernel_size=9, threshold=0.66):
return ynew


def remove_spikes_pandas(y, window=3, threshold=3):
"""remove spikes using pandas
def remove_spikes_scipy(y, window=3, threshold=3):
"""remove spikes using scipy ndimage median_filter
Taken from `https://ocefpaf.github.io/python4oceanographers/blog/2015/03/16/outlier_detection/`_
Parameters
----------
y : array 1D
window : int (optional)
window in rolling median [3]
threshold : int (optional)
number of sigma difference with original data
.. note:: this will not work in pandas > 0.17 one could simply do
`df.rolling(3, center=True).median()`; also
df.as_matrix() is deprecated, use df.values instead
Return
------
ynew : array like x/y
"""
ynew = np.zeros_like(y)
if window % 2 == 0:
window += 1
_logger.warning("'window' must be odd -> adjusted to %d", window)
try:
yf = median_filter(y, size=window, mode='nearest')
diff = yf - y
mean = diff.mean()
sigma = np.sqrt(np.sum((y - mean) ** 2) / len(y))
ynew = np.where(abs(diff) > threshold * sigma, yf, y)
except Exception as e:
_logger.error("Error in remove_spikes_pandas: %s", e)
return ynew
return ynew


def remove_spikes_numpy(y, window=3, threshold=3):
"""remove spikes using numpy
Parameters
----------
Expand All @@ -141,36 +167,56 @@ def remove_spikes_pandas(y, window=3, threshold=3):
ynew : array like x/y
"""
ynew = np.zeros_like(y)
if window % 2 == 0:
window += 1
_logger.warning("'window' must be odd -> adjusted to %d", window)
try:
import pandas as pd
except ImportError:
_logger.error("pandas not found! -> returning zeros")
# Compute the rolling median
pad_width = window // 2
y_padded = np.pad(y, pad_width, mode='edge')
yf = np.zeros_like(y)

for i in range(len(y)):
yf[i] = np.median(y_padded[i:i + window])

# Compute the difference and statistics
diff = yf - y
mean = np.mean(diff)
sigma = np.sqrt(np.sum((y - mean) ** 2) / len(y))

# Replace values where the difference exceeds the threshold
ynew = np.where(np.abs(diff) > threshold * sigma, yf, y)

except Exception as e:
print(f"Error in remove_spikes_numpy: {e}")
return ynew
df = pd.DataFrame(y)
try:
yf = (
pd.rolling_median(df, window=window, center=True)
.fillna(method="bfill")
.fillna(method="ffill")
)
diff = yf.as_matrix() - y
mean = diff.mean()
sigma = (y - mean) ** 2
sigma = np.sqrt(sigma.sum() / float(len(sigma)))
ynew = np.where(abs(diff) > threshold * sigma, yf.as_matrix(), y)
except Exception:
yf = (
df.rolling(window, center=True)
.median()
.fillna(method="bfill")
.fillna(method="ffill")
)

diff = yf.values - y
mean = diff.mean()
sigma = (y - mean) ** 2
sigma = np.sqrt(sigma.sum() / float(len(sigma)))
ynew = np.where(abs(diff) > threshold * sigma, yf.values, y)

# ynew = np.array(yf.values).reshape(len(x))
return ynew


def remove_spikes_pandas(y, window=3, threshold=3):
"""
DEPRECATED:
remove spikes using pandas
Taken from `https://ocefpaf.github.io/python4oceanographers/blog/2015/03/16/outlier_detection/`_
.. note:: this will not work in pandas > 0.17 one could simply do
`df.rolling(3, center=True).median()`; also
df.as_matrix() is deprecated, use df.values instead
Parameters
----------
y : array 1D
window : int (optional)
window in rolling median [3]
threshold : int (optional)
number of sigma difference with original data
Return
------
ynew : array like x/y
"""
_logger.warning("pandas backend is not supported, using scipy instead")
return remove_spikes_scipy(y, window=window, threshold=threshold)
38 changes: 9 additions & 29 deletions larch/xrd/struct2xas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,6 @@
from larch.io import read_ascii
from larch.math.convolution1D import lin_gamma, conv

try:
import pandas as pd
from pandas.io.formats.style import Styler

HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False

try:
import py3Dmol

Expand Down Expand Up @@ -474,17 +466,9 @@ def get_abs_sites_info(self):
"idx_in_struct",
]
abs_sites = self.get_abs_sites()
if HAS_PANDAS:
df = pd.DataFrame(
abs_sites,
columns=header,
)
df = Styler(df).hide(axis="index")
return df
else:
matrix = [header]
matrix.extend(abs_sites)
_pprint(matrix)
matrix = [header]
matrix.extend(abs_sites)
_pprint(matrix)

def get_atoms_from_abs(self, radius):
"""Get atoms in sphere from absorbing atom with certain radius"""
Expand Down Expand Up @@ -744,13 +728,9 @@ def get_coord_envs_info(self):
)
print(coord_sym)
header = ["Element", "Distance"]
if HAS_PANDAS:
df = pd.DataFrame(data=elems_dist, columns=header)
return df
else:
matrix = [header]
matrix.extend(elems_dist)
_pprint(matrix)
matrix = [header]
matrix.extend(elems_dist)
_pprint(matrix)

def make_cluster(self, radius):
"""Create a cluster with absorber atom site at the center.
Expand Down Expand Up @@ -942,7 +922,7 @@ def make_input_fdmnes(
np.allclose(unique_sites[i][0].coords, selected_site[4], atol=0.01)
is True
):
replacements["absorber"] = f"absorber\n{i+1}"
replacements["absorber"] = f"absorber\n{i + 1}"

# absorber = f"{absorber}"
# replacements["absorber"] = f"Z_absorber\n{round(Element(elem).Z)}"
Expand All @@ -963,7 +943,7 @@ def make_input_fdmnes(
absorber = f"{absorber}"
for i in range(len(atoms)):
if np.allclose(atoms[i][1], [0, 0, 0], atol=0.01) is True:
replacements["absorber"] = f"absorber\n{i+1}"
replacements["absorber"] = f"absorber\n{i + 1}"

replacements["group"] = ""

Expand Down Expand Up @@ -1210,7 +1190,7 @@ def make_input_feff(
for i in range(len(at)):
if self.full_occupancy:
atoms += "\n" + (
f"{at[i][0][0]:10.6f} {at[i][0][1]:10.6f} {at[i][0][2]:10.6f} { int(at[i][1])} {at[i][2]:>5} {at[i][3]:10.5f} *1 "
f"{at[i][0][0]:10.6f} {at[i][0][1]:10.6f} {at[i][0][2]:10.6f} {int(at[i][1])} {at[i][2]:>5} {at[i][3]:10.5f} *1 "
)
else:
choice = np.random.choice(
Expand Down
5 changes: 4 additions & 1 deletion tests/test_athena_addgroup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy.testing
from pathlib import Path
from larch.io import read_ascii, AthenaProject


Expand All @@ -9,8 +10,10 @@ def test_add_athena_group():
b.filename = 'cu_10k_copy.xmu'
del b.mu


p = AthenaProject('x1.prj')
p.add_group(a)
p.add_group(b)
p.save()

# remove file after test
Path('x1.prj').unlink(missing_ok=True)
4 changes: 4 additions & 0 deletions tests/test_larchexamples_xafs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def test17_feffit3extra(self):
self.isNear('_dlo', 0.000315, places=4)


def test_remove_files():
created_test_files = ['doc_feffit1.out', 'doc_feffit2.out', 'doc_feffit3.out']
for file in created_test_files:
(base_dir / 'examples' / 'feffit' / file).unlink(missing_ok=True)


if __name__ == '__main__': # pragma: no cover
Expand Down
41 changes: 41 additions & 0 deletions tests/test_math_deglitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy as np
from larch.math.deglitch import remove_spikes_numpy, remove_spikes_pandas, remove_spikes_scipy


def test_remove_spikes_numpy():
# Input array with spikes
y = np.random.random(100)
y[24] = 10
y[56] = 11

# Run the function
result = remove_spikes_numpy(y, window=3, threshold=3)

# Check if all elements in result are smaller than 5
assert np.all(result < 5), "Test failed: spikes not detected"


def test_remove_spikes_scipy():
# Input array with spikes
y = np.random.random(100)
y[24] = 10
y[56] = 11

# Run the function
result = remove_spikes_scipy(y, window=3, threshold=3)

# Check if all elements in result are smaller than 5
assert np.all(result < 5), "Test failed: spikes not detected"


def test_remove_spikes_pandas():
# Input array with spikes
y = np.random.random(100)
y[24] = 10
y[56] = 11

# Run the function
result = remove_spikes_pandas(y, window=3, threshold=3)

# Check if all elements in result are smaller than 5
assert np.all(result < 5), "Test failed: spikes not detected"
5 changes: 4 additions & 1 deletion tests/test_symbol_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from larch import Interpreter
linp = Interpreter()


def onVarChange(group=None, symbolname=None, value=None, **kws):
print( 'var changed ', group, symbolname, value, kws)
print('var changed ', group, symbolname, value, kws)


linp('x = 100.0')
linp.symtable.set_symbol('x', 30)
Expand Down

0 comments on commit b4f5d6d

Please sign in to comment.