Skip to content

Commit

Permalink
Rename cb2 to evenodd.
Browse files Browse the repository at this point in the history
Bug fix for GPT interface.
  • Loading branch information
SaltyChiang committed Dec 10, 2024
1 parent 230c2cf commit a8fccc3
Show file tree
Hide file tree
Showing 12 changed files with 91 additions and 73 deletions.
51 changes: 29 additions & 22 deletions pyquda_core/pyquda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from mpi4py import MPI

from ._version import __version__ # noqa: F401
from . import pyquda as quda
from .field import LatticeInfo


Expand Down Expand Up @@ -141,7 +140,7 @@ def _getDefaultGrid(mpi_size: int, latt_size: List[int]):
return min(min_grid)


def _initEnviron(**kwargs):
def _setEnviron(**kwargs):
def _setEnviron(env, key, value):
if value is not None:
if env in environ:
Expand All @@ -154,7 +153,7 @@ def _setEnviron(env, key, value):
_setEnviron(f"QUDA_{key.upper()}", key, kwargs[key])


def _initEnvironWarn(**kwargs):
def _setEnvironWarn(**kwargs):
def _setEnviron(env, key, value):
if value is not None:
if env in environ:
Expand All @@ -172,9 +171,8 @@ def _setEnviron(env, key, value):

def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1):
global _CUDA_BACKEND, _HIP, _GPUID, _COMPUTE_CAPABILITY

if isGridInitialized():
_MPI_LOGGER.critical("initGPU should be called before init", RuntimeError)
_MPI_LOGGER.critical("initGPU should be called before initGrid", RuntimeError)
if _GPUID < 0:
from platform import node as gethostname

Expand Down Expand Up @@ -239,17 +237,31 @@ def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1):
_MPI_LOGGER.warning("GPU is already initialized", RuntimeWarning)


def initQUDA(grid_size: List[int], gpuid: int):
def initGrid(grid_size: List[int]):
global _GRID_SIZE, _GRID_COORD
if _GRID_SIZE is None:
Gx, Gy, Gz, Gt = grid_size
if _MPI_SIZE != Gx * Gy * Gz * Gt:
_MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError)
_GRID_SIZE = [Gx, Gy, Gz, Gt]
_GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE)
_MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}")
else:
_MPI_LOGGER.warning("Grid is already initialized", RuntimeWarning)


def initQUDA(grid_size: List[int], gpuid: int, use_quda_allocator: bool = False):
import atexit
from . import pyquda as quda, malloc_pyquda

# if _CUDA_BACKEND == "cupy":
# import cupy
# from . import malloc_pyquda
if use_quda_allocator:
if _CUDA_BACKEND == "cupy":
import cupy

# allocator = cupy.cuda.PythonFunctionAllocator(
# malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free
# )
# cupy.cuda.set_allocator(allocator.malloc)
allocator = cupy.cuda.PythonFunctionAllocator(
malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free
)
cupy.cuda.set_allocator(allocator.malloc)

quda.initCommsGridQuda(4, grid_size)
quda.initQuda(gpuid)
Expand Down Expand Up @@ -293,28 +305,23 @@ def init(
"""
Initialize MPI along with the QUDA library.
"""
global _GRID_SIZE, _GRID_COORD, _DEFAULT_LATTICE
global _DEFAULT_LATTICE
if _GRID_SIZE is None:
initGPU(backend)

use_default_grid = grid_size is None and latt_size is not None
use_default_latt = latt_size is not None and t_boundary is not None and anisotropy is not None
if use_default_grid:
grid_size = _getDefaultGrid(_MPI_SIZE, latt_size)
Gx, Gy, Gz, Gt = grid_size if grid_size is not None else [1, 1, 1, 1]
if _MPI_SIZE != Gx * Gy * Gz * Gt:
_MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError)
_GRID_SIZE = [Gx, Gy, Gz, Gt]
_GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE)
_MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}")
initGrid(grid_size if grid_size is not None else [1, 1, 1, 1])
if use_default_grid and not use_default_latt:
_MPI_LOGGER.info(f"Using the lattice size {latt_size} only for getting the default grid size {_GRID_SIZE}")
if use_default_latt:
_DEFAULT_LATTICE = LatticeInfo(latt_size, t_boundary, anisotropy)
_MPI_LOGGER.info(f"Using the default lattice LatticeInfo({latt_size}, {t_boundary}, {anisotropy})")

_initEnvironWarn(resource_path=resource_path if resource_path != "" else None)
_initEnviron(
_setEnvironWarn(resource_path=resource_path if resource_path != "" else None)
_setEnviron(
rank_verbosity=",".join(rank_verbosity) if rank_verbosity != [0] else None,
enable_mps="1" if enable_mps else None,
enable_gdr="1" if enable_gdr else None,
Expand Down
2 changes: 1 addition & 1 deletion pyquda_core/pyquda/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.9.10"
__version__ = "0.9.11"
35 changes: 21 additions & 14 deletions pyquda_core/pyquda/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,40 +130,47 @@ def lexico(data: numpy.ndarray, axes: List[int], dtype=None):
Npre = int(numpy.prod(shape[: axes[0]]))
Nsuf = int(numpy.prod(shape[axes[-1] + 1 :]))
dtype = data.dtype if dtype is None else dtype
data_cb2 = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf)
data_evenodd = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf)
data_lexico = numpy.zeros((Npre, Lt, Lz, Ly, Lx, Nsuf), dtype)
for t in range(Lt):
for z in range(Lz):
for y in range(Ly):
eo = (t + z + y) % 2
if eo == 0:
data_lexico[:, t, z, y, 0::2] = data_cb2[:, 0, t, z, y, :]
data_lexico[:, t, z, y, 1::2] = data_cb2[:, 1, t, z, y, :]
data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 0, t, z, y, :]
data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 1, t, z, y, :]
else:
data_lexico[:, t, z, y, 1::2] = data_cb2[:, 0, t, z, y, :]
data_lexico[:, t, z, y, 0::2] = data_cb2[:, 1, t, z, y, :]
data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 0, t, z, y, :]
data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 1, t, z, y, :]
return data_lexico.reshape(*shape[: axes[0]], Lt, Lz, Ly, Lx, *shape[axes[-1] + 1 :])


def cb2(data: numpy.ndarray, axes: List[int], dtype=None):
def evenodd(data: numpy.ndarray, axes: List[int], dtype=None):
shape = data.shape
Lt, Lz, Ly, Lx = [shape[axis] for axis in axes]
Npre = int(numpy.prod(shape[: axes[0]]))
Nsuf = int(numpy.prod(shape[axes[-1] + 1 :]))
dtype = data.dtype if dtype is None else dtype
data_lexico = data.reshape(Npre, Lt, Lz, Ly, Lx, Nsuf)
data_cb2 = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype)
data_evenodd = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype)
for t in range(Lt):
for z in range(Lz):
for y in range(Ly):
eo = (t + z + y) % 2
if eo == 0:
data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
else:
data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
return data_cb2.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :])
data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
return data_evenodd.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :])


def cb2(data: numpy.ndarray, axes: List[int], dtype=None):
from . import getLogger

getLogger().warning("cb2 is deprecated, use evenodd instead", DeprecationWarning)
return evenodd(data, axes, dtype)


def checksum(latt_info: Union[LatticeInfo, GeneralInfo], data: numpy.ndarray) -> Tuple[int, int]:
Expand Down Expand Up @@ -675,9 +682,9 @@ def load(
if Nc is not None:
latt_info.Nc = Nc
if not issubclass(cls, MultiField):
retval = cls(latt_info, cb2(value, [0, 1, 2, 3]))
retval = cls(latt_info, evenodd(value, [0, 1, 2, 3]))
else:
retval = cls(latt_info, len(label), numpy.asarray([cb2(data, [0, 1, 2, 3]) for data in value]))
retval = cls(latt_info, len(label), numpy.asarray([evenodd(data, [0, 1, 2, 3]) for data in value]))
secs = perf_counter() - s
getLogger().debug(f"Loaded {filename} in {secs:.3f} secs, {gbytes / secs:.3f} GB/s")
return retval
Expand Down
5 changes: 4 additions & 1 deletion pyquda_utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pyquda import (
initGPU,
initGrid,
initQUDA,
init,
getCoordFromRank,
Expand All @@ -15,6 +16,7 @@
getGridCoord,
setDefaultLattice,
getDefaultLattice,
getCUDABackend,
getLogger,
setLoggerLevel,
dirac as fermion,
Expand All @@ -41,7 +43,8 @@
LatticePropagator,
LatticeStaggeredPropagator,
lexico,
cb2,
evenodd,
evenodd as cb2,
)
from pyquda.dirac.abstract import Multigrid, FermionDirac, StaggeredFermionDirac

Expand Down
8 changes: 4 additions & 4 deletions pyquda_utils/deprecated.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List

from pyquda import getLogger, getGridSize, quda, enum_quda
from pyquda import getLogger, getGridSize, pyquda as quda, enum_quda
from pyquda.field import LatticeFermion, LatticeGauge, LatticeInfo, LatticePropagator, Nc, Ns
from pyquda.dirac.abstract import FermionDirac

Expand Down Expand Up @@ -101,11 +101,11 @@ def getDslash(
latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, xi)

if clover_csw != 0.0:
from .dirac.clover_wilson import CloverWilsonDirac
from pyquda.dirac.clover_wilson import CloverWilsonDirac

return CloverWilsonDirac(latt_info, mass, tol, maxiter, clover_csw, clover_xi, geo_block_size)
else:
from .dirac.wilson import WilsonDirac
from pyquda.dirac.wilson import WilsonDirac

return WilsonDirac(latt_info, mass, tol, maxiter, geo_block_size)

Expand All @@ -131,6 +131,6 @@ def getStaggeredDslash(
t_boundary = 1
latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, 1.0)

from .dirac.hisq import HISQDirac
from pyquda.dirac.hisq import HISQDirac

return HISQDirac(latt_info, mass, tol, maxiter, naik_epsilon, None)
15 changes: 9 additions & 6 deletions pyquda_utils/gpt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import List
import numpy

from pyquda import getSublatticeSize, getGridSize
from pyquda.field import cb2, LatticeGauge, LatticeInfo, LatticePropagator
from .core import evenodd, getGridSize, LatticeGauge, LatticeInfo, LatticePropagator

import gpt as g


def LatticeInfoGPT(grid: g.grid, gen_simd_width: int):
assert getGridSize() == grid.mpi
sublatt_size = getSublatticeSize(grid.fdimensions, grid.mpi)
GLx, GLy, GLz, GLt = grid.fdimensions
Gx, Gy, Gz, Gt = grid.mpi
Lx, Ly, Lz, Lt = GLx // Gx, GLy // Gy, GLz // Gz, GLt // Gt
sublatt_size = [Lx, Ly, Lz, Lt]
Nd = len(sublatt_size)
precision = grid.precision.nbytes
n_simd = gen_simd_width // (2 * precision)
Expand All @@ -32,7 +34,7 @@ def LatticeGaugeGPT(lattice: List[g.lattice], gen_simd_width: int, gauge: Lattic
value = []
for index in range(latt_info.Nd):
value.append(
cb2(
evenodd(
numpy.asarray(lattice[index].mview()[0])
.view(f"<c{2 * gpt_prec}")
.reshape(*gpt_latt[::-1], Nc, Nc, *gpt_simd[::-1])
Expand All @@ -49,7 +51,8 @@ def LatticeGaugeGPT(lattice: List[g.lattice], gen_simd_width: int, gauge: Lattic
for index in range(latt_info.Nd):
gpt_shape = [i for sl in zip(gpt_simd, gpt_latt) for i in sl]
lattice[index].mview()[0][:] = (
gauge[index].lexico()
gauge[index]
.lexico()
.astype(f"<c{2 * gpt_prec}")
.reshape(*gpt_shape, Nc, Nc)
.transpose(1, 3, 5, 7, 8, 9, 0, 2, 4, 6)
Expand All @@ -65,7 +68,7 @@ def LatticePropagatorGPT(lattice: g.lattice, gen_simd_width: int, propagator: La
Ns, Nc = latt_info.Ns, latt_info.Nc
assert lattice.describe().startswith(f"ot_matrix_spin_color({Ns},{Nc})")
if propagator is None:
value = cb2(
value = evenodd(
numpy.asarray(lattice.mview()[0])
.view(f"<c{2 * gpt_prec}")
.reshape(*gpt_latt[::-1], Ns, Ns, Nc, Nc, *gpt_simd[::-1])
Expand Down
Loading

0 comments on commit a8fccc3

Please sign in to comment.