Skip to content

Commit

Permalink
[latutil] hide interpolator dep
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoStrand committed Apr 24, 2019
1 parent 2138620 commit 3154d65
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions python/triqs_tprf/lattice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@

# ----------------------------------------------------------------------

from scipy.interpolate import LinearNDInterpolator
from scipy.interpolate import RegularGridInterpolator
from scipy.interpolate import NearestNDInterpolator

# ----------------------------------------------------------------------

import pytriqs.utility.mpi as mpi

from pytriqs.gf import Gf
Expand Down Expand Up @@ -322,6 +316,7 @@ def get_abs_k_chi_interpolator(values, bzmesh, bz, extend_bz=[0]):
k_vec = np.vstack(k_vec_ext)
values = np.hstack(values_ext)

from scipy.interpolate import LinearNDInterpolator
interp = LinearNDInterpolator(k_vec, values, fill_value=float('nan'))

return interp
Expand All @@ -347,11 +342,14 @@ def get_rel_k_chi_interpolator(values, bzmesh, bz, nk,
# -- select interpolator type

if interpolator is 'regular':
from scipy.interpolate import RegularGridInterpolator
interp = RegularGridInterpolator(
(kx, ky, kz), values, fill_value=float('nan'), bounds_error=False)
elif interpolator is 'nearest':
from scipy.interpolate import NearestNDInterpolator
interp = NearestNDInterpolator(k_vec_rel, values.flatten())
elif interpolator is 'linear':
from scipy.interpolate import LinearNDInterpolator
interp = LinearNDInterpolator(k_vec_rel, values.flatten(), fill_value=float('nan'))
else:
raise NotImplementedError
Expand Down

0 comments on commit 3154d65

Please sign in to comment.