Skip to content

Commit

Permalink
Merge pull request #68 from ubermag/new_factor
Browse files Browse the repository at this point in the history
New factor
  • Loading branch information
samjrholt authored Aug 7, 2024
2 parents 29a35b4 + df763fb commit 6c3fc99
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 74 deletions.
44 changes: 23 additions & 21 deletions docs/SANS.ipynb

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions mag2exp/ltem.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def phase(field, /, kcx=0.1, kcy=0.1):
"""
# More readable notation, direction arg will be removed soon.
m_int = field.integrate(direction="z")
m_ft = m_int.fftn()
m_ft = m_int.fftn(norm="ortho")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -106,7 +106,7 @@ def phase(field, /, kcx=0.1, kcy=0.1):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
ft_phase = (m_ft & k).ft_z * denom * prefactor
phase = ft_phase.ifftn().real
phase = ft_phase.ifftn(norm="ortho").real
phase.mesh.translate(field.mesh.region.center[:2], inplace=True)
return phase, ft_phase

Expand Down Expand Up @@ -219,7 +219,7 @@ def defocus_image(phase, /, cs=0, df_length=0.2e-3, voltage=None, wavelength=Non
:py:func:`~mag2exp.ltem.relativistic_wavelength`
"""
ft_wavefn = np.exp(phase * 1j).fftn()
ft_wavefn = np.exp(phase * 1j).fftn(norm="ortho")

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -242,7 +242,7 @@ def defocus_image(phase, /, cs=0, df_length=0.2e-3, voltage=None, wavelength=Non
cts = -df_length + 0.5 * wavelength**2 * cs * ksquare
exp = np.exp(np.pi * cts * 1j * ksquare * wavelength)
ft_def_wf_cts = ft_wavefn * exp
def_wf_cts = ft_def_wf_cts.ifftn()
def_wf_cts = ft_def_wf_cts.ifftn(norm="ortho")
intensity_cts = def_wf_cts.conjugate * def_wf_cts
intensity_cts.mesh.translate(phase.mesh.region.center, inplace=True)
return intensity_cts.real
Expand Down Expand Up @@ -339,10 +339,13 @@ def relativistic_wavelength(voltage):
1.9687489006848795e-12
"""
return constants.h / np.sqrt(
2
* constants.m_e
* voltage
* constants.e
* (1 + constants.e * voltage / (2 * constants.m_e * constants.c**2))
)
return (
constants.h
/ np.sqrt(
2
* constants.m_e
* voltage
* constants.e
* (1 + constants.e * voltage / (2 * constants.m_e * constants.c**2))
)
).item()
31 changes: 20 additions & 11 deletions mag2exp/sans.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def cross_section(field, /, method, polarisation=(0, 0, 1)):
The scattering cross sections can be calculated using
.. math::
\frac{d\sum}{d\Omega} \sim |{\bf Q} \cdot {\bf \sigma}|^2,
\frac{d\sum}{d\Omega} \propto |{\bf Q} \cdot {\bf \sigma}|^2,
where :math:`{\bf \sigma}` is the Pauli vector
.. math::
{\bf \sigma} = \begin{bmatrix} \sigma_x \\
\sigma_y \\
\sigma_z \end{bmatrix},
\sigma_z \end{bmatrix}
and
Expand Down Expand Up @@ -87,6 +87,10 @@ def cross_section(field, /, method, polarisation=(0, 0, 1)):
\frac{d\sum^{--}}{d\Omega} + \frac{d\sum^{-+}}{d\Omega} \right).
\end{align}
Note: Similar to the magnetisation, the value of the cross section
exists within the whole cell and therefore will change based on the size
of the cell and region in the input field.
Parameters
----------
field : discretisedfield.field
Expand Down Expand Up @@ -164,20 +168,25 @@ def cross_section(field, /, method, polarisation=(0, 0, 1)):
cross_s = _cross_section_matrix(field, polarisation=polarisation)
# TODO: make more efficient!

# norm_field = df.Field(field.mesh, nvdim=1, value=(field.norm.array != 0))
# volume = norm_field.integrate()

factor = 1 # 8 * np.pi**3 * (2.91e8)**2 / (field.mesh.dV)

if method in ("polarised_pp", "pp"):
return cross_s.pp
return factor * cross_s.pp
elif method in ("polarised_pn", "pn"):
return cross_s.pn
return factor * cross_s.pn
elif method in ("polarised_np", "np"):
return cross_s.np
return factor * cross_s.np
elif method in ("polarised_nn", "nn"):
return cross_s.nn
return factor * cross_s.nn
elif method in ("half_polarised_p", "p"):
return cross_s.pp + cross_s.pn
return factor * cross_s.pp + cross_s.pn
elif method in ("half_polarised_n", "n"):
return cross_s.nn + cross_s.np
return factor * cross_s.nn + cross_s.np
elif method in ("unpolarised", "unpol"):
return 0.5 * (cross_s.pp + cross_s.pn + cross_s.np + cross_s.nn)
return 0.5 * factor * (cross_s.pp + cross_s.pn + cross_s.np + cross_s.nn)
else:
msg = f"Method {method} is unknown."
raise ValueError(msg)
Expand Down Expand Up @@ -241,8 +250,8 @@ def chiral_function(field, /, polarisation=(0, 0, 1)):


def _cross_section_matrix(field, /, polarisation):
m_fft = field.fftn()
m_fft *= field.mesh.dV * 1e16 # TODO: Normalisation
m_fft = field.fftn(norm="ortho")
m_fft *= field.mesh.dV
q = df.Field(
m_fft.mesh,
nvdim=3,
Expand Down
48 changes: 25 additions & 23 deletions mag2exp/tests/test_sans.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import mag2exp

THRESH = 1e-64


def test_sans_analytical_parallel_bloch():
region = df.Region(p1=(-50e-9, -100e-9, 0), p2=(50e-9, 100e-9, 30e-9))
Expand All @@ -21,35 +23,35 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="pp").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="nn").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="pn").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="np").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2


Expand All @@ -69,21 +71,21 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="pp").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="nn").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, method="pn").sel(k_z=0)
Expand All @@ -110,7 +112,7 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, polarisation=[1, 0, 0], method="nn").sel(k_z=0)
Expand All @@ -123,14 +125,14 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, polarisation=[1, 0, 0], method="np").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2


Expand All @@ -151,7 +153,7 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(abs(q), 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 2

sans = mag2exp.sans.cross_section(m, polarisation=[1, 0, 0], method="nn").sel(k_z=0)
Expand All @@ -164,14 +166,14 @@ def m_fun(pos):
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(q, 1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 1

sans = mag2exp.sans.cross_section(m, polarisation=[1, 0, 0], method="np").sel(k_z=0)
idx = np.unravel_index(sans.array.argmax(), sans.array.shape)[0:2]
q = sans.mesh.index2point(idx)[0]
assert np.isclose(q, -1 / qx)
peaks = (sans.array > 10).sum()
peaks = (sans.array > THRESH).sum()
assert peaks == 1


Expand Down Expand Up @@ -213,12 +215,12 @@ def m_fun(pos):
idx = np.unravel_index(cf.array.argmax(), cf.array.shape)[0:2]
q = cf.mesh.index2point(idx)[0]
assert np.isclose(q, 1 / qx)
peaks = (cf.array > 10).sum()
peaks = (cf.array > THRESH).sum()
assert peaks == 1
idx = np.unravel_index(cf.array.argmin(), cf.array.shape)[0:2]
q = cf.mesh.index2point(idx)[0]
assert np.isclose(q, -1 / qx)
peaks = (cf.array < -10).sum()
peaks = (cf.array < -THRESH).sum()
assert peaks == 1

def m_fun(pos):
Expand All @@ -235,29 +237,29 @@ def test_sans_normalisation():

def m_fun(pos):
x, y, z = pos
qx = 25e-9
qx = 20e-9
return (0, np.sin(2 * np.pi * x / qx), np.cos(2 * np.pi * x / qx))

region = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
region = df.Region(p1=(0, 0, 0), p2=(80e-9, 80e-9, 80e-9))
mesh = df.Mesh(region=region, cell=(4e-9, 4e-9, 4e-9))
field1 = df.Field(mesh, nvdim=3, value=m_fun, norm=Ms)
sans1 = mag2exp.sans.cross_section(field1, method="unpol")
m1 = abs(sans1.array).max()

region2 = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
region2 = df.Region(p1=(0, 0, 0), p2=(80e-9, 80e-9, 80e-9))
mesh2 = df.Mesh(region=region2, cell=(2e-9, 2e-9, 2e-9))
field2 = df.Field(mesh2, nvdim=3, value=m_fun, norm=Ms)
sans2 = mag2exp.sans.cross_section(field2, method="unpol")
m2 = abs(sans2.array).max()

region = df.Region(p1=(0, 0, 0), p2=(150e-9, 150e-9, 150e-9))
mesh = df.Mesh(region=region, cell=(5e-9, 5e-9, 5e-9))
region = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
mesh = df.Mesh(region=region, cell=(4e-9, 4e-9, 4e-9))
field3 = df.Field(mesh, nvdim=3, value=m_fun, norm=Ms)
sans3 = mag2exp.sans.cross_section(field3, method="unpol")
m3 = abs(sans3.array).max()

assert np.isclose(m1, m2)
assert np.isclose(m1 / m3, (100 / 150) ** 6)
assert np.isclose(m2 / m1, ((2 / 4) ** 3))
assert np.isclose(m3 / m1, (100 / 80) ** 3)


def test_sans_cross_section_methods():
Expand Down
13 changes: 7 additions & 6 deletions mag2exp/tests/test_x_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,24 @@ def m_fun(pos):
qx = 25e-9
return (0, np.sin(2 * np.pi * x / qx), np.cos(2 * np.pi * x / qx))

region = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
region = df.Region(p1=(0, 0, 0), p2=(80e-9, 80e-9, 80e-9))
mesh = df.Mesh(region=region, cell=(4e-9, 4e-9, 4e-9))
field1 = df.Field(mesh, nvdim=3, value=m_fun, norm=Ms)
saxs1 = mag2exp.x_ray.saxs(field1)
m1 = abs(saxs1.array).max()

region2 = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
region2 = df.Region(p1=(0, 0, 0), p2=(80e-9, 80e-9, 80e-9))
mesh2 = df.Mesh(region=region2, cell=(2e-9, 2e-9, 2e-9))
field2 = df.Field(mesh2, nvdim=3, value=m_fun, norm=Ms)
saxs2 = mag2exp.x_ray.saxs(field2)
m2 = abs(saxs2.array).max()

region = df.Region(p1=(0, 0, 0), p2=(150e-9, 150e-9, 150e-9))
mesh = df.Mesh(region=region, cell=(5e-9, 5e-9, 5e-9))
region = df.Region(p1=(0, 0, 0), p2=(100e-9, 100e-9, 100e-9))
mesh = df.Mesh(region=region, cell=(4e-9, 4e-9, 4e-9))
field3 = df.Field(mesh, nvdim=3, value=m_fun, norm=Ms)
saxs3 = mag2exp.x_ray.saxs(field3)
m3 = abs(saxs3.array).max()

assert np.isclose(m1, m2)
assert np.isclose(m1 / m3, (100 / 150) ** 6)
assert np.isclose(m2 / m1, (2 / 4) ** -3, rtol=0.05)
assert m3
# assert np.isclose(m3 / m1, (100 / 80) ** 3)
3 changes: 1 addition & 2 deletions mag2exp/x_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,5 @@ def saxs(field):
>>> xrs.mpl.scalar()
"""
m_fft = field.fftn().ft_z.sel(k_z=0)
m_fft *= field.mesh.dV * 1e16
m_fft = field.fftn(norm="ortho").ft_z.sel(k_z=0)
return abs(m_fft) ** 2

0 comments on commit 6c3fc99

Please sign in to comment.