diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 07097cf4..cccac175 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -21,6 +21,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + sudo apt-get install -y libopenblas-dev python -m pip install --upgrade pip python -m pip install flake8 pytest pytest-runner coverage if [ -f requirements.txt ]; then pip install -r requirements.txt; fi diff --git a/.gitignore b/.gitignore index b31216a7..c63f88df 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ src/imagecpp.cpp fdasrsf_python.code-workspace *.so .coverage +src/*.dSYM/ +src/rbfgs +src/crbfgs.cpp diff --git a/CHANGES.txt b/CHANGES.txt index 73b24a07..fb2cd32b 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,4 @@ -v2.5.X, 2023-12-05 -- add projection and variance explained to fpca +v2.5.X, 2023-12-05 -- add projection and variance explained to fpca, add c version of rbfgs v2.5.2, 2023-11-18 -- bugfixes v2.5.0, 2023-11-17 -- bugfixes, add elastic changepoint v2.4.3, 2023-08-21 -- bugfixes, expose lam in curve, add example diff --git a/fdasrsf/time_warping.py b/fdasrsf/time_warping.py index 8399b9a9..34d6cd18 100644 --- a/fdasrsf/time_warping.py +++ b/fdasrsf/time_warping.py @@ -73,7 +73,6 @@ def __init__(self, f, time): self.time = time self.rsamps = False - def srsf_align(self, method="mean", omethod="DP2", center=True, smoothdata=False, MaxItr=20, parallel=False, lam=0.0, cores=-1, grid_dim=7, verbose=True): @@ -83,14 +82,16 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, :param method: (string) warp calculate Karcher Mean or Median (options = "mean" or "median") (default="mean") - :param omethod: optimization method (DP, DP2, RBFGS) (default = DP2) + :param omethod: optimization method (DP, DP2, RBFGS, cRBFGS) + (default = DP2) :param center: center warping functions (default = T) :param smoothdata: Smooth the data using a box filter (default = F) :param MaxItr: Maximum number of iterations (default = 20) :param parallel: run in parallel (default = F) :param lam: controls the elasticity (default = 0) :param cores: number of cores for parallel (default = -1 (all)) - :param grid_dim: size of the grid, for the DP2 method only (default = 7) + :param grid_dim: size of the grid, for the DP2 method only + (default = 7) :param verbose: print status output (default = T) :type lam: double :type smoothdata: bool @@ -149,12 +150,14 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, gam = np.array(out) gam = gam.transpose() else: - gam = np.zeros((M,N)) - for k in range(0,N): - gam[:,k] = uf.optimum_reparam(mq,self.time,q[:,k],omethod,lam,grid_dim) + gam = np.zeros((M, N)) + for k in range(0, N): + gam[:, k] = uf.optimum_reparam(mq, self.time, q[:, k], + omethod, lam, grid_dim) gamI = uf.SqrtMeanInverse(gam) - mf = np.interp((self.time[-1] - self.time[0]) * gamI + self.time[0], self.time, mf) + mf = np.interp((self.time[-1] - self.time[0]) * gamI + self.time[0], + self.time, mf) mq = uf.f_to_srsf(mf, self.time) # Compute Karcher Mean @@ -189,13 +192,14 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, # Matching Step if parallel: out = Parallel(n_jobs=cores)(delayed(uf.optimum_reparam)(mq[:, r], - self.time, q[:, n, 0], omethod, lam, grid_dim) for n in range(N)) + self.time, q[:, n, 0], omethod, lam, grid_dim) for n in range(N)) gam = np.array(out) gam = gam.transpose() else: - for k in range(0,N): - gam[:,k] = uf.optimum_reparam(mq[:, r], self.time, q[:, k, 0], - omethod, lam, grid_dim) + for k in range(0 ,N): + gam[:, k] = uf.optimum_reparam(mq[:, r], self.time, + q[:, k, 0], omethod, lam, + grid_dim) gam_dev = np.zeros((M, N)) vtil = np.zeros((M,N)) @@ -205,9 +209,9 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, + self.time[0], self.time, f[:, k, 0]) q[:, k, r + 1] = uf.f_to_srsf(f[:, k, r + 1], self.time) gam_dev[:, k] = np.gradient(gam[:, k], 1 / float(M - 1)) - v = q[:, k, r + 1] - mq[:,r] + v = q[:, k, r + 1] - mq[:, r] d = np.sqrt(trapz(v*v, self.time)) - vtil[:,k] = v/d + vtil[:, k] = v/d dtil[k] = 1.0/d mqt = mq[:, r] @@ -259,8 +263,8 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, gam = np.array(out) gam = gam.transpose() else: - for k in range(0,N): - gam[:,k] = uf.optimum_reparam(mq[:, r], self.time, q[:, k, 0], omethod, + for k in range(0, N): + gam[:, k] = uf.optimum_reparam(mq[:, r], self.time, q[:, k, 0], omethod, lam, grid_dim) gam_dev = np.zeros((M, N)) @@ -310,7 +314,6 @@ def srsf_align(self, method="mean", omethod="DP2", center=True, return - def plot(self): """ plot functional alignment results @@ -322,7 +325,7 @@ def plot(self): plot.f_plot(self.time, self.f, title="f Original Data") fig, ax = plot.f_plot(np.arange(0, M) / float(M - 1), self.gam, - title="Warping Functions") + title="Warping Functions") ax.set_aspect('equal') plot.f_plot(self.time, self.fn, title="Warped Data") @@ -380,12 +383,12 @@ def gauss_model(self, n=1, sort_samples=False): fs = np.zeros((M, n)) for k in range(0, n): fs[:, k] = uf.cumtrapzmid(time, q_s[0:M, k] * np.abs(q_s[0:M, k]), - np.sign(q_s[M, k]) * (q_s[M, k] ** 2), - mididx) + np.sign(q_s[M, k]) * (q_s[M, k] ** 2), + mididx) fbar = fn.mean(axis=1) fsbar = fs.mean(axis=1) - err = np.transpose(np.tile(fbar-fsbar, (n,1))) + err = np.transpose(np.tile(fbar-fsbar, (n, 1))) fs += err # random warping generation @@ -415,7 +418,7 @@ def gauss_model(self, n=1, sort_samples=False): ft = np.zeros((M, n)) for k in range(0, n): ft[:, k] = np.interp(gams[:, seq2[k]], np.arange(0, M) / - np.double(M - 1), fs[:, seq1[k]]) + np.double(M - 1), fs[:, seq1[k]]) tmp = np.isnan(ft[:, k]) while tmp.any(): rgam2 = uf.randomGamma(gam, 1) @@ -426,23 +429,21 @@ def gauss_model(self, n=1, sort_samples=False): ft = np.zeros((M, n)) for k in range(0, n): ft[:, k] = np.interp(gams[:, k], np.arange(0, M) / - np.double(M - 1), fs[:, k]) + np.double(M - 1), fs[:, k]) tmp = np.isnan(ft[:, k]) while tmp.any(): rgam2 = uf.randomGamma(gam, 1) ft[:, k] = np.interp(gams[:, k], np.arange(0, M) / - np.double(M - 1), uf.invertGamma(rgam2)) - - + np.double(M - 1), uf.invertGamma(rgam2)) + self.rsamps = True self.fs = fs self.gams = rgam self.ft = ft - self.qs = q_s[0:M,:] + self.qs = q_s[0:M, :] return - def joint_gauss_model(self, n=1, no=3): """ This function models the functional data using a joint Gaussian model @@ -458,7 +459,6 @@ def joint_gauss_model(self, n=1, no=3): fn = self.fn time = self.time qn = self.qn - gam = self.gam M = time.size @@ -480,22 +480,21 @@ def joint_gauss_model(self, n=1, no=3): vals = np.random.multivariate_normal(np.zeros(s.shape), np.diag(s), n) tmp = np.matmul(U, np.transpose(vals)) - qhat = np.tile(mqn.T,(n,1)).T + tmp[0:M+1,:] + qhat = np.tile(mqn.T, (n, 1)).T + tmp[0:M+1, :] tmp = np.matmul(U, np.transpose(vals)/C) - vechat = tmp[(M+1):,:] - psihat = np.zeros((M,n)) - gamhat = np.zeros((M,n)) + vechat = tmp[(M+1):, :] + psihat = np.zeros((M, n)) + gamhat = np.zeros((M, n)) for ii in range(n): - psihat[:,ii] = geo.exp_map(mu_psi,vechat[:,ii]) - gam_tmp = cumtrapz(psihat[:,ii]**2,np.linspace(0,1,M),initial=0.0) - gamhat[:,ii] = (gam_tmp - gam_tmp.min())/(gam_tmp.max()-gam_tmp.min()) + psihat[:, ii] = geo.exp_map(mu_psi, vechat[:, ii]) + gam_tmp = cumtrapz(psihat[:, ii]**2, np.linspace(0, 1, M), initial=0.0) + gamhat[:, ii] = (gam_tmp - gam_tmp.min())/(gam_tmp.max()-gam_tmp.min()) - ft = np.zeros((M,n)) - fhat = np.zeros((M,n)) + ft = np.zeros((M, n)) + fhat = np.zeros((M, n)) for ii in range(n): - fhat[:,ii] = uf.cumtrapzmid(time, qhat[0:M,ii]*np.fabs(qhat[0:M,ii]), np.sign(qhat[M,ii])*(qhat[M,ii]*qhat[M,ii]), mididx) - ft[:,ii] = uf.warp_f_gamma(np.linspace(0,1,M),fhat[:,ii],gamhat[:,ii]) - + fhat[:, ii] = uf.cumtrapzmid(time, qhat[0:M, ii]*np.fabs(qhat[0:M, ii]), np.sign(qhat[M,ii])*(qhat[M,ii]*qhat[M,ii]), mididx) + ft[:, ii] = uf.warp_f_gamma(np.linspace(0, 1, M), fhat[:, ii], gamhat[:,ii]) self.rsamps = True self.fs = fhat @@ -506,22 +505,25 @@ def joint_gauss_model(self, n=1, no=3): return def multiple_align_functions(self, mu, omethod="DP2", smoothdata=False, - parallel=False, lam=0.0, cores=-1, grid_dim=7): + parallel=False, lam=0.0, cores=-1, + grid_dim=7): """ - This function aligns a collection of functions using the elastic square-root - slope (srsf) framework. + This function aligns a collection of functions using the elastic + square-root slope (srsf) framework. Usage: obj.multiple_align_functions(mu) obj.multiple_align_functions(lambda) obj.multiple_align_functions(lambda, ...) :param mu: vector of function to align to - :param omethod: optimization method (DP, DP2, RBFGS) (default = DP) + :param omethod: optimization method (DP, DP2, RBFGS, cRBFGS) + (default = DP2) :param smoothdata: Smooth the data using a box filter (default = F) :param parallel: run in parallel (default = F) :param lam: controls the elasticity (default = 0) :param cores: number of cores for parallel (default = -1 (all)) - :param grid_dim: size of the grid, for the DP2 method only (default = 7) + :param grid_dim: size of the grid, for the DP2 method only + (default = 7) :type lam: double :type smoothdata: bool @@ -548,31 +550,29 @@ def multiple_align_functions(self, mu, omethod="DP2", smoothdata=False, if parallel: out = Parallel(n_jobs=cores)(delayed(uf.optimum_reparam)(mq, self.time, - q[:, n], omethod, lam, grid_dim) for n in range(N)) + q[:, n], omethod, lam, grid_dim) for n in range(N)) gam = np.array(out) gam = gam.transpose() else: - gam = np.zeros((M,N)) - for k in range(0,N): - gam[:,k] = uf.optimum_reparam(mq,self.time,q[:,k],omethod,lam,grid_dim) + gam = np.zeros((M, N)) + for k in range(0, N): + gam[:, k] = uf.optimum_reparam(mq, self.time, q[:, k], omethod, + lam, grid_dim) self.gamI = uf.SqrtMeanInverse(gam) - fn = np.zeros((M,N)) - qn = np.zeros((M,N)) + fn = np.zeros((M, N)) + qn = np.zeros((M, N)) for k in range(0, N): fn[:, k] = np.interp((self.time[-1] - self.time[0]) * gam[:, k] - + self.time[0], self.time, f[:, k]) + + self.time[0], self.time, f[:, k]) qn[:, k] = uf.f_to_srsf(f[:, k], self.time) - # Aligned data & stats self.fn = fn self.qn = qn self.q0 = q - mean_f0 = f.mean(axis=1) std_f0 = f.std(axis=1) - mean_fn = self.fn.mean(axis=1) std_fn = self.fn.std(axis=1) self.gam = gam self.mqn = mq @@ -597,12 +597,13 @@ def pairwise_align_functions(f1, f2, time, omethod="DP2", lam=0, grid_dim=7): slope (srsf) framework. Usage: out = pairwise_align_functions(f1, f2, time) - out = pairwise_align_functions(f1, f2, time, omethod, lam, grid_dim) + out = pairwise_align_functions(f1, f2, time, omethod, lam, + grid_dim) :param f1: vector defining M samples of function 1 :param f2: vector defining M samples of function 2 :param time: time vector of length M - :param omethod: optimization method (DP, DP2, RBFGS) (default = DP) + :param omethod: optimization method (DP, DP2, RBFGS, cRBFGS) (default = DP) :param lam: controls the elasticity (default = 0) :param grid_dim: size of the grid, for the DP2 method only (default = 7) @@ -618,10 +619,9 @@ def pairwise_align_functions(f1, f2, time, omethod="DP2", lam=0, grid_dim=7): gam = uf.optimum_reparam(q1, time, q2, omethod, lam, grid_dim) - f2n = uf.warp_f_gamma(time, f2 , gam) + f2n = uf.warp_f_gamma(time, f2, gam) q2n = uf.f_to_srsf(f2n, time) - return (f2n, gam, q2n) @@ -690,7 +690,6 @@ def pairwise_align_bayes(f1i, f2i, time, mcmcopts=None): raise Exception('Length of mcmcopts.initcoef must be even') # Number of sig figs to report in gamma_mat - SIG_GAM = 13 iter = mcmcopts["iter"] # parameter settings @@ -721,22 +720,23 @@ def propose_g_coef(g_coef_curr): # normalize time to [0,1] time = (time - time.min())/(time.max()-time.min()) - timet = np.linspace(0,1,numSimPoints) - f1 = uf.f_predictfunction(f1i,timet,0) - f2 = uf.f_predictfunction(f2i,timet,0) + timet = np.linspace(0, 1, numSimPoints) + f1 = uf.f_predictfunction(f1i, timet, 0) + f2 = uf.f_predictfunction(f2i, timet, 0) # srsf transformation - q1 = uf.f_to_srsf(f1,timet) - q1i = uf.f_to_srsf(f1i,time) - q2 = uf.f_to_srsf(f2,timet) + q1 = uf.f_to_srsf(f1, timet) + q1i = uf.f_to_srsf(f1i, time) + q2 = uf.f_to_srsf(f2, timet) - tmp = uf.f_exp1(uf.f_basistofunction(g_basis["x"],0,g_coef_ini,g_basis)) + tmp = uf.f_exp1(uf.f_basistofunction(g_basis["x"], 0, + g_coef_ini, g_basis)) if tmp.min() < 0: raise Exception("Invalid initial value of g") # result vectors - g_coef = np.zeros((iter,g_coef_ini.shape[0])) + g_coef = np.zeros((iter, g_coef_ini.shape[0])) sigma1 = np.zeros(iter) logl = np.zeros(iter) SSE = np.zeros(iter) @@ -746,16 +746,21 @@ def propose_g_coef(g_coef_curr): # init g_coef_curr = g_coef_ini sigma1_curr = sigma1_ini - SSE_curr = bf.f_SSEg_pw(uf.f_basistofunction(g_basis["x"],0,g_coef_ini,g_basis),q1,q2) - logl_curr = bf.f_logl_pw(uf.f_basistofunction(g_basis["x"],0,g_coef_ini,g_basis),q1,q2,sigma1_ini**2,SSE_curr) + SSE_curr = bf.f_SSEg_pw(uf.f_basistofunction(g_basis["x"], 0, + g_coef_ini, g_basis), + q1, q2) + logl_curr = bf.f_logl_pw(uf.f_basistofunction(g_basis["x"], 0, + g_coef_ini, g_basis), + q1, q2, sigma1_ini**2, + SSE_curr) - g_coef[0,:] = g_coef_ini + g_coef[0, :] = g_coef_ini sigma1[0] = sigma1_ini SSE[0] = SSE_curr logl[0] = logl_curr # update the chain for iter-1 times - for m in tqdm(range(1,iter)): + for m in tqdm(range(1, iter)): # update g g_coef_curr, tmp, SSE_curr, accepti, zpcnInd = bf.f_updateg_pw(g_coef_curr, g_basis, sigma1_curr**2, q1, q2, SSE_curr, propose_g_coef) @@ -825,8 +830,8 @@ def pairwise_align_bayes_infHMC(y1i, y2i, time, mcmcopts=None): This function aligns two functions using Bayesian framework. It uses a hierarchical Bayesian framework assuming mearsurement error error It will align f2 to f1. It is based on mapping warping functions to a hypersphere, - and a subsequent exponential mapping to a tangent space. In the tangent space, - the \infty-HMC algorithm is used to explore both local and global + and a subsequent exponential mapping to a tangent space. In the tangent + space, the \infty-HMC algorithm is used to explore both local and global structure in the posterior distribution. Usage: out = pairwise_align_bayes_infHMC(f1i, f2i, time) @@ -1065,7 +1070,6 @@ def pairwise_align_bayes_infHMC(y1i, y2i, time, mcmcopts=None): def run_mcmc(y1i, y2i, time, mcmcopts): # Number of sig figs to report in gamma_mat - SIG_GAM = 13 iter = mcmcopts["iter"] T = time.shape[0] @@ -1197,7 +1201,7 @@ def propose_v_coef(v_coef_curr): nll, g, SSE_curr = bf.f_dlogl_pw(v_coef_curr, v_basis, d_basis, sigma_curr, q1_curr, q2_curr) # update the chain for iter-1 times - for m in range(1,iter): + for m in range(1, iter): # update f1 f1_curr, q1_curr, f1_accept1 = bf.f_updatef1_pw(f1_curr,q1_curr, y1i, q2_curr,v_coef_curr, v_basis, diff --git a/fdasrsf/utility_functions.py b/fdasrsf/utility_functions.py index 921734c8..21db199b 100644 --- a/fdasrsf/utility_functions.py +++ b/fdasrsf/utility_functions.py @@ -21,6 +21,7 @@ import optimum_reparamN2 as orN2 import optimum_reparam_N as orN import cbayesian as bay +import crbfgs as cr import fdasrsf.geometry as geo from fdasrsf.rbfgs import rlbfgs import sys @@ -43,8 +44,8 @@ def smooth_data(f, sparam=1): fo = f.copy() for k in range(0, sparam): for r in range(0, N): - fo[1 : (M - 2), r] = ( - fo[0 : (M - 3), r] + 2 * fo[1 : (M - 2), r] + fo[2 : (M - 1), r] + fo[1: (M - 2), r] = ( + fo[0: (M - 3), r] + 2 * fo[1: (M - 2), r] + fo[2: (M - 1), r] ) / 4 return fo @@ -135,11 +136,13 @@ def optimum_reparam( :param q1: vector of size N or array of NxM samples of first SRSF :param time: vector of size N describing the sample points :param q2: vector of size N or array of NxM samples samples of second SRSF - :param method: method to apply optimization (default="DP2") options are "DP","DP2","RBFGS" + :param method: method to apply optimization (default="DP2") options are + "DP","DP2","RBFGS","cRBFGS" :param lam: controls the amount of elasticity (default = 0.0) - :param penalty: penalty type (default="roughness") options are "roughness", "l2gam", - "l2psi", "geodesic". Only roughness implemented in all methods. To use - others method needs to be "RBFGS" + :param penalty: penalty type (default="roughness") options are "roughness", + "l2gam", "l2psi", "geodesic". Only roughness implemented + in all methods. To use others method needs to be "RBFGS" + or "cRBFGS" :param grid_dim: size of the grid, for the DP2 method only (default = 7) :rtype: vector @@ -172,17 +175,20 @@ def optimum_reparam( elif method == "DP2": if q1.ndim == 1 and q2.ndim == 1: gam = orN2.coptimum_reparam( - ascontiguousarray(q1), time, ascontiguousarray(q2), lam, grid_dim + ascontiguousarray(q1), time, ascontiguousarray(q2), lam, + grid_dim ) if q1.ndim == 1 and q2.ndim == 2: gam = orN2.coptimum_reparamN( - ascontiguousarray(q1), time, ascontiguousarray(q2), lam, grid_dim + ascontiguousarray(q1), time, ascontiguousarray(q2), lam, + grid_dim ) if q1.ndim == 2 and q2.ndim == 2: gam = orN2.coptimum_reparamN2( - ascontiguousarray(q1), time, ascontiguousarray(q2), lam, grid_dim + ascontiguousarray(q1), time, ascontiguousarray(q2), lam, + grid_dim ) elif method == "RBFGS": if q1.ndim == 1 and q2.ndim == 1: @@ -206,6 +212,38 @@ def optimum_reparam( obj = rlbfgs(q1[:, i], q2[:, i], time) obj.solve(lam=lam, penalty=penalty) gam[:, i] = obj.gammaOpt + elif method == "cRBFGS": + if penalty == "roughness": + pen = 0 + elif penalty == "l2gam": + pen = 1 + elif penalty == "l2psi": + pen = 2 + elif penalty == "geodesic": + pen = 3 + else: + raise Exception("penalty not implemented") + + if q1.ndim == 1 and q2.ndim == 1: + time = linspace(0, 1, q1.shape[0]) + gam = cr.rlbfgs(ascontiguousarray(q1), ascontiguousarray(q2), + ascontiguousarray(time), 30, lam, pen) + + if q1.ndim == 1 and q2.ndim == 2: + gam = zeros(q2.shape) + time = linspace(0, 1, q1.shape[0]) + for i in range(0, q2.shape[1]): + gam[:, i] = cr.rlbfgs(ascontiguousarray(q1), + ascontiguousarray(q2[:, i]), + ascontiguousarray(time), 30, lam, pen) + + if q1.ndim == 2 and q2.ndim == 2: + gam = zeros(q2.shape) + time = linspace(0, 1, q1.shape[0]) + for i in range(0, q2.shape[1]): + gam[:, i] = cr.rlbfgs(ascontiguousarray(q1[:, i]), + ascontiguousarray(q2[:, i]), + ascontiguousarray(time), 30, lam, pen) else: raise Exception("Invalid Optimization Method") @@ -265,7 +303,8 @@ def elastic_depth(f, time, method="DP2", lam=0.0, parallel=True): :param f: matrix of size MxN (M time points for N functions) :param time: vector of size M describing the sample points - :param method: method to apply optimization (default="DP2") options are "DP","DP2","RBFGS" + :param method: method to apply optimization (default="DP2") + options are "DP","DP2","RBFGS","cRBFGS" :param lam: controls the elasticity (default = 0.0) :rtype: scalar @@ -288,7 +327,8 @@ def elastic_depth(f, time, method="DP2", lam=0.0, parallel=True): phs_dist[i, :] = out[i][1] else: for i in range(0, fns): - amp_dist[i, :], phs_dist[i, :] = distmat(f, f[:, i], time, i, method) + amp_dist[i, :], phs_dist[i, :] = distmat(f, f[:, i], time, i, + method) amp_dist = amp_dist + amp_dist.T phs_dist = phs_dist + phs_dist.T @@ -309,7 +349,8 @@ def elastic_distance(f1, f2, time, method="DP2", lam=0.0, alpha=None): :param f1: vector of size N :param f2: vector of size N :param time: vector of size N describing the sample points - :param method: method to apply optimization (default="DP2") options are "DP","DP2","RBFGS" + :param method: method to apply optimization (default="DP2") + options are "DP","DP2","RBFGS","cRBFGS" :param lam: controls the elasticity (default = 0.0) :param alpha: makes alpha * dx + (1-alpha) * dy @@ -340,7 +381,7 @@ def elastic_distance(f1, f2, time, method="DP2", lam=0.0, alpha=None): Dx = real(arccos(q1dotq2)) if alpha is not None: - Dt = alpha * Dx + (1-alpha) * Dy + Dt = alpha * Dx + (1 - alpha) * Dy return Dy, Dx, Dt else: return Dy, Dx @@ -423,7 +464,8 @@ def SqrtMeanInverse(gam): def SqrtMean(gam, parallel=False, cores=-1): """ - calculates the srsf of warping functions with corresponding shooting vectors + calculates the srsf of warping functions with corresponding shooting + vectors :param gam: numpy ndarray of shape (M,N) of M warping functions with N samples @@ -432,8 +474,10 @@ def SqrtMean(gam, parallel=False, cores=-1): :rtype: 2 numpy ndarray and vector :return mu: Karcher mean psi function - :return gam_mu: vector of dim N which is the Karcher mean warping function - :return psi: numpy ndarray of shape (M,N) of M SRSF of the warping functions + :return gam_mu: vector of dim N which is the Karcher mean warping + function + :return psi: numpy ndarray of shape (M,N) of M SRSF of the warping + functions :return vec: numpy ndarray of shape (M,N) of M shooting vectors """ @@ -462,7 +506,6 @@ def SqrtMean(gam, parallel=False, cores=-1): min_ind = dqq.argmin() mu = psi[:, min_ind] maxiter = 501 - tt = 1 lvm = zeros(maxiter) vec = zeros((T, n)) stp = 0.3 @@ -512,15 +555,18 @@ def inv_exp_map_sub(mu, psi): def SqrtMedian(gam): """ - calculates the median srsf of warping functions with corresponding shooting vectors + calculates the median srsf of warping functions with corresponding + shooting vectors :param gam: numpy ndarray of shape (M,N) of M warping functions with N samples :rtype: 2 numpy ndarray and vector :return gam_median: Karcher median warping function - :return psi_meidan: vector of dim N which is the Karcher median srsf function - :return psi: numpy ndarray of shape (M,N) of M SRSF of the warping functions + :return psi_meidan: vector of dim N which is the Karcher median srsf + function + :return psi: numpy ndarray of shape (M,N) of M SRSF of the warping + functions :return vec: numpy ndarray of shape (M,N) of M shooting vectors """ @@ -594,7 +640,7 @@ def cumtrapzmid(x, y, c, mid): fa[0:mid] = tmp[::-1] # case >= mid - fa[mid:a] = c + cumtrapz(y[mid - 1 : a - 1], x[mid - 1 : a - 1], initial=0) + fa[mid:a] = c + cumtrapz(y[mid - 1: a - 1], x[mid - 1: a - 1], initial=0) return fa @@ -782,7 +828,6 @@ def geigen(Amat, Bmat, Cmat): p = Bmat.shape[0] q = Cmat.shape[0] - s = min(p, q) tmp = fabs(Bmat - Bmat.transpose()) tmp1 = fabs(Bmat) if tmp.max() / tmp1.max() > 1e-10: @@ -866,7 +911,6 @@ def warp_f_gamma(time, f, gam): :return f_temp: warped srsf """ - M = gam.size f_temp = interp((time[-1] - time[0]) * gam + time[0], time, f) return f_temp diff --git a/setup.py b/setup.py index 3103ffe5..76747dd3 100644 --- a/setup.py +++ b/setup.py @@ -11,18 +11,22 @@ from distutils.sysconfig import get_config_var, get_python_inc from distutils.version import LooseVersion -sys.path.insert(1, 'src/') +sys.path.insert(1, "src/") import dp_build # Make sure I have the right Python version. if sys.version_info[:2] < (3, 6): - print(("fdasrsf requires Python 3.6 or newer. Python %d.%d detected" % sys.version_info[:2])) + print( + ( + "fdasrsf requires Python 3.6 or newer. Python %d.%d detected" + % sys.version_info[:2] + ) + ) sys.exit(-1) class build_docs(Command): - """Builds the documentation - """ + """Builds the documentation""" description = "builds the documentation" user_options = [] @@ -43,78 +47,104 @@ def run(self): os.system("latexmk -pdf fdasrsf.tex") os.chdir("../../../") -if (sys.platform == 'darwin'): - mac_ver = str(LooseVersion(get_config_var('MACOSX_DEPLOYMENT_TARGET'))) - os.environ['MACOSX_DEPLOYMENT_TARGET'] = mac_ver + +if sys.platform == "darwin": + mac_ver = str(LooseVersion(get_config_var("MACOSX_DEPLOYMENT_TARGET"))) + os.environ["MACOSX_DEPLOYMENT_TARGET"] = mac_ver extensions = [ - Extension(name="optimum_reparamN2", - sources=["src/optimum_reparamN2.pyx", "src/DynamicProgrammingQ2.c", - "src/dp_grid.c", "src/dp_nbhd.c"], - include_dirs=[numpy.get_include()], - language="c" - ), - Extension(name="fpls_warp", - sources=["src/fpls_warp.pyx", "src/fpls_warp_grad.c", "src/misc_funcs.c"], - include_dirs=[numpy.get_include()], - language="c" - ), - Extension(name="mlogit_warp", - sources=["src/mlogit_warp.pyx", "src/mlogit_warp_grad.c", "src/misc_funcs.c"], - include_dirs=[numpy.get_include()], - language="c" - ), - Extension(name="ocmlogit_warp", - sources=["src/ocmlogit_warp.pyx", "src/ocmlogit_warp_grad.c", "src/misc_funcs.c"], - include_dirs=[numpy.get_include()], - language="c" - ), - Extension(name="oclogit_warp", + Extension( + name="optimum_reparamN2", + sources=[ + "src/optimum_reparamN2.pyx", + "src/DynamicProgrammingQ2.c", + "src/dp_grid.c", + "src/dp_nbhd.c", + ], + include_dirs=[numpy.get_include()], + language="c", + ), + Extension( + name="fpls_warp", + sources=["src/fpls_warp.pyx", "src/fpls_warp_grad.c", "src/misc_funcs.c"], + include_dirs=[numpy.get_include()], + language="c", + ), + Extension( + name="mlogit_warp", + sources=["src/mlogit_warp.pyx", "src/mlogit_warp_grad.c", "src/misc_funcs.c"], + include_dirs=[numpy.get_include()], + language="c", + ), + Extension( + name="ocmlogit_warp", + sources=[ + "src/ocmlogit_warp.pyx", + "src/ocmlogit_warp_grad.c", + "src/misc_funcs.c", + ], + include_dirs=[numpy.get_include()], + language="c", + ), + Extension( + name="oclogit_warp", sources=["src/oclogit_warp.pyx", "src/oclogit_warp_grad.c", "src/misc_funcs.c"], include_dirs=[numpy.get_include()], - language="c" + language="c", ), - Extension(name="optimum_reparam_N", + Extension( + name="optimum_reparam_N", sources=["src/optimum_reparam_N.pyx", "src/DP.c"], include_dirs=[numpy.get_include()], - language="c" + language="c", ), - Extension(name="cbayesian", + Extension( + name="cbayesian", sources=["src/cbayesian.pyx", "src/bayesian.cpp"], include_dirs=[numpy.get_include()], - language="c++" + language="c++", + extra_compile_args=["-std=c++11"], + ), + Extension( + name="crbfgs", + sources=["src/crbfgs.pyx", "src/rbfgs.cpp"], + include_dirs=[numpy.get_include()], + language="c++", + libraries=['blas', 'lapack'], + extra_compile_args=["-std=c++11"], ), - Extension(name="cimage", + Extension( + name="cimage", sources=["src/imagecpp.pyx", "src/UnitSquareImage.cpp"], include_dirs=[numpy.get_include()], - language="c++" + language="c++", ), dp_build.ffibuilder.distutils_extension(), ] setup( - cmdclass={'build_ext': build_ext, 'build_docs': build_docs}, - ext_modules=extensions, - name='fdasrsf', - version='2.5.2', - packages=['fdasrsf'], - url='http://research.tetonedge.net', - license='LICENSE.txt', - author='J. Derek Tucker', - author_email='jdtuck@sandia.gov', - scripts=['bin/ex_srsf_align.py'], - keywords=['functional data analysis'], - description='functional data analysis using the square root slope framework', - long_description=open('README.md', encoding="utf8").read(), - data_files=[('share/man/man1', ['doc/build/man/fdasrsf.1'])], + cmdclass={"build_ext": build_ext, "build_docs": build_docs}, + ext_modules=extensions, + name="fdasrsf", + version="2.5.2", + packages=["fdasrsf"], + url="http://research.tetonedge.net", + license="LICENSE.txt", + author="J. Derek Tucker", + author_email="jdtuck@sandia.gov", + scripts=["bin/ex_srsf_align.py"], + keywords=["functional data analysis"], + description="functional data analysis using the square root slope framework", + long_description=open("README.md", encoding="utf8").read(), + data_files=[("share/man/man1", ["doc/build/man/fdasrsf.1"])], classifiers=[ - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.8', - ] + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + ], ) diff --git a/src/armadillo b/src/armadillo index bcf9402f..59694d1c 100644 --- a/src/armadillo +++ b/src/armadillo @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -17,12 +19,19 @@ #ifndef ARMA_INCLUDES #define ARMA_INCLUDES +// NOTE: functions that are designed to be user accessible are described in the documentation (docs.html). +// NOTE: all other functions and classes (ie. not explicitly described in the documentation) +// NOTE: are considered as internal implementation details, and may be changed or removed without notice. + +#include "armadillo_bits/config.hpp" +#include "armadillo_bits/compiler_check.hpp" #include #include #include #include #include +#include #include #include @@ -37,62 +46,72 @@ #include #include #include +#include +#include +#include +#include +#include - -#if ( defined(__unix__) || defined(__unix) || defined(_POSIX_C_SOURCE) || (defined(__APPLE__) && defined(__MACH__)) ) && !defined(_WIN32) - #include -#endif - - -#if (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - #include +#if !defined(ARMA_DONT_USE_STD_MUTEX) + #include #endif +// #if defined(ARMA_HAVE_CXX17) +// #include +// #include +// #endif -#include "armadillo_bits/compiler_extra.hpp" -#include "armadillo_bits/config.hpp" -#include "armadillo_bits/compiler_setup.hpp" - - -#if defined(ARMA_USE_CXX11) - #include - #include - #include - #include - #include - #include - #include +#if ( defined(__unix__) || defined(__unix) || defined(_POSIX_C_SOURCE) || (defined(__APPLE__) && defined(__MACH__)) ) && !defined(_WIN32) + #include #endif - #if defined(ARMA_USE_TBB_ALLOC) - #include + #if defined(__has_include) + #if __has_include() + #include + #else + #undef ARMA_USE_TBB_ALLOC + #pragma message ("WARNING: use of TBB alloc disabled; tbb/scalable_allocator.h header not found") + #endif + #else + #include + #endif #endif - #if defined(ARMA_USE_MKL_ALLOC) - #include -#endif - - -#if !defined(ARMA_USE_CXX11) - #if defined(ARMA_HAVE_TR1) - #include - #include + #if defined(__has_include) + #if __has_include() + #include + #else + #undef ARMA_USE_MKL_ALLOC + #pragma message ("WARNING: use of MKL alloc disabled; mkl_service.h header not found") + #endif + #else + #include #endif #endif -#include "armadillo_bits/include_atlas.hpp" -#include "armadillo_bits/include_hdf5.hpp" -#include "armadillo_bits/include_superlu.hpp" +#include "armadillo_bits/compiler_setup.hpp" #if defined(ARMA_USE_OPENMP) - #include + #if defined(__has_include) + #if __has_include() + #include + #else + #undef ARMA_USE_OPENMP + #pragma message ("WARNING: use of OpenMP disabled; omp.h header not found") + #endif + #else + #include + #endif #endif +#include "armadillo_bits/include_hdf5.hpp" +#include "armadillo_bits/include_superlu.hpp" + //! \namespace arma namespace for Armadillo classes and functions namespace arma @@ -119,14 +138,14 @@ namespace arma #include "armadillo_bits/constants_old.hpp" #include "armadillo_bits/mp_misc.hpp" #include "armadillo_bits/arma_rel_comparators.hpp" + #include "armadillo_bits/fill.hpp" - #ifdef ARMA_RNG_ALT + #if defined(ARMA_RNG_ALT) #include ARMA_INCFILE_WRAP(ARMA_RNG_ALT) #else - #include "armadillo_bits/arma_rng_cxx98.hpp" + #include "armadillo_bits/arma_rng_cxx03.hpp" #endif - #include "armadillo_bits/arma_rng_cxx11.hpp" #include "armadillo_bits/arma_rng.hpp" @@ -138,17 +157,18 @@ namespace arma #include "armadillo_bits/SpBase_bones.hpp" #include "armadillo_bits/def_blas.hpp" - #include "armadillo_bits/def_lapack.hpp" #include "armadillo_bits/def_atlas.hpp" + #include "armadillo_bits/def_lapack.hpp" #include "armadillo_bits/def_arpack.hpp" #include "armadillo_bits/def_superlu.hpp" - #include "armadillo_bits/def_hdf5.hpp" + #include "armadillo_bits/def_fftw3.hpp" #include "armadillo_bits/translate_blas.hpp" - #include "armadillo_bits/translate_lapack.hpp" #include "armadillo_bits/translate_atlas.hpp" + #include "armadillo_bits/translate_lapack.hpp" #include "armadillo_bits/translate_arpack.hpp" #include "armadillo_bits/translate_superlu.hpp" + #include "armadillo_bits/translate_fftw3.hpp" #include "armadillo_bits/cond_rel_bones.hpp" #include "armadillo_bits/arrayops_bones.hpp" @@ -172,6 +192,7 @@ namespace arma #include "armadillo_bits/SpCol_bones.hpp" #include "armadillo_bits/SpRow_bones.hpp" #include "armadillo_bits/SpSubview_bones.hpp" + #include "armadillo_bits/SpSubview_col_list_bones.hpp" #include "armadillo_bits/spdiagview_bones.hpp" #include "armadillo_bits/MapMat_bones.hpp" @@ -188,7 +209,8 @@ namespace arma #include "armadillo_bits/subview_cube_each_bones.hpp" #include "armadillo_bits/subview_cube_slices_bones.hpp" - + #include "armadillo_bits/hdf5_name.hpp" + #include "armadillo_bits/csv_name.hpp" #include "armadillo_bits/diskio_bones.hpp" #include "armadillo_bits/wall_clock_bones.hpp" #include "armadillo_bits/running_stat_bones.hpp" @@ -212,6 +234,7 @@ namespace arma #include "armadillo_bits/mtGlue_bones.hpp" #include "armadillo_bits/SpGlue_bones.hpp" #include "armadillo_bits/mtSpGlue_bones.hpp" + #include "armadillo_bits/SpToDGlue_bones.hpp" #include "armadillo_bits/GlueCube_bones.hpp" #include "armadillo_bits/eGlueCube_bones.hpp" @@ -220,14 +243,16 @@ namespace arma #include "armadillo_bits/eop_core_bones.hpp" #include "armadillo_bits/eglue_core_bones.hpp" - #include "armadillo_bits/GenSpecialiser.hpp" #include "armadillo_bits/Gen_bones.hpp" #include "armadillo_bits/GenCube_bones.hpp" #include "armadillo_bits/op_diagmat_bones.hpp" #include "armadillo_bits/op_diagvec_bones.hpp" #include "armadillo_bits/op_dot_bones.hpp" - #include "armadillo_bits/op_inv_bones.hpp" + #include "armadillo_bits/op_det_bones.hpp" + #include "armadillo_bits/op_log_det_bones.hpp" + #include "armadillo_bits/op_inv_gen_bones.hpp" + #include "armadillo_bits/op_inv_spd_bones.hpp" #include "armadillo_bits/op_htrans_bones.hpp" #include "armadillo_bits/op_max_bones.hpp" #include "armadillo_bits/op_min_bones.hpp" @@ -279,6 +304,8 @@ namespace arma #include "armadillo_bits/op_nonzeros_bones.hpp" #include "armadillo_bits/op_diff_bones.hpp" #include "armadillo_bits/op_norm_bones.hpp" + #include "armadillo_bits/op_vecnorm_bones.hpp" + #include "armadillo_bits/op_norm2est_bones.hpp" #include "armadillo_bits/op_sqrtmat_bones.hpp" #include "armadillo_bits/op_logmat_bones.hpp" #include "armadillo_bits/op_range_bones.hpp" @@ -286,10 +313,16 @@ namespace arma #include "armadillo_bits/op_wishrnd_bones.hpp" #include "armadillo_bits/op_roots_bones.hpp" #include "armadillo_bits/op_cond_bones.hpp" + #include "armadillo_bits/op_rcond_bones.hpp" #include "armadillo_bits/op_sp_plus_bones.hpp" #include "armadillo_bits/op_sp_minus_bones.hpp" + #include "armadillo_bits/op_powmat_bones.hpp" + #include "armadillo_bits/op_rank_bones.hpp" + #include "armadillo_bits/op_row_as_mat_bones.hpp" + #include "armadillo_bits/op_col_as_mat_bones.hpp" #include "armadillo_bits/glue_times_bones.hpp" + #include "armadillo_bits/glue_times_misc_bones.hpp" #include "armadillo_bits/glue_mixed_bones.hpp" #include "armadillo_bits/glue_cov_bones.hpp" #include "armadillo_bits/glue_cor_bones.hpp" @@ -312,6 +345,8 @@ namespace arma #include "armadillo_bits/glue_intersect_bones.hpp" #include "armadillo_bits/glue_affmul_bones.hpp" #include "armadillo_bits/glue_mvnrnd_bones.hpp" + #include "armadillo_bits/glue_quantile_bones.hpp" + #include "armadillo_bits/glue_powext_bones.hpp" #include "armadillo_bits/gmm_misc_bones.hpp" #include "armadillo_bits/gmm_diag_bones.hpp" @@ -332,8 +367,9 @@ namespace arma #include "armadillo_bits/spop_reverse_bones.hpp" #include "armadillo_bits/spop_repmat_bones.hpp" #include "armadillo_bits/spop_vectorise_bones.hpp" + #include "armadillo_bits/spop_norm_bones.hpp" + #include "armadillo_bits/spop_vecnorm_bones.hpp" - #include "armadillo_bits/spglue_elem_helper_bones.hpp" #include "armadillo_bits/spglue_plus_bones.hpp" #include "armadillo_bits/spglue_minus_bones.hpp" #include "armadillo_bits/spglue_schur_bones.hpp" @@ -345,13 +381,17 @@ namespace arma #include "armadillo_bits/spglue_merge_bones.hpp" #include "armadillo_bits/spglue_relational_bones.hpp" + #include "armadillo_bits/spsolve_factoriser_bones.hpp" + #if defined(ARMA_USE_NEWARP) #include "armadillo_bits/newarp_EigsSelect.hpp" #include "armadillo_bits/newarp_DenseGenMatProd_bones.hpp" #include "armadillo_bits/newarp_SparseGenMatProd_bones.hpp" + #include "armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp" #include "armadillo_bits/newarp_DoubleShiftQR_bones.hpp" #include "armadillo_bits/newarp_GenEigsSolver_bones.hpp" #include "armadillo_bits/newarp_SymEigsSolver_bones.hpp" + #include "armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp" #include "armadillo_bits/newarp_TridiagEigen_bones.hpp" #include "armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp" #include "armadillo_bits/newarp_UpperHessenbergQR_bones.hpp" @@ -432,6 +472,7 @@ namespace arma #include "armadillo_bits/fn_diagmat.hpp" #include "armadillo_bits/fn_diagvec.hpp" #include "armadillo_bits/fn_inv.hpp" + #include "armadillo_bits/fn_inv_sympd.hpp" #include "armadillo_bits/fn_trace.hpp" #include "armadillo_bits/fn_trans.hpp" #include "armadillo_bits/fn_det.hpp" @@ -451,6 +492,7 @@ namespace arma #include "armadillo_bits/fn_elem.hpp" #include "armadillo_bits/fn_approx_equal.hpp" #include "armadillo_bits/fn_norm.hpp" + #include "armadillo_bits/fn_vecnorm.hpp" #include "armadillo_bits/fn_dot.hpp" #include "armadillo_bits/fn_randu.hpp" #include "armadillo_bits/fn_randn.hpp" @@ -491,10 +533,11 @@ namespace arma #include "armadillo_bits/fn_trunc_log.hpp" #include "armadillo_bits/fn_toeplitz.hpp" #include "armadillo_bits/fn_trimat.hpp" + #include "armadillo_bits/fn_trimat_ind.hpp" #include "armadillo_bits/fn_cumsum.hpp" #include "armadillo_bits/fn_cumprod.hpp" #include "armadillo_bits/fn_symmat.hpp" - #include "armadillo_bits/fn_syl_lyap.hpp" + #include "armadillo_bits/fn_sylvester.hpp" #include "armadillo_bits/fn_hist.hpp" #include "armadillo_bits/fn_histc.hpp" #include "armadillo_bits/fn_unique.hpp" @@ -508,7 +551,7 @@ namespace arma #include "armadillo_bits/fn_inplace_trans.hpp" #include "armadillo_bits/fn_randi.hpp" #include "armadillo_bits/fn_randg.hpp" - #include "armadillo_bits/fn_cond.hpp" + #include "armadillo_bits/fn_cond_rcond.hpp" #include "armadillo_bits/fn_normalise.hpp" #include "armadillo_bits/fn_clamp.hpp" #include "armadillo_bits/fn_expmat.hpp" @@ -528,12 +571,17 @@ namespace arma #include "armadillo_bits/fn_polyval.hpp" #include "armadillo_bits/fn_intersect.hpp" #include "armadillo_bits/fn_normpdf.hpp" + #include "armadillo_bits/fn_log_normpdf.hpp" #include "armadillo_bits/fn_normcdf.hpp" #include "armadillo_bits/fn_mvnrnd.hpp" #include "armadillo_bits/fn_chi2rnd.hpp" #include "armadillo_bits/fn_wishrnd.hpp" #include "armadillo_bits/fn_roots.hpp" #include "armadillo_bits/fn_randperm.hpp" + #include "armadillo_bits/fn_quantile.hpp" + #include "armadillo_bits/fn_powmat.hpp" + #include "armadillo_bits/fn_powext.hpp" + #include "armadillo_bits/fn_diags_spdiags.hpp" #include "armadillo_bits/fn_speye.hpp" #include "armadillo_bits/fn_spones.hpp" @@ -548,9 +596,10 @@ namespace arma // misc stuff #include "armadillo_bits/hdf5_misc.hpp" - #include "armadillo_bits/fft_engine.hpp" + #include "armadillo_bits/fft_engine_kissfft.hpp" + #include "armadillo_bits/fft_engine_fftw3.hpp" #include "armadillo_bits/band_helper.hpp" - #include "armadillo_bits/sympd_helper.hpp" + #include "armadillo_bits/sym_helper.hpp" #include "armadillo_bits/trimat_helper.hpp" // @@ -579,6 +628,7 @@ namespace arma #include "armadillo_bits/GlueCube_meat.hpp" #include "armadillo_bits/SpGlue_meat.hpp" #include "armadillo_bits/mtSpGlue_meat.hpp" + #include "armadillo_bits/SpToDGlue_meat.hpp" #include "armadillo_bits/eOp_meat.hpp" #include "armadillo_bits/eOpCube_meat.hpp" @@ -634,6 +684,7 @@ namespace arma #include "armadillo_bits/SpRow_meat.hpp" #include "armadillo_bits/SpSubview_meat.hpp" #include "armadillo_bits/SpSubview_iterators_meat.hpp" + #include "armadillo_bits/SpSubview_col_list_meat.hpp" #include "armadillo_bits/spdiagview_meat.hpp" #include "armadillo_bits/MapMat_meat.hpp" @@ -645,7 +696,10 @@ namespace arma #include "armadillo_bits/op_diagmat_meat.hpp" #include "armadillo_bits/op_diagvec_meat.hpp" #include "armadillo_bits/op_dot_meat.hpp" - #include "armadillo_bits/op_inv_meat.hpp" + #include "armadillo_bits/op_det_meat.hpp" + #include "armadillo_bits/op_log_det_meat.hpp" + #include "armadillo_bits/op_inv_gen_meat.hpp" + #include "armadillo_bits/op_inv_spd_meat.hpp" #include "armadillo_bits/op_htrans_meat.hpp" #include "armadillo_bits/op_max_meat.hpp" #include "armadillo_bits/op_index_max_meat.hpp" @@ -697,6 +751,8 @@ namespace arma #include "armadillo_bits/op_nonzeros_meat.hpp" #include "armadillo_bits/op_diff_meat.hpp" #include "armadillo_bits/op_norm_meat.hpp" + #include "armadillo_bits/op_vecnorm_meat.hpp" + #include "armadillo_bits/op_norm2est_meat.hpp" #include "armadillo_bits/op_sqrtmat_meat.hpp" #include "armadillo_bits/op_logmat_meat.hpp" #include "armadillo_bits/op_range_meat.hpp" @@ -704,10 +760,16 @@ namespace arma #include "armadillo_bits/op_wishrnd_meat.hpp" #include "armadillo_bits/op_roots_meat.hpp" #include "armadillo_bits/op_cond_meat.hpp" + #include "armadillo_bits/op_rcond_meat.hpp" #include "armadillo_bits/op_sp_plus_meat.hpp" #include "armadillo_bits/op_sp_minus_meat.hpp" + #include "armadillo_bits/op_powmat_meat.hpp" + #include "armadillo_bits/op_rank_meat.hpp" + #include "armadillo_bits/op_row_as_mat_meat.hpp" + #include "armadillo_bits/op_col_as_mat_meat.hpp" #include "armadillo_bits/glue_times_meat.hpp" + #include "armadillo_bits/glue_times_misc_meat.hpp" #include "armadillo_bits/glue_mixed_meat.hpp" #include "armadillo_bits/glue_cov_meat.hpp" #include "armadillo_bits/glue_cor_meat.hpp" @@ -730,6 +792,8 @@ namespace arma #include "armadillo_bits/glue_intersect_meat.hpp" #include "armadillo_bits/glue_affmul_meat.hpp" #include "armadillo_bits/glue_mvnrnd_meat.hpp" + #include "armadillo_bits/glue_quantile_meat.hpp" + #include "armadillo_bits/glue_powext_meat.hpp" #include "armadillo_bits/gmm_misc_meat.hpp" #include "armadillo_bits/gmm_diag_meat.hpp" @@ -750,8 +814,9 @@ namespace arma #include "armadillo_bits/spop_reverse_meat.hpp" #include "armadillo_bits/spop_repmat_meat.hpp" #include "armadillo_bits/spop_vectorise_meat.hpp" + #include "armadillo_bits/spop_norm_meat.hpp" + #include "armadillo_bits/spop_vecnorm_meat.hpp" - #include "armadillo_bits/spglue_elem_helper_meat.hpp" #include "armadillo_bits/spglue_plus_meat.hpp" #include "armadillo_bits/spglue_minus_meat.hpp" #include "armadillo_bits/spglue_schur_meat.hpp" @@ -763,14 +828,18 @@ namespace arma #include "armadillo_bits/spglue_merge_meat.hpp" #include "armadillo_bits/spglue_relational_meat.hpp" + #include "armadillo_bits/spsolve_factoriser_meat.hpp" + #if defined(ARMA_USE_NEWARP) #include "armadillo_bits/newarp_cx_attrib.hpp" #include "armadillo_bits/newarp_SortEigenvalue.hpp" #include "armadillo_bits/newarp_DenseGenMatProd_meat.hpp" #include "armadillo_bits/newarp_SparseGenMatProd_meat.hpp" + #include "armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp" #include "armadillo_bits/newarp_DoubleShiftQR_meat.hpp" #include "armadillo_bits/newarp_GenEigsSolver_meat.hpp" #include "armadillo_bits/newarp_SymEigsSolver_meat.hpp" + #include "armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp" #include "armadillo_bits/newarp_TridiagEigen_meat.hpp" #include "armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp" #include "armadillo_bits/newarp_UpperHessenbergQR_meat.hpp" diff --git a/src/armadillo_bits/BaseCube_bones.hpp b/src/armadillo_bits/BaseCube_bones.hpp index d5452d1d..15d6a4c8 100644 --- a/src/armadillo_bits/BaseCube_bones.hpp +++ b/src/armadillo_bits/BaseCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,14 +24,14 @@ template struct BaseCube_eval_Cube { - arma_inline const derived& eval() const; + arma_warn_unused arma_inline const derived& eval() const; }; template struct BaseCube_eval_expr { - arma_inline Cube eval() const; //!< force the immediate evaluation of a delayed expression + arma_warn_unused inline Cube eval() const; //!< force the immediate evaluation of a delayed expression }; @@ -37,7 +39,7 @@ template struct BaseCube_eval {}; template -struct BaseCube_eval { typedef BaseCube_eval_Cube result; }; +struct BaseCube_eval { typedef BaseCube_eval_Cube result; }; template struct BaseCube_eval { typedef BaseCube_eval_expr result; }; @@ -57,16 +59,26 @@ struct BaseCube arma_cold inline void raw_print( const std::string extra_text = "") const; arma_cold inline void raw_print(std::ostream& user_stream, const std::string extra_text = "") const; - inline arma_warn_unused elem_type min() const; - inline arma_warn_unused elem_type max() const; + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; + + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; + + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_finite() const; - inline arma_warn_unused uword index_min() const; - inline arma_warn_unused uword index_max() const; + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; - inline arma_warn_unused bool is_empty() const; - inline arma_warn_unused bool is_finite() const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline const CubeToMatOp row_as_mat(const uword in_row) const; + arma_warn_unused inline const CubeToMatOp col_as_mat(const uword in_col) const; }; diff --git a/src/armadillo_bits/BaseCube_meat.hpp b/src/armadillo_bits/BaseCube_meat.hpp index 8590f000..2d0df91e 100644 --- a/src/armadillo_bits/BaseCube_meat.hpp +++ b/src/armadillo_bits/BaseCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,60 +32,145 @@ BaseCube::get_ref() const template -arma_cold inline void BaseCube::print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_cube tmp( (*this).get_ref() ); - tmp.M.impl_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); } template -arma_cold inline void BaseCube::print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_cube tmp( (*this).get_ref() ); - tmp.M.impl_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); } template -arma_cold inline void BaseCube::raw_print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_cube tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); } template -arma_cold inline void BaseCube::raw_print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_cube tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); + } + + + +template +inline +void +BaseCube::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); + } + + + +template +inline +void +BaseCube::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_cube tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); } template inline -arma_warn_unused elem_type BaseCube::min() const { @@ -94,7 +181,6 @@ BaseCube::min() const template inline -arma_warn_unused elem_type BaseCube::max() const { @@ -105,7 +191,6 @@ BaseCube::max() const template inline -arma_warn_unused uword BaseCube::index_min() const { @@ -129,7 +214,6 @@ BaseCube::index_min() const template inline -arma_warn_unused uword BaseCube::index_max() const { @@ -153,7 +237,58 @@ BaseCube::index_max() const template inline -arma_warn_unused +bool +BaseCube::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + if(ProxyCube::use_at || is_Cube::stored_type>::value) + { + const unwrap_cube U( (*this).get_ref() ); + + return arrayops::is_zero( U.M.memptr(), U.M.n_elem, tol ); + } + + const ProxyCube P( (*this).get_ref() ); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { return false; } + + const typename ProxyCube::ea_type Pea = P.get_ea(); + + if(is_cx::yes) + { + for(uword i=0; i tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + } + } + else // not complex + { + for(uword i=0; i < n_elem; ++i) + { + if(eop_aux::arma_abs(Pea[i]) > tol) { return false; } + } + } + + return true; + } + + + +template +inline bool BaseCube::is_empty() const { @@ -168,30 +303,33 @@ BaseCube::is_empty() const template inline -arma_warn_unused bool BaseCube::is_finite() const { arma_extra_debug_sigprint(); - const ProxyCube P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } if(is_Cube::stored_type>::value) { - const unwrap_cube::stored_type> U(P.Q); + const unwrap_cube U( (*this).get_ref() ); return arrayops::is_finite( U.M.memptr(), U.M.n_elem ); } - - const uword n_r = P.get_n_rows(); - const uword n_c = P.get_n_cols(); - const uword n_s = P.get_n_slices(); - - for(uword s=0; s P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s::is_finite() const template inline -arma_warn_unused bool BaseCube::has_inf() const { arma_extra_debug_sigprint(); - const ProxyCube P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } if(is_Cube::stored_type>::value) { - const unwrap_cube::stored_type> U(P.Q); + const unwrap_cube U( (*this).get_ref() ); return arrayops::has_inf( U.M.memptr(), U.M.n_elem ); } - - const uword n_r = P.get_n_rows(); - const uword n_c = P.get_n_cols(); - const uword n_s = P.get_n_slices(); - - for(uword s=0; s P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s::has_inf() const template inline -arma_warn_unused bool BaseCube::has_nan() const { arma_extra_debug_sigprint(); - const ProxyCube P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } if(is_Cube::stored_type>::value) { - const unwrap_cube::stored_type> U(P.Q); + const unwrap_cube U( (*this).get_ref() ); return arrayops::has_nan( U.M.memptr(), U.M.n_elem ); } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s +inline +bool +BaseCube::has_nonfinite() const + { + arma_extra_debug_sigprint(); - const uword n_r = P.get_n_rows(); - const uword n_c = P.get_n_cols(); - const uword n_s = P.get_n_slices(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } - for(uword s=0; s::stored_type>::value) { - if(arma_isnan(P.at(r,c,s))) { return true; } + const unwrap_cube U( (*this).get_ref() ); + + return (arrayops::is_finite( U.M.memptr(), U.M.n_elem ) == false); + } + else + { + const ProxyCube P( (*this).get_ref() ); + + const uword n_r = P.get_n_rows(); + const uword n_c = P.get_n_cols(); + const uword n_s = P.get_n_slices(); + + for(uword s=0; s::has_nan() const +template +inline +const CubeToMatOp +BaseCube::row_as_mat(const uword in_row) const + { + return CubeToMatOp( (*this).get_ref(), in_row ); + } + + + +template +inline +const CubeToMatOp +BaseCube::col_as_mat(const uword in_col) const + { + return CubeToMatOp( (*this).get_ref(), in_col ); + } + + + // // extra functions defined in BaseCube_eval_Cube @@ -284,7 +484,7 @@ BaseCube_eval_Cube::eval() const // extra functions defined in BaseCube_eval_expr template -arma_inline +inline Cube BaseCube_eval_expr::eval() const { diff --git a/src/armadillo_bits/Base_bones.hpp b/src/armadillo_bits/Base_bones.hpp index 66f4fcf5..ac947856 100644 --- a/src/armadillo_bits/Base_bones.hpp +++ b/src/armadillo_bits/Base_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,13 +24,10 @@ template struct Base_extra_yes { - arma_inline const Op i() const; //!< matrix inverse - - arma_deprecated inline const Op i(const bool ) const; //!< kept only for compatibility with old user code - arma_deprecated inline const Op i(const char*) const; //!< kept only for compatibility with old user code + arma_warn_unused inline const Op i() const; //!< matrix inverse - inline arma_warn_unused bool is_sympd() const; - inline arma_warn_unused bool is_sympd(typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_sympd() const; + arma_warn_unused inline bool is_sympd(typename get_pod_type::result tol) const; }; @@ -52,14 +51,14 @@ struct Base_extra { typedef Base_extra_no struct Base_eval_Mat { - arma_inline const derived& eval() const; + arma_warn_unused arma_inline const derived& eval() const; }; template struct Base_eval_expr { - arma_inline Mat eval() const; //!< force the immediate evaluation of a delayed expression + arma_warn_unused inline Mat eval() const; //!< force the immediate evaluation of a delayed expression }; @@ -77,18 +76,18 @@ struct Base_eval { typedef Base_eval_expr struct Base_trans_cx { - arma_inline const Op t() const; - arma_inline const Op ht() const; - arma_inline const Op st() const; // simple transpose: no complex conjugates + arma_warn_unused arma_inline const Op t() const; + arma_warn_unused arma_inline const Op ht() const; + arma_warn_unused arma_inline const Op st() const; // simple transpose: no complex conjugates }; template struct Base_trans_default { - arma_inline const Op t() const; - arma_inline const Op ht() const; - arma_inline const Op st() const; // return op_htrans instead of op_strans, as it's handled better by matrix multiplication code + arma_warn_unused arma_inline const Op t() const; + arma_warn_unused arma_inline const Op ht() const; + arma_warn_unused arma_inline const Op st() const; // return op_htrans instead of op_strans, as it's handled better by matrix multiplication code }; @@ -105,7 +104,7 @@ struct Base_trans { typedef Base_trans_default result; //! Class for static polymorphism, modelled after the "Curiously Recurring Template Pattern" (CRTP). //! Used for type-safe downcasting in functions that restrict their input(s) to be classes that are -//! derived from Base (e.g. Mat, Op, Glue, diagview, subview). +//! derived from Base (eg. Mat, Op, Glue, diagview, subview). //! A Base object can be converted to a Mat object by the unwrap class. template @@ -122,8 +121,11 @@ struct Base arma_cold inline void raw_print( const std::string extra_text = "") const; arma_cold inline void raw_print(std::ostream& user_stream, const std::string extra_text = "") const; - inline arma_warn_unused elem_type min() const; - inline arma_warn_unused elem_type max() const; + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; inline elem_type min(uword& index_of_min_val) const; inline elem_type max(uword& index_of_max_val) const; @@ -131,29 +133,33 @@ struct Base inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; - inline arma_warn_unused uword index_min() const; - inline arma_warn_unused uword index_max() const; + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; + + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; - inline arma_warn_unused bool is_symmetric() const; - inline arma_warn_unused bool is_symmetric(const typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; - inline arma_warn_unused bool is_hermitian() const; - inline arma_warn_unused bool is_hermitian(const typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_trimatu() const; + arma_warn_unused inline bool is_trimatl() const; + arma_warn_unused inline bool is_diagmat() const; + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_square() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_finite() const; - inline arma_warn_unused bool is_trimatu() const; - inline arma_warn_unused bool is_trimatl() const; - inline arma_warn_unused bool is_diagmat() const; - inline arma_warn_unused bool is_empty() const; - inline arma_warn_unused bool is_square() const; - inline arma_warn_unused bool is_vec() const; - inline arma_warn_unused bool is_colvec() const; - inline arma_warn_unused bool is_rowvec() const; - inline arma_warn_unused bool is_finite() const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; - arma_inline const Op as_col() const; - arma_inline const Op as_row() const; + arma_warn_unused inline const Op as_col() const; + arma_warn_unused inline const Op as_row() const; }; diff --git a/src/armadillo_bits/Base_meat.hpp b/src/armadillo_bits/Base_meat.hpp index d2eb000b..646f33aa 100644 --- a/src/armadillo_bits/Base_meat.hpp +++ b/src/armadillo_bits/Base_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,60 +32,145 @@ Base::get_ref() const template -arma_cold inline void Base::print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const quasi_unwrap tmp( (*this).get_ref() ); - tmp.M.impl_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); } template -arma_cold inline void Base::print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const quasi_unwrap tmp( (*this).get_ref() ); - tmp.M.impl_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); } template -arma_cold inline void Base::raw_print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const quasi_unwrap tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); } template -arma_cold inline void Base::raw_print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); + } + + + +template +inline +void +Base::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + const quasi_unwrap tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); + } + + + +template +inline +void +Base::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const quasi_unwrap tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); } template inline -arma_warn_unused elem_type Base::min() const { @@ -94,7 +181,6 @@ Base::min() const template inline -arma_warn_unused elem_type Base::max() const { @@ -171,7 +257,6 @@ Base::max(uword& row_of_max_val, uword& col_of_max_val) const template inline -arma_warn_unused uword Base::index_min() const { @@ -195,7 +280,6 @@ Base::index_min() const template inline -arma_warn_unused uword Base::index_max() const { @@ -219,7 +303,6 @@ Base::index_max() const template inline -arma_warn_unused bool Base::is_symmetric() const { @@ -260,7 +343,6 @@ Base::is_symmetric() const template inline -arma_warn_unused bool Base::is_symmetric(const typename get_pod_type::result tol) const { @@ -292,7 +374,6 @@ Base::is_symmetric(const typename get_pod_type::re template inline -arma_warn_unused bool Base::is_hermitian() const { @@ -345,7 +426,6 @@ Base::is_hermitian() const template inline -arma_warn_unused bool Base::is_hermitian(const typename get_pod_type::result tol) const { @@ -377,7 +457,58 @@ Base::is_hermitian(const typename get_pod_type::re template inline -arma_warn_unused +bool +Base::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + if(Proxy::use_at || is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return arrayops::is_zero( U.M.memptr(), U.M.n_elem, tol ); + } + + const Proxy P( (*this).get_ref() ); + + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { return false; } + + const typename Proxy::ea_type Pea = P.get_ea(); + + if(is_cx::yes) + { + for(uword i=0; i tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + } + } + else // not complex + { + for(uword i=0; i tol) { return false; } + } + } + + return true; + } + + + +template +inline bool Base::is_trimatu() const { @@ -396,7 +527,6 @@ Base::is_trimatu() const template inline -arma_warn_unused bool Base::is_trimatl() const { @@ -415,7 +545,6 @@ Base::is_trimatl() const template inline -arma_warn_unused bool Base::is_diagmat() const { @@ -427,19 +556,25 @@ Base::is_diagmat() const if(A.n_elem <= 1) { return true; } + // NOTE: we're NOT assuming the matrix has a square size + const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; - const elem_type* A_colmem = A.memptr(); + const elem_type* A_mem = A.memptr(); + + if(A_mem[1] != elem_type(0)) { return false; } + + // if we got to this point, do a thorough check for(uword A_col=0; A_col < A_n_cols; ++A_col) { for(uword A_row=0; A_row < A_n_rows; ++A_row) { - if( (A_colmem[A_row] != elem_type(0)) && (A_row != A_col) ) { return false; } + if( (A_mem[A_row] != elem_type(0)) && (A_row != A_col) ) { return false; } } - A_colmem += A_n_rows; + A_mem += A_n_rows; } return true; @@ -449,7 +584,6 @@ Base::is_diagmat() const template inline -arma_warn_unused bool Base::is_empty() const { @@ -464,7 +598,6 @@ Base::is_empty() const template inline -arma_warn_unused bool Base::is_square() const { @@ -479,7 +612,6 @@ Base::is_square() const template inline -arma_warn_unused bool Base::is_vec() const { @@ -496,7 +628,6 @@ Base::is_vec() const template inline -arma_warn_unused bool Base::is_colvec() const { @@ -513,7 +644,6 @@ Base::is_colvec() const template inline -arma_warn_unused bool Base::is_rowvec() const { @@ -530,41 +660,44 @@ Base::is_rowvec() const template inline -arma_warn_unused bool Base::is_finite() const { arma_extra_debug_sigprint(); - const Proxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } if(is_Mat::stored_type>::value) { - const quasi_unwrap::stored_type> U(P.Q); + const quasi_unwrap U( (*this).get_ref() ); return arrayops::is_finite( U.M.memptr(), U.M.n_elem ); } - - if(Proxy::use_at == false) + else { - const typename Proxy::ea_type Pea = P.get_ea(); - - const uword n_elem = P.get_n_elem(); + const Proxy P( (*this).get_ref() ); - for(uword i=0; i::use_at == false) { - if(arma_isfinite(Pea[i]) == false) { return false; } + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i::is_finite() const template inline -arma_warn_unused bool Base::has_inf() const { arma_extra_debug_sigprint(); - const Proxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } if(is_Mat::stored_type>::value) { - const quasi_unwrap::stored_type> U(P.Q); + const quasi_unwrap U( (*this).get_ref() ); return arrayops::has_inf( U.M.memptr(), U.M.n_elem ); } - - if(Proxy::use_at == false) + else { - const typename Proxy::ea_type Pea = P.get_ea(); + const Proxy P( (*this).get_ref() ); - const uword n_elem = P.get_n_elem(); - - for(uword i=0; i::use_at == false) { - if(arma_isinf(Pea[i])) { return true; } + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i::has_inf() const template inline -arma_warn_unused bool Base::has_nan() const { arma_extra_debug_sigprint(); - const Proxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } if(is_Mat::stored_type>::value) { - const quasi_unwrap::stored_type> U(P.Q); + const quasi_unwrap U( (*this).get_ref() ); return arrayops::has_nan( U.M.memptr(), U.M.n_elem ); } - - if(Proxy::use_at == false) + else { - const typename Proxy::ea_type Pea = P.get_ea(); + const Proxy P( (*this).get_ref() ); - const uword n_elem = P.get_n_elem(); - - for(uword i=0; i::use_at == false) { - if(arma_isnan(Pea[i])) { return true; } + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i +inline +bool +Base::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap U( (*this).get_ref() ); + + return (arrayops::is_finite( U.M.memptr(), U.M.n_elem ) == false); + } else { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + const Proxy P( (*this).get_ref() ); - for(uword col=0; col::use_at == false) { - if(arma_isnan(P.at(row,col))) { return true; } + const typename Proxy::ea_type Pea = P.get_ea(); + + const uword n_elem = P.get_n_elem(); + + for(uword i=0; i::has_nan() const template -arma_inline +inline const Op Base::as_col() const { @@ -674,7 +861,7 @@ Base::as_col() const template -arma_inline +inline const Op Base::as_row() const { @@ -687,44 +874,17 @@ Base::as_row() const // extra functions defined in Base_extra_yes template -arma_inline -const Op -Base_extra_yes::i() const - { - return Op(static_cast(*this)); - } - - - -template -arma_deprecated -inline -const Op -Base_extra_yes::i(const bool) const // argument kept only for compatibility with old user code - { - // arma_debug_warn(".i(bool) is deprecated and will be removed; change to .i()"); - - return Op(static_cast(*this)); - } - - - -template -arma_deprecated inline -const Op -Base_extra_yes::i(const char*) const // argument kept only for compatibility with old user code +const Op +Base_extra_yes::i() const { - // arma_debug_warn(".i(char*) is deprecated and will be removed; change to .i()"); - - return Op(static_cast(*this)); + return Op(static_cast(*this)); } template inline -arma_warn_unused bool Base_extra_yes::is_sympd() const { @@ -750,7 +910,6 @@ Base_extra_yes::is_sympd() const template inline -arma_warn_unused bool Base_extra_yes::is_sympd(typename get_pod_type::result tol) const { @@ -792,7 +951,7 @@ Base_eval_Mat::eval() const // extra functions defined in Base_eval_expr template -arma_inline +inline Mat Base_eval_expr::eval() const { diff --git a/src/armadillo_bits/Col_bones.hpp b/src/armadillo_bits/Col_bones.hpp index 1bc526a2..b3f0ab68 100644 --- a/src/armadillo_bits/Col_bones.hpp +++ b/src/armadillo_bits/Col_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,20 +29,29 @@ class Col : public Mat typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_col = true; - static const bool is_row = false; - static const bool is_xvec = false; + static constexpr bool is_col = true; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; + + inline Col(); + inline Col(const Col& X); - inline Col(); - inline Col(const Col& X); inline explicit Col(const uword n_elem); inline explicit Col(const uword in_rows, const uword in_cols); inline explicit Col(const SizeMat& s); + template inline explicit Col(const uword n_elem, const arma_initmode_indicator&); + template inline explicit Col(const uword in_rows, const uword in_cols, const arma_initmode_indicator&); + template inline explicit Col(const SizeMat& s, const arma_initmode_indicator&); + template inline Col(const uword n_elem, const fill::fill_class& f); template inline Col(const uword in_rows, const uword in_cols, const fill::fill_class& f); template inline Col(const SizeMat& s, const fill::fill_class& f); + inline Col(const uword N, const fill::scalar_holder f); + inline Col(const uword in_rows, const uword in_cols, const fill::scalar_holder f); + inline Col(const SizeMat& s, const fill::scalar_holder f); + inline Col(const char* text); inline Col& operator=(const char* text); @@ -50,13 +61,14 @@ class Col : public Mat inline Col(const std::vector& x); inline Col& operator=(const std::vector& x); - #if defined(ARMA_USE_CXX11) inline Col(const std::initializer_list& list); inline Col& operator=(const std::initializer_list& list); inline Col(Col&& m); inline Col& operator=(Col&& m); - #endif + + // inline Col(Mat&& m); + // inline Col& operator=(Mat&& m); inline Col& operator=(const eT val); inline Col& operator=(const Col& m); @@ -79,13 +91,13 @@ class Col : public Mat inline Col(const subview_cube& X); inline Col& operator=(const subview_cube& X); - inline mat_injector operator<<(const eT val); + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); - arma_inline const Op,op_htrans> t() const; - arma_inline const Op,op_htrans> ht() const; - arma_inline const Op,op_strans> st() const; + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; - arma_inline const Op,op_strans> as_row() const; + arma_warn_unused arma_inline const Op,op_strans> as_row() const; arma_inline subview_col row(const uword row_num); arma_inline const subview_col row(const uword row_num) const; @@ -129,15 +141,17 @@ class Col : public Mat template inline void shed_rows(const Base& indices); - inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero = true); + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + inline void insert_rows(const uword row_num, const uword N); + template inline void insert_rows(const uword row_num, const Base& X); - arma_inline arma_warn_unused eT& at(const uword i); - arma_inline arma_warn_unused const eT& at(const uword i) const; + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; - arma_inline arma_warn_unused eT& at(const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at(const uword in_row, const uword in_col) const; typedef eT* row_iterator; @@ -160,7 +174,7 @@ class Col : public Mat public: - #ifdef ARMA_EXTRA_COL_PROTO + #if defined(ARMA_EXTRA_COL_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_COL_PROTO) #endif }; @@ -173,7 +187,7 @@ class Col::fixed : public Col { private: - static const bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; @@ -185,9 +199,9 @@ class Col::fixed : public Col typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_col = true; - static const bool is_row = false; - static const bool is_xvec = false; + static constexpr bool is_col = true; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; static const uword n_rows; // value provided below the class definition static const uword n_cols; // value provided below the class definition @@ -197,6 +211,7 @@ class Col::fixed : public Col arma_inline fixed(const fixed& X); inline fixed(const subview_cube& X); + inline fixed(const fill::scalar_holder f); template inline fixed(const fill::fill_class& f); template inline fixed(const Base& A); template inline fixed(const Base& A, const Base& B); @@ -215,10 +230,8 @@ class Col::fixed : public Col using Col::operator(); - #if defined(ARMA_USE_CXX11) - inline fixed(const std::initializer_list& list); - inline Col& operator=(const std::initializer_list& list); - #endif + inline fixed(const std::initializer_list& list); + inline Col& operator=(const std::initializer_list& list); arma_inline Col& operator=(const fixed& X); @@ -227,30 +240,30 @@ class Col::fixed : public Col template inline Col& operator=(const eGlue& X); #endif - arma_inline const Op< Col_fixed_type, op_htrans > t() const; - arma_inline const Op< Col_fixed_type, op_htrans > ht() const; - arma_inline const Op< Col_fixed_type, op_strans > st() const; + arma_warn_unused arma_inline const Op< Col_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Col_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Col_fixed_type, op_strans > st() const; - arma_inline arma_warn_unused const eT& at_alt (const uword i) const; + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; - arma_inline arma_warn_unused eT& operator[] (const uword i); - arma_inline arma_warn_unused const eT& operator[] (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword i); - arma_inline arma_warn_unused const eT& at (const uword i) const; - arma_inline arma_warn_unused eT& operator() (const uword i); - arma_inline arma_warn_unused const eT& operator() (const uword i) const; + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT* memptr(); - arma_inline arma_warn_unused const eT* memptr() const; + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; - arma_hot inline const Col& fill(const eT val); - arma_hot inline const Col& zeros(); - arma_hot inline const Col& ones(); + inline const Col& fill(const eT val); + inline const Col& zeros(); + inline const Col& ones(); }; diff --git a/src/armadillo_bits/Col_meat.hpp b/src/armadillo_bits/Col_meat.hpp index 48920a6f..a6a1945f 100644 --- a/src/armadillo_bits/Col_meat.hpp +++ b/src/armadillo_bits/Col_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -48,6 +50,12 @@ Col::Col(const uword in_n_elem) : Mat(arma_vec_indicator(), in_n_elem, 1, 1) { arma_extra_debug_sigprint(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -60,6 +68,12 @@ Col::Col(const uword in_n_rows, const uword in_n_cols) arma_extra_debug_sigprint(); Mat::init_warm(in_n_rows, in_n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -72,6 +86,70 @@ Col::Col(const SizeMat& s) arma_extra_debug_sigprint(); Mat::init_warm(s.n_rows, s.n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const uword in_n_elem, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Col::Col(const SizeMat& s, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Col::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -119,6 +197,46 @@ Col::Col(const SizeMat& s, const fill::fill_class& f) +template +inline +Col::Col(const uword in_n_elem, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), in_n_elem, 1, 1) + { + arma_extra_debug_sigprint(); + + (*this).fill(f.scalar); + } + + + +template +inline +Col::Col(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Col::Col(const SizeMat& s, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 1) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f.scalar); + } + + + template inline Col::Col(const char* text) @@ -193,10 +311,9 @@ Col::Col(const std::vector& x) { arma_extra_debug_sigprint_this(this); - if(x.size() > 0) - { - arrayops::copy( Mat::memptr(), &(x[0]), uword(x.size()) ); - } + const uword N = uword(x.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } } @@ -209,114 +326,160 @@ Col::operator=(const std::vector& x) { arma_extra_debug_sigprint(); - Mat::init_warm(uword(x.size()), 1); + const uword N = uword(x.size()); - if(x.size() > 0) - { - arrayops::copy( Mat::memptr(), &(x[0]), uword(x.size()) ); - } + Mat::init_warm(N, 1); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } return *this; } -#if defined(ARMA_USE_CXX11) +template +inline +Col::Col(const std::initializer_list& list) + : Mat(arma_vec_indicator(), uword(list.size()), 1, 1) + { + arma_extra_debug_sigprint_this(this); - template - inline - Col::Col(const std::initializer_list& list) - : Mat(arma_vec_indicator(), 1) - { - arma_extra_debug_sigprint(); - - (*this).operator=(list); - } + const uword N = uword(list.size()); + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + } + + + +template +inline +Col& +Col::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + const uword N = uword(list.size()); - template - inline - Col& - Col::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - Mat tmp(list); - - arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with column vector layout" ); - - access::rw(tmp.n_rows) = tmp.n_elem; - access::rw(tmp.n_cols) = 1; - - (*this).steal_mem(tmp); - - return *this; - } + Mat::init_warm(N, 1); + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + return *this; + } + + + +template +inline +Col::Col(Col&& X) + : Mat(arma_vec_indicator(), 1) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - template - inline - Col::Col(Col&& X) - : Mat(arma_vec_indicator(), 1) + access::rw(Mat::n_rows) = X.n_rows; + access::rw(Mat::n_cols) = 1; + access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::n_alloc) = X.n_alloc; + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - - access::rw(Mat::n_rows) = X.n_rows; - access::rw(Mat::n_cols) = 1; - access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::mem_state) = X.mem_state; + access::rw(Mat::mem) = X.mem; - if( ((X.mem_state == 0) && (X.n_elem > arma_config::mat_prealloc)) || (X.mem_state == 1) || (X.mem_state == 2) ) - { - access::rw(Mat::mem_state) = X.mem_state; - access::rw(Mat::mem) = X.mem; - - access::rw(X.n_rows) = 0; - access::rw(X.n_cols) = 1; - access::rw(X.n_elem) = 0; - access::rw(X.mem_state) = 0; - access::rw(X.mem) = 0; - } - else - { - (*this).init_cold(); - - arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); - - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) ) - { - access::rw(X.n_rows) = 0; - access::rw(X.n_cols) = 1; - access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; - } - } + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 1; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; } - - - - template - inline - Col& - Col::operator=(Col&& X) + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + (*this).init_cold(); - (*this).steal_mem(X); + arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) && (this != &X) ) + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) { - access::rw(X.n_rows) = 0; - access::rw(X.n_cols) = 1; - access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 1; + access::rw(X.n_elem) = 0; + access::rw(X.mem) = nullptr; } - - return *this; } + } + + + +template +inline +Col& +Col::operator=(Col&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); -#endif + (*this).steal_mem(X, true); + + return *this; + } + + + +// template +// inline +// Col::Col(Mat&& X) +// : Mat(arma_vec_indicator(), 1) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_cols != 1) { const Mat& XX = X; Mat::operator=(XX); return; } +// +// access::rw(Mat::n_rows) = X.n_rows; +// access::rw(Mat::n_cols) = 1; +// access::rw(Mat::n_elem) = X.n_elem; +// access::rw(Mat::n_alloc) = X.n_alloc; +// +// if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) +// { +// access::rw(Mat::mem_state) = X.mem_state; +// access::rw(Mat::mem) = X.mem; +// +// access::rw(X.n_rows) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.n_alloc) = 0; +// access::rw(X.mem_state) = 0; +// access::rw(X.mem) = nullptr; +// } +// else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) +// { +// (*this).init_cold(); +// +// arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); +// +// if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) +// { +// access::rw(X.n_rows) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.mem) = nullptr; +// } +// } +// } +// +// +// +// template +// inline +// Col& +// Col::operator=(Mat&& X) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_cols != 1) { const Mat& XX = X; Mat::operator=(XX); return *this; } +// +// (*this).steal_mem(X, true); +// +// return *this; +// } @@ -561,7 +724,7 @@ Col::row(const uword in_row1) { arma_extra_debug_sigprint(); - arma_debug_check( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used" ); return subview_col(*this, 0, in_row1, 1); } @@ -575,7 +738,7 @@ Col::row(const uword in_row1) const { arma_extra_debug_sigprint(); - arma_debug_check( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( (in_row1 >= Mat::n_rows), "Col::row(): indices out of bounds or incorrectly used" ); return subview_col(*this, 0, in_row1, 1); } @@ -589,7 +752,7 @@ Col::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -605,7 +768,7 @@ Col::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::rows(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -621,7 +784,7 @@ Col::subvec(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -637,7 +800,7 @@ Col::subvec(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows) ), "Col::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -685,7 +848,7 @@ Col::subvec(const span& row_span) const uword in_row2 = row_span.b; const uword subvec_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used" ); return subview_col(*this, 0, in_row1, subvec_n_rows); } @@ -707,7 +870,7 @@ Col::subvec(const span& row_span) const const uword in_row2 = row_span.b; const uword subvec_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ), "Col::subvec(): indices out of bounds or incorrectly used" ); return subview_col(*this, 0, in_row1, subvec_n_rows); } @@ -747,7 +910,7 @@ Col::subvec(const uword start_row, const SizeMat& s) arma_debug_check( (s.n_cols != 1), "Col::subvec(): given size does not specify a column vector" ); - arma_debug_check( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); return subview_col(*this, 0, start_row, s.n_rows); } @@ -763,7 +926,7 @@ Col::subvec(const uword start_row, const SizeMat& s) const arma_debug_check( (s.n_cols != 1), "Col::subvec(): given size does not specify a column vector" ); - arma_debug_check( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_row >= Mat::n_rows) || ((start_row + s.n_rows) > Mat::n_rows) ), "Col::subvec(): size out of bounds" ); return subview_col(*this, 0, start_row, s.n_rows); } @@ -777,7 +940,7 @@ Col::head(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_rows), "Col::head(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_rows), "Col::head(): size out of bounds" ); return subview_col(*this, 0, 0, N); } @@ -791,7 +954,7 @@ Col::head(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_rows), "Col::head(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_rows), "Col::head(): size out of bounds" ); return subview_col(*this, 0, 0, N); } @@ -805,7 +968,7 @@ Col::tail(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_rows), "Col::tail(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_rows), "Col::tail(): size out of bounds" ); const uword start_row = Mat::n_rows - N; @@ -821,7 +984,7 @@ Col::tail(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_rows), "Col::tail(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_rows), "Col::tail(): size out of bounds" ); const uword start_row = Mat::n_rows - N; @@ -886,7 +1049,7 @@ Col::shed_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= Mat::n_rows, "Col::shed_row(): index out of bounds"); + arma_debug_check_bounds( row_num >= Mat::n_rows, "Col::shed_row(): index out of bounds" ); shed_rows(row_num, row_num); } @@ -901,7 +1064,7 @@ Col::shed_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= Mat::n_rows), "Col::shed_rows(): indices out of bounds or incorrectly used" @@ -910,7 +1073,7 @@ Col::shed_rows(const uword in_row1, const uword in_row2) const uword n_keep_front = in_row1; const uword n_keep_back = Mat::n_rows - (in_row2 + 1); - Col X(n_keep_front + n_keep_back); + Col X(n_keep_front + n_keep_back, arma_nozeros_indicator()); eT* X_mem = X.memptr(); const eT* t_mem = (*this).memptr(); @@ -944,8 +1107,6 @@ Col::shed_rows(const Base& indices) -//! insert N rows at the specified row position, -//! optionally setting the elements of the inserted rows to zero template inline void @@ -953,38 +1114,48 @@ Col::insert_rows(const uword row_num, const uword N, const bool set_to_zero) { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Col::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_rows = Mat::n_rows; const uword A_n_rows = row_num; const uword B_n_rows = t_n_rows - row_num; // insertion at row_num == n_rows is in effect an append operation - arma_debug_check( (row_num > t_n_rows), "Col::insert_rows(): index out of bounds"); + arma_debug_check_bounds( (row_num > t_n_rows), "Col::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } - if(N > 0) + Col out(t_n_rows + N, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + const eT* t_mem = (*this).memptr(); + + if(A_n_rows > 0) { - Col out(t_n_rows + N); - - eT* out_mem = out.memptr(); - const eT* t_mem = (*this).memptr(); - - if(A_n_rows > 0) - { - arrayops::copy( out_mem, t_mem, A_n_rows ); - } - - if(B_n_rows > 0) - { - arrayops::copy( &(out_mem[row_num + N]), &(t_mem[row_num]), B_n_rows ); - } - - if(set_to_zero) - { - arrayops::inplace_set( &(out_mem[row_num]), eT(0), N ); - } - - Mat::steal_mem(out); + arrayops::copy( out_mem, t_mem, A_n_rows ); + } + + if(B_n_rows > 0) + { + arrayops::copy( &(out_mem[row_num + N]), &(t_mem[row_num]), B_n_rows ); } + + arrayops::fill_zeros( &(out_mem[row_num]), N ); + + Mat::steal_mem(out); } @@ -1006,7 +1177,6 @@ Col::insert_rows(const uword row_num, const Base& X) template arma_inline -arma_warn_unused eT& Col::at(const uword i) { @@ -1017,7 +1187,6 @@ Col::at(const uword i) template arma_inline -arma_warn_unused const eT& Col::at(const uword i) const { @@ -1028,7 +1197,6 @@ Col::at(const uword i) const template arma_inline -arma_warn_unused eT& Col::at(const uword in_row, const uword) { @@ -1039,7 +1207,6 @@ Col::at(const uword in_row, const uword) template arma_inline -arma_warn_unused const eT& Col::at(const uword in_row, const uword) const { @@ -1055,7 +1222,7 @@ Col::begin_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds" ); return Mat::memptr() + row_num; } @@ -1069,7 +1236,7 @@ Col::begin_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::begin_row(): index out of bounds" ); return Mat::memptr() + row_num; } @@ -1083,7 +1250,7 @@ Col::end_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds" ); return Mat::memptr() + row_num + 1; } @@ -1097,7 +1264,7 @@ Col::end_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Col::end_row(): index out of bounds" ); return Mat::memptr() + row_num + 1; } @@ -1111,6 +1278,15 @@ Col::fixed::fixed() : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) { arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Col::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } } @@ -1144,6 +1320,19 @@ Col::fixed::fixed(const subview_cube& X) +template +template +inline +Col::fixed::fixed(const fill::scalar_holder f) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + template template template @@ -1153,11 +1342,11 @@ Col::fixed::fixed(const fill::fill_class&) { arma_extra_debug_sigprint_this(this); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).eye(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } } @@ -1302,43 +1491,39 @@ Col::fixed::operator=(const subview_cube& X) -#if defined(ARMA_USE_CXX11) +template +template +inline +Col::fixed::fixed(const std::initializer_list& list) + : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); - template - template - inline - Col::fixed::fixed(const std::initializer_list& list) - : Col( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) - { - arma_extra_debug_sigprint_this(this); - - (*this).operator=(list); - } + (*this).operator=(list); + } + + + +template +template +inline +Col& +Col::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + const uword N = uword(list.size()); + arma_debug_check( (N > fixed_n_elem), "Col::fixed: initialiser list is too long" ); - template - template - inline - Col& - Col::fixed::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - const uword N = uword(list.size()); - - arma_debug_check( (N > fixed_n_elem), "Col::fixed: initialiser list is too long" ); - - eT* this_mem = (*this).memptr(); - - arrayops::copy( this_mem, list.begin(), N ); - - for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } - - return *this; - } + eT* this_mem = (*this).memptr(); -#endif + arrayops::copy( this_mem, list.begin(), N ); + + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + + return *this; + } @@ -1475,7 +1660,6 @@ Col::fixed::st() const template template arma_inline -arma_warn_unused const eT& Col::fixed::at_alt(const uword ii) const { @@ -1497,7 +1681,6 @@ Col::fixed::at_alt(const uword ii) const template template arma_inline -arma_warn_unused eT& Col::fixed::operator[] (const uword ii) { @@ -1509,7 +1692,6 @@ Col::fixed::operator[] (const uword ii) template template arma_inline -arma_warn_unused const eT& Col::fixed::operator[] (const uword ii) const { @@ -1521,7 +1703,6 @@ Col::fixed::operator[] (const uword ii) const template template arma_inline -arma_warn_unused eT& Col::fixed::at(const uword ii) { @@ -1533,7 +1714,6 @@ Col::fixed::at(const uword ii) template template arma_inline -arma_warn_unused const eT& Col::fixed::at(const uword ii) const { @@ -1545,11 +1725,10 @@ Col::fixed::at(const uword ii) const template template arma_inline -arma_warn_unused eT& Col::fixed::operator() (const uword ii) { - arma_debug_check( (ii >= fixed_n_elem), "Col::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Col::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; } @@ -1559,11 +1738,10 @@ Col::fixed::operator() (const uword ii) template template arma_inline -arma_warn_unused const eT& Col::fixed::operator() (const uword ii) const { - arma_debug_check( (ii >= fixed_n_elem), "Col::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Col::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; } @@ -1573,7 +1751,6 @@ Col::fixed::operator() (const uword ii) const template template arma_inline -arma_warn_unused eT& Col::fixed::at(const uword in_row, const uword) { @@ -1585,7 +1762,6 @@ Col::fixed::at(const uword in_row, const uword) template template arma_inline -arma_warn_unused const eT& Col::fixed::at(const uword in_row, const uword) const { @@ -1597,11 +1773,10 @@ Col::fixed::at(const uword in_row, const uword) const template template arma_inline -arma_warn_unused eT& Col::fixed::operator() (const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; } @@ -1611,11 +1786,10 @@ Col::fixed::operator() (const uword in_row, const uword in_col template template arma_inline -arma_warn_unused const eT& Col::fixed::operator() (const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= fixed_n_elem) || (in_col > 0)), "Col::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[in_row] : Mat::mem_local[in_row]; } @@ -1625,7 +1799,6 @@ Col::fixed::operator() (const uword in_row, const uword in_col template template arma_inline -arma_warn_unused eT* Col::fixed::memptr() { @@ -1637,7 +1810,6 @@ Col::fixed::memptr() template template arma_inline -arma_warn_unused const eT* Col::fixed::memptr() const { @@ -1648,7 +1820,6 @@ Col::fixed::memptr() const template template -arma_hot inline const Col& Col::fixed::fill(const eT val) @@ -1666,7 +1837,6 @@ Col::fixed::fill(const eT val) template template -arma_hot inline const Col& Col::fixed::zeros() @@ -1684,7 +1854,6 @@ Col::fixed::zeros() template template -arma_hot inline const Col& Col::fixed::ones() @@ -1710,7 +1879,7 @@ Col::Col(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_me -#ifdef ARMA_EXTRA_COL_MEAT +#if defined(ARMA_EXTRA_COL_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_COL_MEAT) #endif diff --git a/src/armadillo_bits/CubeToMatOp_bones.hpp b/src/armadillo_bits/CubeToMatOp_bones.hpp index 3ff89147..cd2ba599 100644 --- a/src/armadillo_bits/CubeToMatOp_bones.hpp +++ b/src/armadillo_bits/CubeToMatOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,7 @@ template -class CubeToMatOp : public Base > +class CubeToMatOp : public Base< typename T1::elem_type, CubeToMatOp > { public: @@ -28,15 +30,15 @@ class CubeToMatOp : public Base typedef typename get_pod_type::result pod_type; inline explicit CubeToMatOp(const T1& in_m); - inline CubeToMatOp(const T1& in_m, const elem_type in_aux); + inline CubeToMatOp(const T1& in_m, const uword in_aux_uword); inline ~CubeToMatOp(); - arma_aligned const T1& m; //!< the operand; must be derived from BaseCube - arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 + arma_aligned const T1& m; //!< the operand; must be derived from BaseCube + arma_aligned uword aux_uword; //!< auxiliary data, uword format - static const bool is_row = op_type::template traits::is_row; - static const bool is_col = op_type::template traits::is_col; - static const bool is_xvec = op_type::template traits::is_xvec; + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; }; diff --git a/src/armadillo_bits/CubeToMatOp_meat.hpp b/src/armadillo_bits/CubeToMatOp_meat.hpp index 6032f0d7..abe83e81 100644 --- a/src/armadillo_bits/CubeToMatOp_meat.hpp +++ b/src/armadillo_bits/CubeToMatOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -31,9 +33,9 @@ CubeToMatOp::CubeToMatOp(const T1& in_m) template inline -CubeToMatOp::CubeToMatOp(const T1& in_m, const typename T1::elem_type in_aux) +CubeToMatOp::CubeToMatOp(const T1& in_m, const uword in_aux_uword) : m(in_m) - , aux(in_aux) + , aux_uword(in_aux_uword) { arma_extra_debug_sigprint(); } diff --git a/src/armadillo_bits/Cube_bones.hpp b/src/armadillo_bits/Cube_bones.hpp index 835cbddb..5cf364aa 100644 --- a/src/armadillo_bits/Cube_bones.hpp +++ b/src/armadillo_bits/Cube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,8 +23,8 @@ struct Cube_prealloc { - static const uword mat_ptrs_size = 4; - static const uword mem_n_elem = 64; + static constexpr uword mat_ptrs_size = 4; + static constexpr uword mem_n_elem = 64; }; @@ -37,12 +39,13 @@ class Cube : public BaseCube< eT, Cube > typedef eT elem_type; //!< the type of elements stored in the cube typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT - const uword n_rows; //!< number of rows in each slice (read-only) - const uword n_cols; //!< number of columns in each slice (read-only) - const uword n_elem_slice; //!< number of elements in each slice (read-only) - const uword n_slices; //!< number of slices in the cube (read-only) - const uword n_elem; //!< number of elements in the cube (read-only) - const uword mem_state; + const uword n_rows; //!< number of rows in each slice (read-only) + const uword n_cols; //!< number of columns in each slice (read-only) + const uword n_elem_slice; //!< number of elements in each slice (read-only) + const uword n_slices; //!< number of slices in the cube (read-only) + const uword n_elem; //!< number of elements in the cube (read-only) + const uword n_alloc; //!< number of allocated elements (read-only); NOTE: n_alloc can be 0, even if n_elem > 0 + const uword mem_state; // mem_state = 0: normal cube which manages its own memory // mem_state = 1: use auxiliary memory until a size change @@ -54,10 +57,27 @@ class Cube : public BaseCube< eT, Cube > protected: - arma_aligned const Mat** const mat_ptrs; + using mat_type = Mat; + + #if defined(ARMA_USE_OPENMP) + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = mat_type*; + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = std::atomic; + #else + using raw_mat_ptr_type = mat_type*; + using atomic_mat_ptr_type = mat_type*; + #endif + + atomic_mat_ptr_type* mat_ptrs = nullptr; - arma_align_mem Mat* mat_ptrs_local[ Cube_prealloc::mat_ptrs_size ]; - arma_align_mem eT mem_local[ Cube_prealloc::mem_n_elem ]; // local storage, for small cubes + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) + mutable std::mutex mat_mutex; // required for slice() + #endif + + arma_aligned atomic_mat_ptr_type mat_ptrs_local[ Cube_prealloc::mat_ptrs_size ]; + arma_align_mem eT mem_local[ Cube_prealloc::mem_n_elem ]; // local storage, for small cubes public: @@ -65,28 +85,32 @@ class Cube : public BaseCube< eT, Cube > inline ~Cube(); inline Cube(); - inline explicit Cube(const uword in_rows, const uword in_cols, const uword in_slices); + inline explicit Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); inline explicit Cube(const SizeCube& s); - template inline Cube(const uword in_rows, const uword in_cols, const uword in_slices, const fill::fill_class& f); - template inline Cube(const SizeCube& s, const fill::fill_class& f); + template inline explicit Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const arma_initmode_indicator&); + template inline explicit Cube(const SizeCube& s, const arma_initmode_indicator&); + + template inline Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const fill::fill_class& f); + template inline Cube(const SizeCube& s, const fill::fill_class& f); + + inline Cube(const uword in_rows, const uword in_cols, const uword in_slices, const fill::scalar_holder f); + inline Cube(const SizeCube& s, const fill::scalar_holder f); - #if defined(ARMA_USE_CXX11) inline Cube(Cube&& m); inline Cube& operator=(Cube&& m); - #endif inline Cube( eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices, const bool copy_aux_mem = true, const bool strict = false, const bool prealloc_mat = false); inline Cube(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const uword aux_n_slices); - inline Cube& operator=(const eT val); + inline Cube& operator= (const eT val); inline Cube& operator+=(const eT val); inline Cube& operator-=(const eT val); inline Cube& operator*=(const eT val); inline Cube& operator/=(const eT val); inline Cube(const Cube& m); - inline Cube& operator=(const Cube& m); + inline Cube& operator= (const Cube& m); inline Cube& operator+=(const Cube& m); inline Cube& operator-=(const Cube& m); inline Cube& operator%=(const Cube& m); @@ -96,14 +120,14 @@ class Cube : public BaseCube< eT, Cube > inline explicit Cube(const BaseCube& A, const BaseCube& B); inline Cube(const subview_cube& X); - inline Cube& operator=(const subview_cube& X); + inline Cube& operator= (const subview_cube& X); inline Cube& operator+=(const subview_cube& X); inline Cube& operator-=(const subview_cube& X); inline Cube& operator%=(const subview_cube& X); inline Cube& operator/=(const subview_cube& X); template inline Cube(const subview_cube_slices& X); - template inline Cube& operator=(const subview_cube_slices& X); + template inline Cube& operator= (const subview_cube_slices& X); template inline Cube& operator+=(const subview_cube_slices& X); template inline Cube& operator-=(const subview_cube_slices& X); template inline Cube& operator%=(const subview_cube_slices& X); @@ -117,7 +141,7 @@ class Cube : public BaseCube< eT, Cube > inline Mat& slice(const uword in_slice); inline const Mat& slice(const uword in_slice) const; - + arma_inline subview_cube rows(const uword in_row1, const uword in_row2); arma_inline const subview_cube rows(const uword in_row1, const uword in_row2) const; @@ -173,13 +197,11 @@ class Cube : public BaseCube< eT, Cube > template inline subview_cube_each2 each_slice(const Base& indices); template inline const subview_cube_each2 each_slice(const Base& indices) const; - #if defined(ARMA_USE_CXX11) - inline const Cube& each_slice(const std::function< void( Mat&) >& F); + inline Cube& each_slice(const std::function< void( Mat&) >& F); inline const Cube& each_slice(const std::function< void(const Mat&) >& F) const; - inline const Cube& each_slice(const std::function< void( Mat&) >& F, const bool use_mp); + inline Cube& each_slice(const std::function< void( Mat&) >& F, const bool use_mp); inline const Cube& each_slice(const std::function< void(const Mat&) >& F, const bool use_mp) const; - #endif template arma_inline subview_cube_slices slices(const Base& indices); @@ -196,81 +218,91 @@ class Cube : public BaseCube< eT, Cube > template inline void shed_slices(const Base& indices); - inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero = true); - inline void insert_cols(const uword row_num, const uword N, const bool set_to_zero = true); - inline void insert_slices(const uword slice_num, const uword N, const bool set_to_zero = true); + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_cols(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_slices(const uword slice_num, const uword N, const bool set_to_zero); + + inline void insert_rows(const uword row_num, const uword N); + inline void insert_cols(const uword row_num, const uword N); + inline void insert_slices(const uword slice_num, const uword N); template inline void insert_rows(const uword row_num, const BaseCube& X); template inline void insert_cols(const uword col_num, const BaseCube& X); template inline void insert_slices(const uword slice_num, const BaseCube& X); + template inline void insert_slices(const uword slice_num, const Base& X); template inline Cube(const GenCube& X); - template inline Cube& operator=(const GenCube& X); + template inline Cube& operator= (const GenCube& X); template inline Cube& operator+=(const GenCube& X); template inline Cube& operator-=(const GenCube& X); template inline Cube& operator%=(const GenCube& X); template inline Cube& operator/=(const GenCube& X); template inline Cube(const OpCube& X); - template inline Cube& operator=(const OpCube& X); + template inline Cube& operator= (const OpCube& X); template inline Cube& operator+=(const OpCube& X); template inline Cube& operator-=(const OpCube& X); template inline Cube& operator%=(const OpCube& X); template inline Cube& operator/=(const OpCube& X); template inline Cube(const eOpCube& X); - template inline Cube& operator=(const eOpCube& X); + template inline Cube& operator= (const eOpCube& X); template inline Cube& operator+=(const eOpCube& X); template inline Cube& operator-=(const eOpCube& X); template inline Cube& operator%=(const eOpCube& X); template inline Cube& operator/=(const eOpCube& X); template inline Cube(const mtOpCube& X); - template inline Cube& operator=(const mtOpCube& X); + template inline Cube& operator= (const mtOpCube& X); template inline Cube& operator+=(const mtOpCube& X); template inline Cube& operator-=(const mtOpCube& X); template inline Cube& operator%=(const mtOpCube& X); template inline Cube& operator/=(const mtOpCube& X); template inline Cube(const GlueCube& X); - template inline Cube& operator=(const GlueCube& X); + template inline Cube& operator= (const GlueCube& X); template inline Cube& operator+=(const GlueCube& X); template inline Cube& operator-=(const GlueCube& X); template inline Cube& operator%=(const GlueCube& X); template inline Cube& operator/=(const GlueCube& X); template inline Cube(const eGlueCube& X); - template inline Cube& operator=(const eGlueCube& X); + template inline Cube& operator= (const eGlueCube& X); template inline Cube& operator+=(const eGlueCube& X); template inline Cube& operator-=(const eGlueCube& X); template inline Cube& operator%=(const eGlueCube& X); template inline Cube& operator/=(const eGlueCube& X); template inline Cube(const mtGlueCube& X); - template inline Cube& operator=(const mtGlueCube& X); + template inline Cube& operator= (const mtGlueCube& X); template inline Cube& operator+=(const mtGlueCube& X); template inline Cube& operator-=(const mtGlueCube& X); template inline Cube& operator%=(const mtGlueCube& X); template inline Cube& operator/=(const mtGlueCube& X); - arma_inline arma_warn_unused const eT& at_alt (const uword i) const; + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; - arma_inline arma_warn_unused eT& operator[] (const uword i); - arma_inline arma_warn_unused const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; - arma_inline arma_warn_unused eT& at(const uword i); - arma_inline arma_warn_unused const eT& at(const uword i) const; + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; - arma_inline arma_warn_unused eT& operator() (const uword i); - arma_inline arma_warn_unused const eT& operator() (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col, const uword in_slice); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col, const uword in_slice) const; + #endif + + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col, const uword in_slice); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; arma_inline const Cube& operator++(); arma_inline void operator++(int); @@ -278,76 +310,71 @@ class Cube : public BaseCube< eT, Cube > arma_inline const Cube& operator--(); arma_inline void operator--(int); - inline arma_warn_unused bool is_finite() const; - arma_inline arma_warn_unused bool is_empty() const; - - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused arma_inline bool is_empty() const; - arma_inline arma_warn_unused bool in_range(const uword i) const; - arma_inline arma_warn_unused bool in_range(const span& x) const; + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; - inline arma_warn_unused bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; - inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused inline bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; - arma_inline arma_warn_unused eT* memptr(); - arma_inline arma_warn_unused const eT* memptr() const; + arma_warn_unused inline bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; - arma_inline arma_warn_unused eT* slice_memptr(const uword slice); - arma_inline arma_warn_unused const eT* slice_memptr(const uword slice) const; + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; - arma_inline arma_warn_unused eT* slice_colptr(const uword in_slice, const uword in_col); - arma_inline arma_warn_unused const eT* slice_colptr(const uword in_slice, const uword in_col) const; + arma_warn_unused arma_inline eT* slice_memptr(const uword slice); + arma_warn_unused arma_inline const eT* slice_memptr(const uword slice) const; - arma_cold inline void impl_print( const std::string& extra_text) const; - arma_cold inline void impl_print(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused arma_inline eT* slice_colptr(const uword in_slice, const uword in_col); + arma_warn_unused arma_inline const eT* slice_colptr(const uword in_slice, const uword in_col) const; - arma_cold inline void impl_raw_print( const std::string& extra_text) const; - arma_cold inline void impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const; + inline Cube& set_size(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& set_size(const SizeCube& s); - inline void set_size(const uword in_rows, const uword in_cols, const uword in_slices); - inline void set_size(const SizeCube& s); - - inline void reshape(const uword in_rows, const uword in_cols, const uword in_slices); - inline void reshape(const SizeCube& s); + inline Cube& reshape(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& reshape(const SizeCube& s); - inline void resize(const uword in_rows, const uword in_cols, const uword in_slices); - inline void resize(const SizeCube& s); - - arma_deprecated inline void reshape(const uword in_rows, const uword in_cols, const uword in_slices, const uword dim); //!< NOTE: don't use this form: it will be removed + inline Cube& resize(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& resize(const SizeCube& s); - template inline void copy_size(const Cube& m); + template inline Cube& copy_size(const Cube& m); - template inline const Cube& for_each(functor F); + template inline Cube& for_each(functor F); template inline const Cube& for_each(functor F) const; - template inline const Cube& transform(functor F); - template inline const Cube& imbue(functor F); + template inline Cube& transform(functor F); + template inline Cube& imbue(functor F); + + inline Cube& replace(const eT old_val, const eT new_val); - inline const Cube& replace(const eT old_val, const eT new_val); + inline Cube& clean(const pod_type threshold); - inline const Cube& clean(const pod_type threshold); + inline Cube& clamp(const eT min_val, const eT max_val); - inline const Cube& fill(const eT val); + inline Cube& fill(const eT val); - inline const Cube& zeros(); - inline const Cube& zeros(const uword in_rows, const uword in_cols, const uword in_slices); - inline const Cube& zeros(const SizeCube& s); + inline Cube& zeros(); + inline Cube& zeros(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& zeros(const SizeCube& s); - inline const Cube& ones(); - inline const Cube& ones(const uword in_rows, const uword in_cols, const uword in_slices); - inline const Cube& ones(const SizeCube& s); + inline Cube& ones(); + inline Cube& ones(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& ones(const SizeCube& s); - inline const Cube& randu(); - inline const Cube& randu(const uword in_rows, const uword in_cols, const uword in_slices); - inline const Cube& randu(const SizeCube& s); + inline Cube& randu(); + inline Cube& randu(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& randu(const SizeCube& s); - inline const Cube& randn(); - inline const Cube& randn(const uword in_rows, const uword in_cols, const uword in_slices); - inline const Cube& randn(const SizeCube& s); + inline Cube& randn(); + inline Cube& randn(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + inline Cube& randn(const SizeCube& s); inline void reset(); inline void soft_reset(); @@ -357,8 +384,8 @@ class Cube : public BaseCube< eT, Cube > template inline void set_imag(const BaseCube& X); - inline arma_warn_unused eT min() const; - inline arma_warn_unused eT max() const; + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; inline eT min(uword& index_of_min_val) const; inline eT max(uword& index_of_max_val) const; @@ -367,21 +394,21 @@ class Cube : public BaseCube< eT, Cube > inline eT max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_val) const; - inline arma_cold bool save(const std::string name, const file_type type = arma_binary, const bool print_status = true) const; - inline arma_cold bool save(const hdf5_name& spec, const file_type type = hdf5_binary, const bool print_status = true) const; - inline arma_cold bool save( std::ostream& os, const file_type type = arma_binary, const bool print_status = true) const; + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool load(const std::string name, const file_type type = auto_detect, const bool print_status = true); - inline arma_cold bool load(const hdf5_name& spec, const file_type type = hdf5_binary, const bool print_status = true); - inline arma_cold bool load( std::istream& is, const file_type type = auto_detect, const bool print_status = true); + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); - inline arma_cold bool quiet_save(const std::string name, const file_type type = arma_binary) const; - inline arma_cold bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; - inline arma_cold bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool quiet_load(const std::string name, const file_type type = auto_detect); - inline arma_cold bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); - inline arma_cold bool quiet_load( std::istream& is, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); // iterators @@ -410,9 +437,16 @@ class Cube : public BaseCube< eT, Cube > inline bool empty() const; inline uword size() const; + arma_warn_unused inline eT& front(); + arma_warn_unused inline const eT& front() const; + + arma_warn_unused inline eT& back(); + arma_warn_unused inline const eT& back() const; + inline void swap(Cube& B); - inline void steal_mem(Cube& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Cube& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Cube& X, const bool is_move); //!< don't use this unless you're writing code internal to Armadillo template class fixed; @@ -420,7 +454,7 @@ class Cube : public BaseCube< eT, Cube > protected: inline void init_cold(); - inline void init_warm(const uword in_rows, const uword in_cols, const uword in_slices); + inline void init_warm(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); template inline void init(const BaseCube& A, const BaseCube& B); @@ -428,6 +462,9 @@ class Cube : public BaseCube< eT, Cube > inline void delete_mat(); inline void create_mat(); + inline Mat* create_mat_ptr(const uword in_slice) const; + inline Mat* get_mat_ptr(const uword in_slice) const; + friend class glue_join; friend class op_reshape; friend class op_resize; @@ -436,7 +473,7 @@ class Cube : public BaseCube< eT, Cube > public: - #ifdef ARMA_EXTRA_CUBE_PROTO + #if defined(ARMA_EXTRA_CUBE_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_CUBE_PROTO) #endif }; @@ -449,13 +486,13 @@ class Cube::fixed : public Cube { private: - static const uword fixed_n_elem = fixed_n_rows * fixed_n_cols * fixed_n_slices; - static const uword fixed_n_elem_slice = fixed_n_rows * fixed_n_cols; + static constexpr uword fixed_n_elem = fixed_n_rows * fixed_n_cols * fixed_n_slices; + static constexpr uword fixed_n_elem_slice = fixed_n_rows * fixed_n_cols; - static const bool use_extra = (fixed_n_elem > Cube_prealloc::mem_n_elem); + static constexpr bool use_extra = (fixed_n_elem > Cube_prealloc::mem_n_elem); - arma_aligned Mat* mat_ptrs_local_extra[ (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? fixed_n_slices : 1 ]; - arma_align_mem eT mem_local_extra [ use_extra ? fixed_n_elem : 1 ]; + arma_aligned atomic_mat_ptr_type mat_ptrs_local_extra[ (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? fixed_n_slices : 1 ]; + arma_align_mem eT mem_local_extra[ use_extra ? fixed_n_elem : 1 ]; arma_inline void mem_setup(); @@ -465,6 +502,7 @@ class Cube::fixed : public Cube inline fixed(); inline fixed(const fixed& X); + inline fixed(const fill::scalar_holder f); template inline fixed(const fill::fill_class& f); template inline fixed(const BaseCube& A); template inline fixed(const BaseCube& A, const BaseCube& B); @@ -475,20 +513,25 @@ class Cube::fixed : public Cube inline Cube& operator=(const fixed& X); - arma_inline arma_warn_unused eT& operator[] (const uword i); - arma_inline arma_warn_unused const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword i); - arma_inline arma_warn_unused const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; - arma_inline arma_warn_unused eT& operator() (const uword i); - arma_inline arma_warn_unused const eT& operator() (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col, const uword in_slice) const; + #endif - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col, const uword in_slice); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col, const uword in_slice) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col, const uword in_slice); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col, const uword in_slice); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col, const uword in_slice) const; }; diff --git a/src/armadillo_bits/Cube_meat.hpp b/src/armadillo_bits/Cube_meat.hpp index 07e53147..265dcf47 100644 --- a/src/armadillo_bits/Cube_meat.hpp +++ b/src/armadillo_bits/Cube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,17 +28,14 @@ Cube::~Cube() delete_mat(); - if( (mem_state == 0) && (n_elem > Cube_prealloc::mem_n_elem) ) + if( (mem_state == 0) && (n_alloc > 0) ) { + arma_extra_debug_print("Cube::destructor: releasing memory"); memory::release( access::rw(mem) ); } // try to expose buggy user code that accesses deleted objects - if(arma_config::debug) - { - access::rw(mem) = 0; - access::rw(mat_ptrs) = 0; - } + if(arma_config::debug) { access::rw(mem) = nullptr; } arma_type_check(( is_supported_elem_type::value == false )); } @@ -51,9 +50,9 @@ Cube::Cube() , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); } @@ -69,13 +68,19 @@ Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_sl , n_elem_slice(in_n_rows*in_n_cols) , n_slices(in_n_slices) , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } } @@ -88,13 +93,73 @@ Cube::Cube(const SizeCube& s) , n_elem_slice(s.n_rows*s.n_cols) , n_slices(s.n_slices) , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const arma_initmode_indicator&) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Cube::Cube(const SizeCube& s, const arma_initmode_indicator&) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Cube::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } } @@ -109,20 +174,20 @@ Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_sl , n_elem_slice(in_n_rows*in_n_cols) , n_slices(in_n_slices) , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); init_cold(); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } - if(is_same_type::yes) { arma_debug_check(true, "Cube::Cube(): unsupported fill type"); } + arma_static_check( (is_same_type::yes), "Cube::Cube(): unsupported fill type" ); } @@ -136,59 +201,97 @@ Cube::Cube(const SizeCube& s, const fill::fill_class&) , n_elem_slice(s.n_rows*s.n_cols) , n_slices(s.n_slices) , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); init_cold(); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } - if(is_same_type::yes) { arma_debug_check(true, "Cube::Cube(): unsupported fill type"); } + arma_static_check( (is_same_type::yes), "Cube::Cube(): unsupported fill type" ); } -#if defined(ARMA_USE_CXX11) +//! construct the cube to have user specified dimensions and fill with specified value +template +inline +Cube::Cube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const fill::scalar_holder f) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem_slice(in_n_rows*in_n_cols) + , n_slices(in_n_slices) + , n_elem(in_n_rows*in_n_cols*in_n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); - template - inline - Cube::Cube(Cube&& in_cube) - : n_rows(0) - , n_cols(0) - , n_elem_slice(0) - , n_slices(0) - , n_elem(0) - , mem_state(0) - , mem() - , mat_ptrs(0) - { - arma_extra_debug_sigprint_this(this); - arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); - - (*this).steal_mem(in_cube); - } - + init_cold(); + (*this).fill(f.scalar); + } + + + +template +inline +Cube::Cube(const SizeCube& s, const fill::scalar_holder f) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem_slice(s.n_rows*s.n_cols) + , n_slices(s.n_slices) + , n_elem(s.n_rows*s.n_cols*s.n_slices) + , n_alloc() + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); - template - inline - Cube& - Cube::operator=(Cube&& in_cube) - { - arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); - - (*this).steal_mem(in_cube); - - return *this; - } + init_cold(); -#endif + (*this).fill(f.scalar); + } + + + +template +inline +Cube::Cube(Cube&& in_cube) + : n_rows(0) + , n_cols(0) + , n_elem_slice(0) + , n_slices(0) + , n_elem(0) + , n_alloc(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); + + (*this).steal_mem(in_cube, true); + } + + + +template +inline +Cube& +Cube::operator=(Cube&& in_cube) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &in_cube); + + (*this).steal_mem(in_cube, true); + + return *this; + } @@ -197,12 +300,12 @@ inline void Cube::init_cold() { - arma_extra_debug_sigprint( arma_str::format("n_rows = %d, n_cols = %d, n_slices = %d") % n_rows % n_cols % n_slices ); + arma_extra_debug_sigprint( arma_str::format("n_rows = %u, n_cols = %u, n_slices = %u") % n_rows % n_cols % n_slices ); - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "Cube::init(): requested size is too large"; #else - const char* error_message = "Cube::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "Cube::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_check @@ -218,20 +321,17 @@ Cube::init_cold() if(n_elem <= Cube_prealloc::mem_n_elem) { - if(n_elem == 0) - { - access::rw(mem) = NULL; - } - else - { - arma_extra_debug_print("Cube::init(): using local memory"); - access::rw(mem) = mem_local; - } + if(n_elem > 0) { arma_extra_debug_print("Cube::init(): using local memory"); } + + access::rw(mem) = (n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; } else { arma_extra_debug_print("Cube::init(): acquiring memory"); - access::rw(mem) = memory::acquire(n_elem); + + access::rw(mem) = memory::acquire(n_elem); + access::rw(n_alloc) = n_elem; } create_mat(); @@ -244,21 +344,23 @@ inline void Cube::init_warm(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) { - arma_extra_debug_sigprint( arma_str::format("in_n_rows = %d, in_n_cols = %d, in_n_slices = %d") % in_n_rows % in_n_cols % in_n_slices ); + arma_extra_debug_sigprint( arma_str::format("in_n_rows = %u, in_n_cols = %u, in_n_slices = %u") % in_n_rows % in_n_cols % in_n_slices ); if( (n_rows == in_n_rows) && (n_cols == in_n_cols) && (n_slices == in_n_slices) ) { return; } const uword t_mem_state = mem_state; bool err_state = false; - char* err_msg = 0; + char* err_msg = nullptr; - arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), "Cube::init(): size is fixed and hence cannot be changed" ); + const char* error_message_1 = "Cube::init(): size is fixed and hence cannot be changed"; - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) - const char* error_message = "Cube::init(): requested size is too large"; + arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), error_message_1 ); + + #if defined(ARMA_64BIT_WORD) + const char* error_message_2 = "Cube::init(): requested size is too large"; #else - const char* error_message = "Cube::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message_2 = "Cube::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_set_error @@ -270,7 +372,7 @@ Cube::init_warm(const uword in_n_rows, const uword in_n_cols, const uword in ? ( (double(in_n_rows) * double(in_n_cols) * double(in_n_slices)) > double(ARMA_MAX_UWORD) ) : false ), - error_message + error_message_2 ); arma_debug_check(err_state, err_msg); @@ -290,68 +392,64 @@ Cube::init_warm(const uword in_n_rows, const uword in_n_cols, const uword in access::rw(n_slices) = in_n_slices; create_mat(); + + return; } - else // condition: old_n_elem != new_n_elem + + arma_debug_check( (t_mem_state == 2), "Cube::init(): mismatch between size of auxiliary memory and requested size" ); + + delete_mat(); + + if(new_n_elem <= Cube_prealloc::mem_n_elem) { - arma_debug_check( (t_mem_state == 2), "Cube::init(): requested size is not compatible with the size of auxiliary memory" ); - - delete_mat(); - - if(new_n_elem < old_n_elem) // reuse existing memory if possible + if(n_alloc > 0) { - if( (t_mem_state == 0) && (new_n_elem <= Cube_prealloc::mem_n_elem) ) - { - if(old_n_elem > Cube_prealloc::mem_n_elem) - { - arma_extra_debug_print("Cube::init(): releasing memory"); - memory::release( access::rw(mem) ); - } - - if(new_n_elem == 0) - { - access::rw(mem) = NULL; - } - else - { - arma_extra_debug_print("Cube::init(): using local memory"); - access::rw(mem) = mem_local; - } - } - else - { - arma_extra_debug_print("Cube::init(): reusing memory"); - } + arma_extra_debug_print("Cube::init(): releasing memory"); + memory::release( access::rw(mem) ); } - else // condition: new_n_elem > old_n_elem + + if(new_n_elem > 0) { arma_extra_debug_print("Cube::init(): using local memory"); } + + access::rw(mem) = (new_n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else // condition: new_n_elem > Cube_prealloc::mem_n_elem + { + if(new_n_elem > n_alloc) { - if( (t_mem_state == 0) && (old_n_elem > Cube_prealloc::mem_n_elem) ) + if(n_alloc > 0) { arma_extra_debug_print("Cube::init(): releasing memory"); memory::release( access::rw(mem) ); + + // in case memory::acquire() throws an exception + access::rw(mem) = nullptr; + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem_slice) = 0; + access::rw(n_slices) = 0; + access::rw(n_elem) = 0; + access::rw(n_alloc) = 0; } - if(new_n_elem <= Cube_prealloc::mem_n_elem) - { - arma_extra_debug_print("Cube::init(): using local memory"); - access::rw(mem) = mem_local; - } - else - { - arma_extra_debug_print("Cube::init(): acquiring memory"); - access::rw(mem) = memory::acquire(new_n_elem); - } - - access::rw(mem_state) = 0; + arma_extra_debug_print("Cube::init(): acquiring memory"); + access::rw(mem) = memory::acquire(new_n_elem); + access::rw(n_alloc) = new_n_elem; + } + else // condition: new_n_elem <= n_alloc + { + arma_extra_debug_print("Cube::init(): reusing memory"); } - - access::rw(n_rows) = in_n_rows; - access::rw(n_cols) = in_n_cols; - access::rw(n_elem_slice) = in_n_rows*in_n_cols; - access::rw(n_slices) = in_n_slices; - access::rw(n_elem) = new_n_elem; - - create_mat(); } + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem_slice) = in_n_rows*in_n_cols; + access::rw(n_slices) = in_n_slices; + access::rw(n_elem) = new_n_elem; + access::rw(mem_state) = 0; + + create_mat(); } @@ -427,16 +525,25 @@ Cube::delete_mat() { arma_extra_debug_sigprint(); - if((n_slices > 0) && (mat_ptrs != NULL)) + if((n_slices > 0) && (mat_ptrs != nullptr)) { - for(uword uslice = 0; uslice < n_slices; ++uslice) + for(uword s=0; s < n_slices; ++s) { - if(mat_ptrs[uslice] != NULL) { delete access::rw(mat_ptrs[uslice]); } + raw_mat_ptr_type mat_ptr = raw_mat_ptr_type(mat_ptrs[s]); // explicit cast to indicate load from std::atomic*> + + if(mat_ptr != nullptr) + { + arma_extra_debug_print( arma_str::format("Cube::delete_mat(): destroying matrix %u") % s ); + delete mat_ptr; + mat_ptrs[s] = nullptr; + } } if( (mem_state <= 2) && (n_slices > Cube_prealloc::mat_ptrs_size) ) { + arma_extra_debug_print("Cube::delete_mat(): freeing mat_ptrs array"); delete [] mat_ptrs; + mat_ptrs = nullptr; } } } @@ -450,31 +557,110 @@ Cube::create_mat() { arma_extra_debug_sigprint(); - if(n_slices == 0) + if(n_slices == 0) { mat_ptrs = nullptr; return; } + + if(mem_state <= 2) { - access::rw(mat_ptrs) = NULL; + if(n_slices <= Cube_prealloc::mat_ptrs_size) + { + arma_extra_debug_print("Cube::create_mat(): using local memory for mat_ptrs array"); + + mat_ptrs = mat_ptrs_local; + } + else + { + arma_extra_debug_print("Cube::create_mat(): allocating mat_ptrs array"); + + mat_ptrs = new(std::nothrow) atomic_mat_ptr_type[n_slices]; + + arma_check_bad_alloc( (mat_ptrs == nullptr), "Cube::create_mat(): out of memory" ); + } } - else + + for(uword s=0; s < n_slices; ++s) { mat_ptrs[s] = nullptr; } + } + + + +template +inline +Mat* +Cube::create_mat_ptr(const uword in_slice) const + { + arma_extra_debug_sigprint(); + + arma_extra_debug_print( arma_str::format("Cube::create_mat_ptr(): creating matrix %u") % in_slice ); + + const eT* mat_mem = (n_elem_slice > 0) ? slice_memptr(in_slice) : nullptr; + + Mat* mat_ptr = new(std::nothrow) Mat('j', mat_mem, n_rows, n_cols); + + return mat_ptr; + } + + + +template +inline +Mat* +Cube::get_mat_ptr(const uword in_slice) const + { + arma_extra_debug_sigprint(); + + raw_mat_ptr_type mat_ptr = nullptr; + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp atomic read + mat_ptr = mat_ptrs[in_slice]; + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + mat_ptr = mat_ptrs[in_slice].load(); + } + #else + { + mat_ptr = mat_ptrs[in_slice]; + } + #endif + + if(mat_ptr == nullptr) { - if(mem_state <= 2) + #if defined(ARMA_USE_OPENMP) { - if(n_slices <= Cube_prealloc::mat_ptrs_size) - { - access::rw(mat_ptrs) = const_cast< const Mat** >(mat_ptrs_local); - } - else + #pragma omp critical (arma_Cube_mat_ptrs) { - access::rw(mat_ptrs) = new(std::nothrow) const Mat*[n_slices]; + #pragma omp atomic read + mat_ptr = mat_ptrs[in_slice]; - arma_check_bad_alloc( (mat_ptrs == 0), "Cube::create_mat(): out of memory" ); + if(mat_ptr == nullptr) { mat_ptr = create_mat_ptr(in_slice); } + + #pragma omp atomic write + mat_ptrs[in_slice] = mat_ptr; } } - - for(uword uslice = 0; uslice < n_slices; ++uslice) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(mat_mutex); + + mat_ptr = mat_ptrs[in_slice].load(); + + if(mat_ptr == nullptr) { mat_ptr = create_mat_ptr(in_slice); } + + mat_ptrs[in_slice].store(mat_ptr); + } + #else { - mat_ptrs[uslice] = NULL; + mat_ptr = create_mat_ptr(in_slice); + + mat_ptrs[in_slice] = mat_ptr; } + #endif + + arma_check_bad_alloc( (mat_ptr == nullptr), "Cube::get_mat_ptr(): out of memory" ); } + + return mat_ptr; } @@ -489,7 +675,9 @@ Cube::operator=(const eT val) arma_extra_debug_sigprint(); init_warm(1,1,1); + access::rw(mem[0]) = val; + return *this; } @@ -564,9 +752,9 @@ Cube::Cube(const Cube& x) , n_elem_slice(x.n_elem_slice) , n_slices(x.n_slices) , n_elem(x.n_elem) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); arma_extra_debug_sigprint(arma_str::format("this = %x in_cube = %x") % this % &x); @@ -609,15 +797,15 @@ Cube::Cube(eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, cons , n_elem_slice( aux_n_rows*aux_n_cols ) , n_slices ( aux_n_slices ) , n_elem ( aux_n_rows*aux_n_cols*aux_n_slices ) + , n_alloc ( 0 ) , mem_state ( copy_aux_mem ? 0 : (strict ? 2 : 1) ) - , mem ( copy_aux_mem ? 0 : aux_mem ) - , mat_ptrs ( 0 ) + , mem ( copy_aux_mem ? nullptr : aux_mem ) { arma_extra_debug_sigprint_this(this); - if(prealloc_mat == true) { arma_debug_warn("Cube::Cube(): parameter 'prealloc_mat' ignored as it's no longer used"); } + arma_ignore(prealloc_mat); // kept only for compatibility with old user code - if(copy_aux_mem == true) + if(copy_aux_mem) { init_cold(); @@ -641,9 +829,9 @@ Cube::Cube(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols , n_elem_slice(aux_n_rows*aux_n_cols) , n_slices(aux_n_slices) , n_elem(aux_n_rows*aux_n_cols*aux_n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -736,9 +924,9 @@ Cube::Cube , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -747,7 +935,7 @@ Cube::Cube -//! construct a cube from a subview_cube instance (e.g. construct a cube from a delayed subcube operation) +//! construct a cube from a subview_cube instance (eg. construct a cube from a delayed subcube operation) template inline Cube::Cube(const subview_cube& X) @@ -756,9 +944,9 @@ Cube::Cube(const subview_cube& X) , n_elem_slice(X.n_elem_slice) , n_slices(X.n_slices) , n_elem(X.n_elem) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -769,7 +957,7 @@ Cube::Cube(const subview_cube& X) -//! construct a cube from a subview_cube instance (e.g. construct a cube from a delayed subcube operation) +//! construct a cube from a subview_cube instance (eg. construct a cube from a delayed subcube operation) template inline Cube& @@ -866,9 +1054,9 @@ Cube::Cube(const subview_cube_slices& X) , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -971,7 +1159,7 @@ Cube::row(const uword in_row) { arma_extra_debug_sigprint(); - arma_debug_check( (in_row >= n_rows), "Cube::row(): index out of bounds" ); + arma_debug_check_bounds( (in_row >= n_rows), "Cube::row(): index out of bounds" ); return (*this).rows(in_row, in_row); } @@ -986,7 +1174,7 @@ Cube::row(const uword in_row) const { arma_extra_debug_sigprint(); - arma_debug_check( (in_row >= n_rows), "Cube::row(): index out of bounds" ); + arma_debug_check_bounds( (in_row >= n_rows), "Cube::row(): index out of bounds" ); return (*this).rows(in_row, in_row); } @@ -1001,7 +1189,7 @@ Cube::col(const uword in_col) { arma_extra_debug_sigprint(); - arma_debug_check( (in_col >= n_cols), "Cube::col(): index out of bounds" ); + arma_debug_check_bounds( (in_col >= n_cols), "Cube::col(): index out of bounds" ); return (*this).cols(in_col, in_col); } @@ -1016,7 +1204,7 @@ Cube::col(const uword in_col) const { arma_extra_debug_sigprint(); - arma_debug_check( (in_col >= n_cols), "Cube::col(): index out of bounds" ); + arma_debug_check_bounds( (in_col >= n_cols), "Cube::col(): index out of bounds" ); return (*this).cols(in_col, in_col); } @@ -1031,16 +1219,9 @@ Cube::slice(const uword in_slice) { arma_extra_debug_sigprint(); - arma_debug_check( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); - - if(mat_ptrs[in_slice] == NULL) - { - const eT* ptr = (n_elem_slice > 0) ? slice_memptr(in_slice) : NULL; - - mat_ptrs[in_slice] = new Mat('j', ptr, n_rows, n_cols); - } + arma_debug_check_bounds( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); - return const_cast< Mat& >( *(mat_ptrs[in_slice]) ); + return *(get_mat_ptr(in_slice)); } @@ -1053,16 +1234,9 @@ Cube::slice(const uword in_slice) const { arma_extra_debug_sigprint(); - arma_debug_check( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); - - if(mat_ptrs[in_slice] == NULL) - { - const eT* ptr = (n_elem_slice > 0) ? slice_memptr(in_slice) : NULL; - - mat_ptrs[in_slice] = new Mat('j', ptr, n_rows, n_cols); - } + arma_debug_check_bounds( (in_slice >= n_slices), "Cube::slice(): index out of bounds" ); - return *(mat_ptrs[in_slice]); + return *(get_mat_ptr(in_slice)); } @@ -1075,7 +1249,7 @@ Cube::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Cube::rows(): indices out of bounds or incorrectly used" @@ -1096,7 +1270,7 @@ Cube::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Cube::rows(): indices out of bounds or incorrectly used" @@ -1117,7 +1291,7 @@ Cube::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Cube::cols(): indices out of bounds or incorrectly used" @@ -1138,7 +1312,7 @@ Cube::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Cube::cols(): indices out of bounds or incorrectly used" @@ -1159,7 +1333,7 @@ Cube::slices(const uword in_slice1, const uword in_slice2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices), "Cube::slices(): indices out of bounds or incorrectly used" @@ -1180,7 +1354,7 @@ Cube::slices(const uword in_slice1, const uword in_slice2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices), "Cube::slices(): indices out of bounds or incorrectly used" @@ -1201,7 +1375,7 @@ Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), @@ -1225,7 +1399,7 @@ Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), @@ -1257,7 +1431,7 @@ Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice const uword s_n_cols = s.n_cols; const uword s_n_slices = s.n_slices; - arma_debug_check + arma_debug_check_bounds ( ( in_row1 >= l_n_rows) || ( in_col1 >= l_n_cols) || ( in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + s_n_slices) > l_n_slices), @@ -1285,7 +1459,7 @@ Cube::subcube(const uword in_row1, const uword in_col1, const uword in_slice const uword s_n_cols = s.n_cols; const uword s_n_slices = s.n_slices; - arma_debug_check + arma_debug_check_bounds ( ( in_row1 >= l_n_rows) || ( in_col1 >= l_n_cols) || ( in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + s_n_slices) > l_n_slices), @@ -1325,7 +1499,7 @@ Cube::subcube(const span& row_span, const span& col_span, const span& slice_ const uword in_slice2 = slice_span.b; const uword subcube_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1369,7 +1543,7 @@ Cube::subcube(const span& row_span, const span& col_span, const span& slice_ const uword in_slice2 = slice_span.b; const uword subcube_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1440,7 +1614,7 @@ Cube::tube(const uword in_row1, const uword in_col1) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= n_rows) || (in_col1 >= n_cols)), "Cube::tube(): indices out of bounds" @@ -1458,7 +1632,7 @@ Cube::tube(const uword in_row1, const uword in_col1) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= n_rows) || (in_col1 >= n_cols)), "Cube::tube(): indices out of bounds" @@ -1476,7 +1650,7 @@ Cube::tube(const uword in_row1, const uword in_col1, const uword in_row2, co { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), @@ -1498,7 +1672,7 @@ Cube::tube(const uword in_row1, const uword in_col1, const uword in_row2, co { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), @@ -1526,7 +1700,7 @@ Cube::tube(const uword in_row1, const uword in_col1, const SizeMat& s) const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "Cube::tube(): indices or size out of bounds" @@ -1550,7 +1724,7 @@ Cube::tube(const uword in_row1, const uword in_col1, const SizeMat& s) const const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "Cube::tube(): indices or size out of bounds" @@ -1582,7 +1756,7 @@ Cube::tube(const span& row_span, const span& col_span) const uword in_col2 = col_span.b; const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1617,7 +1791,7 @@ Cube::tube(const span& row_span, const span& col_span) const const uword in_col2 = col_span.b; const uword subcube_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1638,7 +1812,7 @@ Cube::head_slices(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_slices), "Cube::head_slices(): size out of bounds" ); + arma_debug_check_bounds( (N > n_slices), "Cube::head_slices(): size out of bounds" ); return subview_cube(*this, 0, 0, 0, n_rows, n_cols, N); } @@ -1652,7 +1826,7 @@ Cube::head_slices(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_slices), "Cube::head_slices(): size out of bounds" ); + arma_debug_check_bounds( (N > n_slices), "Cube::head_slices(): size out of bounds" ); return subview_cube(*this, 0, 0, 0, n_rows, n_cols, N); } @@ -1666,7 +1840,7 @@ Cube::tail_slices(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); + arma_debug_check_bounds( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); const uword start_slice = n_slices - N; @@ -1682,7 +1856,7 @@ Cube::tail_slices(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); + arma_debug_check_bounds( (N > n_slices), "Cube::tail_slices(): size out of bounds" ); const uword start_slice = n_slices - N; @@ -1793,109 +1967,106 @@ Cube::each_slice(const Base& indices) const -#if defined(ARMA_USE_CXX11) - - //! apply a lambda function to each slice, where each slice is interpreted as a matrix - template - inline - const Cube& - Cube::each_slice(const std::function< void(Mat&) >& F) - { - arma_extra_debug_sigprint(); - - for(uword slice_id=0; slice_id < n_slices; ++slice_id) - { - Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - - F(tmp); - } +//! apply a lambda function to each slice, where each slice is interpreted as a matrix +template +inline +Cube& +Cube::each_slice(const std::function< void(Mat&) >& F) + { + arma_extra_debug_sigprint(); + + for(uword slice_id=0; slice_id < n_slices; ++slice_id) + { + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - return *this; + F(tmp); } + return *this; + } + + + +template +inline +const Cube& +Cube::each_slice(const std::function< void(const Mat&) >& F) const + { + arma_extra_debug_sigprint(); - - template - inline - const Cube& - Cube::each_slice(const std::function< void(const Mat&) >& F) const + for(uword slice_id=0; slice_id < n_slices; ++slice_id) { - arma_extra_debug_sigprint(); + const Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - for(uword slice_id=0; slice_id < n_slices; ++slice_id) - { - const Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - - F(tmp); - } - - return *this; + F(tmp); } + return *this; + } + + + +template +inline +Cube& +Cube::each_slice(const std::function< void(Mat&) >& F, const bool use_mp) + { + arma_extra_debug_sigprint(); + if((use_mp == false) || (arma_config::openmp == false)) + { + return (*this).each_slice(F); + } - template - inline - const Cube& - Cube::each_slice(const std::function< void(Mat&) >& F, const bool use_mp) + #if defined(ARMA_USE_OPENMP) { - arma_extra_debug_sigprint(); + const uword local_n_slices = n_slices; + const int n_threads = mp_thread_limit::get(); - if((use_mp == false) || (arma_config::openmp == false)) + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) { - return (*this).each_slice(F); - } - - #if defined(ARMA_USE_OPENMP) - { - const uword local_n_slices = n_slices; - const int n_threads = mp_thread_limit::get(); + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) - { - Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - - F(tmp); - } + F(tmp); } - #endif - - return *this; } + #endif + return *this; + } + + + +template +inline +const Cube& +Cube::each_slice(const std::function< void(const Mat&) >& F, const bool use_mp) const + { + arma_extra_debug_sigprint(); + if((use_mp == false) || (arma_config::openmp == false)) + { + return (*this).each_slice(F); + } - template - inline - const Cube& - Cube::each_slice(const std::function< void(const Mat&) >& F, const bool use_mp) const + #if defined(ARMA_USE_OPENMP) { - arma_extra_debug_sigprint(); - - if((use_mp == false) || (arma_config::openmp == false)) - { - return (*this).each_slice(F); - } + const uword local_n_slices = n_slices; + const int n_threads = mp_thread_limit::get(); - #if defined(ARMA_USE_OPENMP) + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) { - const uword local_n_slices = n_slices; - const int n_threads = mp_thread_limit::get(); + Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword slice_id=0; slice_id < local_n_slices; ++slice_id) - { - Mat tmp('j', slice_memptr(slice_id), n_rows, n_cols); - - F(tmp); - } + F(tmp); } - #endif - - return *this; } -#endif + #endif + + return *this; + } @@ -1933,7 +2104,7 @@ Cube::shed_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "Cube::shed_row(): index out of bounds"); + arma_debug_check_bounds( row_num >= n_rows, "Cube::shed_row(): index out of bounds" ); shed_rows(row_num, row_num); } @@ -1948,7 +2119,7 @@ Cube::shed_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Cube::shed_col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Cube::shed_col(): index out of bounds" ); shed_cols(col_num, col_num); } @@ -1963,7 +2134,7 @@ Cube::shed_slice(const uword slice_num) { arma_extra_debug_sigprint(); - arma_debug_check( slice_num >= n_slices, "Cube::shed_slice(): index out of bounds"); + arma_debug_check_bounds( slice_num >= n_slices, "Cube::shed_slice(): index out of bounds" ); shed_slices(slice_num, slice_num); } @@ -1978,7 +2149,7 @@ Cube::shed_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Cube::shed_rows(): indices out of bounds or incorrectly used" @@ -1987,7 +2158,7 @@ Cube::shed_rows(const uword in_row1, const uword in_row2) const uword n_keep_front = in_row1; const uword n_keep_back = n_rows - (in_row2 + 1); - Cube X(n_keep_front + n_keep_back, n_cols, n_slices); + Cube X(n_keep_front + n_keep_back, n_cols, n_slices, arma_nozeros_indicator()); if(n_keep_front > 0) { @@ -2012,7 +2183,7 @@ Cube::shed_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Cube::shed_cols(): indices out of bounds or incorrectly used" @@ -2021,7 +2192,7 @@ Cube::shed_cols(const uword in_col1, const uword in_col2) const uword n_keep_front = in_col1; const uword n_keep_back = n_cols - (in_col2 + 1); - Cube X(n_rows, n_keep_front + n_keep_back, n_slices); + Cube X(n_rows, n_keep_front + n_keep_back, n_slices, arma_nozeros_indicator()); if(n_keep_front > 0) { @@ -2046,7 +2217,7 @@ Cube::shed_slices(const uword in_slice1, const uword in_slice2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices), "Cube::shed_slices(): indices out of bounds or incorrectly used" @@ -2055,7 +2226,7 @@ Cube::shed_slices(const uword in_slice1, const uword in_slice2) const uword n_keep_front = in_slice1; const uword n_keep_back = n_slices - (in_slice2 + 1); - Cube X(n_rows, n_cols, n_keep_front + n_keep_back); + Cube X(n_rows, n_cols, n_keep_front + n_keep_back, arma_nozeros_indicator()); if(n_keep_front > 0) { @@ -2101,11 +2272,11 @@ Cube::shed_slices(const Base& indices) { for(uword i=0; i= n_slices), "Cube::shed_slices(): indices out of bounds" ); + arma_debug_check_bounds( (slices_to_shed_mem[i] >= n_slices), "Cube::shed_slices(): indices out of bounds" ); } } - Col tmp3(n_slices); + Col tmp3(n_slices, arma_nozeros_indicator()); uword* tmp3_mem = tmp3.memptr(); @@ -2149,35 +2320,45 @@ Cube::insert_rows(const uword row_num, const uword N, const bool set_to_zero { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Cube::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_rows = n_rows; const uword A_n_rows = row_num; const uword B_n_rows = t_n_rows - row_num; // insertion at row_num == n_rows is in effect an append operation - arma_debug_check( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds"); + arma_debug_check_bounds( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } - if(N > 0) + Cube out(t_n_rows + N, n_cols, n_slices, arma_nozeros_indicator()); + + if(A_n_rows > 0) { - Cube out(t_n_rows + N, n_cols, n_slices); - - if(A_n_rows > 0) - { - out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); - } - - if(B_n_rows > 0) - { - out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); - } - - if(set_to_zero == true) - { - out.rows(row_num, row_num + N - 1).zeros(); - } - - steal_mem(out); + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); + } + + out.rows(row_num, row_num + N - 1).zeros(); + + steal_mem(out); } @@ -2189,41 +2370,49 @@ Cube::insert_cols(const uword col_num, const uword N, const bool set_to_zero { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Cube::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_cols = n_cols; const uword A_n_cols = col_num; const uword B_n_cols = t_n_cols - col_num; // insertion at col_num == n_cols is in effect an append operation - arma_debug_check( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds"); + arma_debug_check_bounds( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } + + Cube out(n_rows, t_n_cols + N, n_slices, arma_nozeros_indicator()); - if(N > 0) + if(A_n_cols > 0) { - Cube out(n_rows, t_n_cols + N, n_slices); - - if(A_n_cols > 0) - { - out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); - } - - if(B_n_cols > 0) - { - out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); - } - - if(set_to_zero == true) - { - out.cols(col_num, col_num + N - 1).zeros(); - } - - steal_mem(out); + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); } + + out.cols(col_num, col_num + N - 1).zeros(); + + steal_mem(out); } -//! insert N slices at the specified slice position, -//! optionally setting the elements of the inserted slices to zero template inline void @@ -2231,40 +2420,50 @@ Cube::insert_slices(const uword slice_num, const uword N, const bool set_to_ { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_slices(slice_num, N); + } + + + +template +inline +void +Cube::insert_slices(const uword slice_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_slices = n_slices; const uword A_n_slices = slice_num; const uword B_n_slices = t_n_slices - slice_num; // insertion at slice_num == n_slices is in effect an append operation - arma_debug_check( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds"); + arma_debug_check_bounds( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds" ); + + if(N == 0) { return; } - if(N > 0) + Cube out(n_rows, n_cols, t_n_slices + N, arma_nozeros_indicator()); + + if(A_n_slices > 0) { - Cube out(n_rows, n_cols, t_n_slices + N); - - if(A_n_slices > 0) - { - out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); - } - - if(B_n_slices > 0) - { - out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices-1); - } - - if(set_to_zero == true) - { - //out.slices(slice_num, slice_num + N - 1).zeros(); - - for(uword i=slice_num; i < (slice_num + N); ++i) - { - arrayops::fill_zeros(out.slice_memptr(i), out.n_elem_slice); - } - } - - steal_mem(out); + out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); + } + + if(B_n_slices > 0) + { + out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices-1); + } + + //out.slices(slice_num, slice_num + N - 1).zeros(); + + for(uword i=slice_num; i < (slice_num + N); ++i) + { + arrayops::fill_zeros(out.slice_memptr(i), out.n_elem_slice); } + + steal_mem(out); } @@ -2288,7 +2487,7 @@ Cube::insert_rows(const uword row_num, const BaseCube& X) const uword B_n_rows = t_n_rows - row_num; // insertion at row_num == n_rows is in effect an append operation - arma_debug_check( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds"); + arma_debug_check_bounds( (row_num > t_n_rows), "Cube::insert_rows(): index out of bounds" ); arma_debug_check ( @@ -2296,24 +2495,23 @@ Cube::insert_rows(const uword row_num, const BaseCube& X) "Cube::insert_rows(): given object has incompatible dimensions" ); - if(N > 0) + if(N == 0) { return; } + + Cube out(t_n_rows + N, n_cols, n_slices, arma_nozeros_indicator()); + + if(A_n_rows > 0) { - Cube out(t_n_rows + N, n_cols, n_slices); - - if(A_n_rows > 0) - { - out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); - } - - if(B_n_rows > 0) - { - out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows - 1); - } - - out.rows(row_num, row_num + N - 1) = C; - - steal_mem(out); + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows - 1); } + + out.rows(row_num, row_num + N - 1) = C; + + steal_mem(out); } @@ -2337,7 +2535,7 @@ Cube::insert_cols(const uword col_num, const BaseCube& X) const uword B_n_cols = t_n_cols - col_num; // insertion at col_num == n_cols is in effect an append operation - arma_debug_check( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds"); + arma_debug_check_bounds( (col_num > t_n_cols), "Cube::insert_cols(): index out of bounds" ); arma_debug_check ( @@ -2345,24 +2543,23 @@ Cube::insert_cols(const uword col_num, const BaseCube& X) "Cube::insert_cols(): given object has incompatible dimensions" ); - if(N > 0) + if(N == 0) { return; } + + Cube out(n_rows, t_n_cols + N, n_slices, arma_nozeros_indicator()); + + if(A_n_cols > 0) { - Cube out(n_rows, t_n_cols + N, n_slices); - - if(A_n_cols > 0) - { - out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); - } - - if(B_n_cols > 0) - { - out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols - 1); - } - - out.cols(col_num, col_num + N - 1) = C; - - steal_mem(out); + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); + } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols - 1); } + + out.cols(col_num, col_num + N - 1) = C; + + steal_mem(out); } @@ -2388,7 +2585,7 @@ Cube::insert_slices(const uword slice_num, const BaseCube& X) const uword B_n_slices = t_n_slices - slice_num; // insertion at slice_num == n_slices is in effect an append operation - arma_debug_check( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds"); + arma_debug_check_bounds( (slice_num > t_n_slices), "Cube::insert_slices(): index out of bounds" ); arma_debug_check ( @@ -2396,29 +2593,45 @@ Cube::insert_slices(const uword slice_num, const BaseCube& X) "Cube::insert_slices(): given object has incompatible dimensions" ); - if(N > 0) + if(N == 0) { return; } + + Cube out(n_rows, n_cols, t_n_slices + N, arma_nozeros_indicator()); + + if(A_n_slices > 0) { - Cube out(n_rows, n_cols, t_n_slices + N); - - if(A_n_slices > 0) - { - out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); - } - - if(B_n_slices > 0) - { - out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices - 1); - } - - out.slices(slice_num, slice_num + N - 1) = C; - - steal_mem(out); + out.slices(0, A_n_slices-1) = slices(0, A_n_slices-1); } + + if(B_n_slices > 0) + { + out.slices(slice_num + N, t_n_slices + N - 1) = slices(slice_num, t_n_slices - 1); + } + + out.slices(slice_num, slice_num + N - 1) = C; + + steal_mem(out); + } + + + +template +template +inline +void +Cube::insert_slices(const uword slice_num, const Base& X) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(X.get_ref()); + + const Cube C(const_cast(U.M.memptr()), U.M.n_rows, U.M.n_cols, uword(1), false, true); + + (*this).insert_slices(slice_num, C); } -//! create a cube from GenCube, i.e. run the previously delayed element generation operations +//! create a cube from GenCube, ie. run the previously delayed element generation operations template template inline @@ -2428,9 +2641,9 @@ Cube::Cube(const GenCube& X) , n_elem_slice(X.n_rows*X.n_cols) , n_slices(X.n_slices) , n_elem(X.n_rows*X.n_cols*X.n_slices) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -2518,7 +2731,7 @@ Cube::operator/=(const GenCube& X) -//! create a cube from OpCube, i.e. run the previously delayed unary operations +//! create a cube from OpCube, ie. run the previously delayed unary operations template template inline @@ -2528,9 +2741,9 @@ Cube::Cube(const OpCube& X) , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -2541,7 +2754,7 @@ Cube::Cube(const OpCube& X) -//! create a cube from OpCube, i.e. run the previously delayed unary operations +//! create a cube from OpCube, ie. run the previously delayed unary operations template template inline @@ -2631,7 +2844,7 @@ Cube::operator/=(const OpCube& X) -//! create a cube from eOpCube, i.e. run the previously delayed unary operations +//! create a cube from eOpCube, ie. run the previously delayed unary operations template template inline @@ -2641,9 +2854,9 @@ Cube::Cube(const eOpCube& X) , n_elem_slice(X.get_n_elem_slice()) , n_slices(X.get_n_slices()) , n_elem(X.get_n_elem()) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -2656,7 +2869,7 @@ Cube::Cube(const eOpCube& X) -//! create a cube from eOpCube, i.e. run the previously delayed unary operations +//! create a cube from eOpCube, ie. run the previously delayed unary operations template template inline @@ -2669,18 +2882,11 @@ Cube::operator=(const eOpCube& X) const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); - if(bad_alias == false) - { - init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); - - eop_type::apply(*this, X); - } - else - { - Cube tmp(X); - - steal_mem(tmp); - } + if(bad_alias) { Cube tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); + + eop_type::apply(*this, X); return *this; } @@ -2698,6 +2904,10 @@ Cube::operator+=(const eOpCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator+=(tmp); } + eop_type::apply_inplace_plus(*this, X); return *this; @@ -2716,6 +2926,10 @@ Cube::operator-=(const eOpCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator-=(tmp); } + eop_type::apply_inplace_minus(*this, X); return *this; @@ -2733,7 +2947,11 @@ Cube::operator%=(const eOpCube& X) arma_extra_debug_sigprint(); arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); - + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator%=(tmp); } + eop_type::apply_inplace_schur(*this, X); return *this; @@ -2751,7 +2969,11 @@ Cube::operator/=(const eOpCube& X) arma_extra_debug_sigprint(); arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); - + + const bool bad_alias = ( X.P.has_subview && X.P.is_alias(*this) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator/=(tmp); } + eop_type::apply_inplace_div(*this, X); return *this; @@ -2768,9 +2990,9 @@ Cube::Cube(const mtOpCube& X) , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -2854,7 +3076,7 @@ Cube::operator/=(const mtOpCube& X) -//! create a cube from GlueCube, i.e. run the previously delayed binary operations +//! create a cube from GlueCube, ie. run the previously delayed binary operations template template inline @@ -2864,17 +3086,18 @@ Cube::Cube(const GlueCube& X) , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); + this->operator=(X); } -//! create a cube from GlueCube, i.e. run the previously delayed binary operations +//! create a cube from GlueCube, ie. run the previously delayed binary operations template template inline @@ -2968,7 +3191,7 @@ Cube::operator/=(const GlueCube& X) -//! create a cube from eGlueCube, i.e. run the previously delayed binary operations +//! create a cube from eGlueCube, ie. run the previously delayed binary operations template template inline @@ -2978,9 +3201,9 @@ Cube::Cube(const eGlueCube& X) , n_elem_slice(X.get_n_elem_slice()) , n_slices(X.get_n_slices()) , n_elem(X.get_n_elem()) + , n_alloc() , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -2994,7 +3217,7 @@ Cube::Cube(const eGlueCube& X) -//! create a cube from eGlueCube, i.e. run the previously delayed binary operations +//! create a cube from eGlueCube, ie. run the previously delayed binary operations template template inline @@ -3008,18 +3231,11 @@ Cube::operator=(const eGlueCube& X) const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); - if(bad_alias == false) - { - init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); - - eglue_type::apply(*this, X); - } - else - { - Cube tmp(X); - - steal_mem(tmp); - } + if(bad_alias) { Cube tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols(), X.get_n_slices()); + + eglue_type::apply(*this, X); return *this; } @@ -3038,6 +3254,10 @@ Cube::operator+=(const eGlueCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator+=(tmp); } + eglue_type::apply_inplace_plus(*this, X); return *this; @@ -3057,6 +3277,10 @@ Cube::operator-=(const eGlueCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator-=(tmp); } + eglue_type::apply_inplace_minus(*this, X); return *this; @@ -3076,6 +3300,10 @@ Cube::operator%=(const eGlueCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator%=(tmp); } + eglue_type::apply_inplace_schur(*this, X); return *this; @@ -3095,6 +3323,10 @@ Cube::operator/=(const eGlueCube& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = ( (X.P1.has_subview && X.P1.is_alias(*this)) || (X.P2.has_subview && X.P2.is_alias(*this)) ); + + if(bad_alias) { const Cube tmp(X); return (*this).operator/=(tmp); } + eglue_type::apply_inplace_div(*this, X); return *this; @@ -3111,9 +3343,9 @@ Cube::Cube(const mtGlueCube& X) , n_elem_slice(0) , n_slices(0) , n_elem(0) + , n_alloc(0) , mem_state(0) , mem() - , mat_ptrs(0) { arma_extra_debug_sigprint_this(this); @@ -3200,11 +3432,11 @@ Cube::operator/=(const mtGlueCube& X) //! linear element accessor (treats the cube as a vector); no bounds check; assumes memory is aligned template arma_inline -arma_warn_unused const eT& Cube::at_alt(const uword i) const { const eT* mem_aligned = mem; + memory::mark_as_aligned(mem_aligned); return mem_aligned[i]; @@ -3215,11 +3447,11 @@ Cube::at_alt(const uword i) const //! linear element accessor (treats the cube as a vector); bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused eT& Cube::operator() (const uword i) { - arma_debug_check( (i >= n_elem), "Cube::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "Cube::operator(): index out of bounds" ); + return access::rw(mem[i]); } @@ -3228,11 +3460,11 @@ Cube::operator() (const uword i) //! linear element accessor (treats the cube as a vector); bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused const eT& Cube::operator() (const uword i) const { - arma_debug_check( (i >= n_elem), "Cube::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "Cube::operator(): index out of bounds" ); + return mem[i]; } @@ -3240,7 +3472,6 @@ Cube::operator() (const uword i) const //! linear element accessor (treats the cube as a vector); no bounds check. template arma_inline -arma_warn_unused eT& Cube::operator[] (const uword i) { @@ -3252,7 +3483,6 @@ Cube::operator[] (const uword i) //! linear element accessor (treats the cube as a vector); no bounds check template arma_inline -arma_warn_unused const eT& Cube::operator[] (const uword i) const { @@ -3264,7 +3494,6 @@ Cube::operator[] (const uword i) const //! linear element accessor (treats the cube as a vector); no bounds check. template arma_inline -arma_warn_unused eT& Cube::at(const uword i) { @@ -3276,7 +3505,6 @@ Cube::at(const uword i) //! linear element accessor (treats the cube as a vector); no bounds check template arma_inline -arma_warn_unused const eT& Cube::at(const uword i) const { @@ -3288,11 +3516,10 @@ Cube::at(const uword i) const //! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused eT& Cube::operator() (const uword in_row, const uword in_col, const uword in_slice) { - arma_debug_check + arma_debug_check_bounds ( (in_row >= n_rows) || (in_col >= n_cols) || @@ -3309,11 +3536,10 @@ Cube::operator() (const uword in_row, const uword in_col, const uword in_sli //! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused const eT& Cube::operator() (const uword in_row, const uword in_col, const uword in_slice) const { - arma_debug_check + arma_debug_check_bounds ( (in_row >= n_rows) || (in_col >= n_cols) || @@ -3327,10 +3553,35 @@ Cube::operator() (const uword in_row, const uword in_col, const uword in_sli +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + eT& + Cube::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + return access::rw( mem[in_slice*n_elem_slice + in_col*n_rows + in_row] ); + } + + + + //! element accessor; no bounds check + template + arma_inline + const eT& + Cube::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + return mem[in_slice*n_elem_slice + in_col*n_rows + in_row]; + } + +#endif + + + //! element accessor; no bounds check template arma_inline -arma_warn_unused eT& Cube::at(const uword in_row, const uword in_col, const uword in_slice) { @@ -3342,7 +3593,6 @@ Cube::at(const uword in_row, const uword in_col, const uword in_slice) //! element accessor; no bounds check template arma_inline -arma_warn_unused const eT& Cube::at(const uword in_row, const uword in_col, const uword in_slice) const { @@ -3358,6 +3608,7 @@ const Cube& Cube::operator++() { Cube_aux::prefix_pp(*this); + return *this; } @@ -3397,54 +3648,61 @@ Cube::operator--(int) -//! returns true if all of the elements are finite +//! returns true if the cube has no elements +template +arma_inline +bool +Cube::is_empty() const + { + return (n_elem == 0); + } + + + template inline -arma_warn_unused bool -Cube::is_finite() const +Cube::internal_is_finite() const { arma_extra_debug_sigprint(); - return arrayops::is_finite( memptr(), n_elem ); + return arrayops::is_finite(memptr(), n_elem); } -//! returns true if the cube has no elements template -arma_inline -arma_warn_unused +inline bool -Cube::is_empty() const +Cube::internal_has_inf() const { - return (n_elem == 0); + arma_extra_debug_sigprint(); + + return arrayops::has_inf(memptr(), n_elem); } template inline -arma_warn_unused bool -Cube::has_inf() const +Cube::internal_has_nan() const { arma_extra_debug_sigprint(); - return arrayops::has_inf( memptr(), n_elem ); + return arrayops::has_nan(memptr(), n_elem); } template inline -arma_warn_unused bool -Cube::has_nan() const +Cube::internal_has_nonfinite() const { arma_extra_debug_sigprint(); - return arrayops::has_nan( memptr(), n_elem ); + return (arrayops::is_finite(memptr(), n_elem) == false); } @@ -3452,7 +3710,6 @@ Cube::has_nan() const //! returns true if the given index is currently in range template arma_inline -arma_warn_unused bool Cube::in_range(const uword i) const { @@ -3464,13 +3721,12 @@ Cube::in_range(const uword i) const //! returns true if the given start and end indices are currently in range template arma_inline -arma_warn_unused bool Cube::in_range(const span& x) const { arma_extra_debug_sigprint(); - if(x.whole == true) + if(x.whole) { return true; } @@ -3488,7 +3744,6 @@ Cube::in_range(const span& x) const //! returns true if the given location is currently in range template arma_inline -arma_warn_unused bool Cube::in_range(const uword in_row, const uword in_col, const uword in_slice) const { @@ -3499,7 +3754,6 @@ Cube::in_range(const uword in_row, const uword in_col, const uword in_slice) template inline -arma_warn_unused bool Cube::in_range(const span& row_span, const span& col_span, const span& slice_span) const { @@ -3520,14 +3774,13 @@ Cube::in_range(const span& row_span, const span& col_span, const span& slice const bool slices_ok = slice_span.whole ? true : ( (in_slice1 <= in_slice2) && (in_slice2 < n_slices) ); - return ( (rows_ok == true) && (cols_ok == true) && (slices_ok == true) ); + return ( rows_ok && cols_ok && slices_ok ); } template inline -arma_warn_unused bool Cube::in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const { @@ -3553,7 +3806,6 @@ Cube::in_range(const uword in_row, const uword in_col, const uword in_slice, //! returns a pointer to array of eTs used by the cube template arma_inline -arma_warn_unused eT* Cube::memptr() { @@ -3565,186 +3817,84 @@ Cube::memptr() //! returns a pointer to array of eTs used by the cube template arma_inline -arma_warn_unused -const eT* -Cube::memptr() const - { - return mem; - } - - - -//! returns a pointer to array of eTs used by the specified slice in the cube -template -arma_inline -arma_warn_unused -eT* -Cube::slice_memptr(const uword uslice) - { - return const_cast( &mem[ uslice*n_elem_slice ] ); - } - - - -//! returns a pointer to array of eTs used by the specified slice in the cube -template -arma_inline -arma_warn_unused -const eT* -Cube::slice_memptr(const uword uslice) const - { - return &mem[ uslice*n_elem_slice ]; - } - - - -//! returns a pointer to array of eTs used by the specified slice in the cube -template -arma_inline -arma_warn_unused -eT* -Cube::slice_colptr(const uword uslice, const uword col) - { - return const_cast( &mem[ uslice*n_elem_slice + col*n_rows] ); - } - - - -//! returns a pointer to array of eTs used by the specified slice in the cube -template -arma_inline -arma_warn_unused -const eT* -Cube::slice_colptr(const uword uslice, const uword col) const - { - return &mem[ uslice*n_elem_slice + col*n_rows ]; - } - - - -//! print contents of the cube (to the cout stream), -//! optionally preceding with a user specified line of text. -//! the precision and cell width are modified. -//! on return, the stream's state are restored to their original values. -template -arma_cold -inline -void -Cube::impl_print(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - get_cout_stream() << extra_text << '\n'; - } - - arma_ostream::print(get_cout_stream(), *this, true); +const eT* +Cube::memptr() const + { + return mem; } -//! print contents of the cube to a user specified stream, -//! optionally preceding with a user specified line of text. -//! the precision and cell width are modified. -//! on return, the stream's state are restored to their original values. + +//! returns a pointer to array of eTs used by the specified slice in the cube template -arma_cold -inline -void -Cube::impl_print(std::ostream& user_stream, const std::string& extra_text) const +arma_inline +eT* +Cube::slice_memptr(const uword uslice) { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - user_stream << extra_text << '\n'; - } - - arma_ostream::print(user_stream, *this, true); + return const_cast( &mem[ uslice*n_elem_slice ] ); } -//! print contents of the cube (to the cout stream), -//! optionally preceding with a user specified line of text. -//! the stream's state are used as is and are not modified -//! (i.e. the precision and cell width are not modified). +//! returns a pointer to array of eTs used by the specified slice in the cube template -arma_cold -inline -void -Cube::impl_raw_print(const std::string& extra_text) const +arma_inline +const eT* +Cube::slice_memptr(const uword uslice) const { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - get_cout_stream() << extra_text << '\n'; - } - - arma_ostream::print(get_cout_stream(), *this, false); + return &mem[ uslice*n_elem_slice ]; } -//! print contents of the cube to a user specified stream, -//! optionally preceding with a user specified line of text. -//! the stream's state are used as is and are not modified. -//! (i.e. the precision and cell width are not modified). +//! returns a pointer to array of eTs used by the specified slice in the cube template -arma_cold -inline -void -Cube::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const +arma_inline +eT* +Cube::slice_colptr(const uword uslice, const uword col) { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - user_stream << extra_text << '\n'; - } - - arma_ostream::print(user_stream, *this, false); + return const_cast( &mem[ uslice*n_elem_slice + col*n_rows] ); } -//! change the cube to have user specified dimensions (data is not preserved) +//! returns a pointer to array of eTs used by the specified slice in the cube template -inline -void -Cube::set_size(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) +arma_inline +const eT* +Cube::slice_colptr(const uword uslice, const uword col) const { - arma_extra_debug_sigprint(); - - init_warm(in_n_rows, in_n_cols, in_n_slices); + return &mem[ uslice*n_elem_slice + col*n_rows ]; } -//! change the cube to have user specified dimensions (data is preserved) +//! change the cube to have user specified dimensions (data is not preserved) template inline -void -Cube::reshape(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::set_size(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, in_rows, in_cols, in_slices); + init_warm(new_n_rows, new_n_cols, new_n_slices); + + return *this; } -//! NOTE: don't use this form; it's deprecated and will be removed +//! change the cube to have user specified dimensions (data is preserved) template -arma_deprecated inline -void -Cube::reshape(const uword in_rows, const uword in_cols, const uword in_slices, const uword dim) +Cube& +Cube::reshape(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, in_rows, in_cols, in_slices, dim); + op_reshape::apply_cube_inplace((*this), new_n_rows, new_n_cols, new_n_slices); + + return *this; } @@ -3752,48 +3902,56 @@ Cube::reshape(const uword in_rows, const uword in_cols, const uword in_slice //! change the cube to have user specified dimensions (data is preserved) template inline -void -Cube::resize(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::resize(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - *this = arma::resize(*this, in_rows, in_cols, in_slices); + op_resize::apply_cube_inplace((*this), new_n_rows, new_n_cols, new_n_slices); + + return *this; } template inline -void +Cube& Cube::set_size(const SizeCube& s) { arma_extra_debug_sigprint(); init_warm(s.n_rows, s.n_cols, s.n_slices); + + return *this; } template inline -void +Cube& Cube::reshape(const SizeCube& s) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, s.n_rows, s.n_cols, s.n_slices, 0); + op_reshape::apply_cube_inplace((*this), s.n_rows, s.n_cols, s.n_slices); + + return *this; } template inline -void +Cube& Cube::resize(const SizeCube& s) { arma_extra_debug_sigprint(); - *this = arma::resize(*this, s.n_rows, s.n_cols, s.n_slices); + op_resize::apply_cube_inplace((*this), s.n_rows, s.n_cols, s.n_slices); + + return *this; } @@ -3802,12 +3960,14 @@ Cube::resize(const SizeCube& s) template template inline -void +Cube& Cube::copy_size(const Cube& m) { arma_extra_debug_sigprint(); init_warm(m.n_rows, m.n_cols, m.n_slices); + + return *this; } @@ -3816,7 +3976,7 @@ Cube::copy_size(const Cube& m) template template inline -const Cube& +Cube& Cube::for_each(functor F) { arma_extra_debug_sigprint(); @@ -3877,7 +4037,7 @@ Cube::for_each(functor F) const template template inline -const Cube& +Cube& Cube::transform(functor F) { arma_extra_debug_sigprint(); @@ -3914,7 +4074,7 @@ Cube::transform(functor F) template template inline -const Cube& +Cube& Cube::imbue(functor F) { arma_extra_debug_sigprint(); @@ -3946,7 +4106,7 @@ Cube::imbue(functor F) template inline -const Cube& +Cube& Cube::replace(const eT old_val, const eT new_val) { arma_extra_debug_sigprint(); @@ -3960,7 +4120,7 @@ Cube::replace(const eT old_val, const eT new_val) template inline -const Cube& +Cube& Cube::clean(const typename get_pod_type::result threshold) { arma_extra_debug_sigprint(); @@ -3972,10 +4132,34 @@ Cube::clean(const typename get_pod_type::result threshold) +template +inline +Cube& +Cube::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Cube::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Cube::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "Cube::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + arrayops::clamp(memptr(), n_elem, min_val, max_val); + + return *this; + } + + + //! fill the cube with the specified value template inline -const Cube& +Cube& Cube::fill(const eT val) { arma_extra_debug_sigprint(); @@ -3989,7 +4173,7 @@ Cube::fill(const eT val) template inline -const Cube& +Cube& Cube::zeros() { arma_extra_debug_sigprint(); @@ -4003,12 +4187,12 @@ Cube::zeros() template inline -const Cube& -Cube::zeros(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::zeros(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols, in_slices); + set_size(new_n_rows, new_n_cols, new_n_slices); return (*this).zeros(); } @@ -4017,7 +4201,7 @@ Cube::zeros(const uword in_rows, const uword in_cols, const uword in_slices) template inline -const Cube& +Cube& Cube::zeros(const SizeCube& s) { arma_extra_debug_sigprint(); @@ -4029,7 +4213,7 @@ Cube::zeros(const SizeCube& s) template inline -const Cube& +Cube& Cube::ones() { arma_extra_debug_sigprint(); @@ -4041,12 +4225,12 @@ Cube::ones() template inline -const Cube& -Cube::ones(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::ones(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols, in_slices); + set_size(new_n_rows, new_n_cols, new_n_slices); return (*this).fill(eT(1)); } @@ -4055,7 +4239,7 @@ Cube::ones(const uword in_rows, const uword in_cols, const uword in_slices) template inline -const Cube& +Cube& Cube::ones(const SizeCube& s) { arma_extra_debug_sigprint(); @@ -4067,7 +4251,7 @@ Cube::ones(const SizeCube& s) template inline -const Cube& +Cube& Cube::randu() { arma_extra_debug_sigprint(); @@ -4081,12 +4265,12 @@ Cube::randu() template inline -const Cube& -Cube::randu(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::randu(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols, in_slices); + set_size(new_n_rows, new_n_cols, new_n_slices); return (*this).randu(); } @@ -4095,7 +4279,7 @@ Cube::randu(const uword in_rows, const uword in_cols, const uword in_slices) template inline -const Cube& +Cube& Cube::randu(const SizeCube& s) { arma_extra_debug_sigprint(); @@ -4107,7 +4291,7 @@ Cube::randu(const SizeCube& s) template inline -const Cube& +Cube& Cube::randn() { arma_extra_debug_sigprint(); @@ -4121,12 +4305,12 @@ Cube::randn() template inline -const Cube& -Cube::randn(const uword in_rows, const uword in_cols, const uword in_slices) +Cube& +Cube::randn(const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols, in_slices); + set_size(new_n_rows, new_n_cols, new_n_slices); return (*this).randn(); } @@ -4135,7 +4319,7 @@ Cube::randn(const uword in_rows, const uword in_cols, const uword in_slices) template inline -const Cube& +Cube& Cube::randn(const SizeCube& s) { arma_extra_debug_sigprint(); @@ -4171,7 +4355,7 @@ Cube::soft_reset() } else { - fill(Datum::nan); + zeros(); } } @@ -4205,7 +4389,6 @@ Cube::set_imag(const BaseCube::pod_type,T1>& X) template inline -arma_warn_unused eT Cube::min() const { @@ -4225,7 +4408,6 @@ Cube::min() const template inline -arma_warn_unused eT Cube::max() const { @@ -4358,9 +4540,8 @@ Cube::max(uword& row_of_max_val, uword& col_of_max_val, uword& slice_of_max_ //! save the cube to a file template inline -arma_cold bool -Cube::save(const std::string name, const file_type type, const bool print_status) const +Cube::save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); @@ -4397,11 +4578,11 @@ Cube::save(const std::string name, const file_type type, const bool print_st break; default: - if(print_status) { arma_debug_warn("Cube::save(): unsupported file type"); } + arma_debug_warn_level(1, "Cube::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("Cube::save(): couldn't write to ", name); } + if(save_okay == false) { arma_debug_warn_level(3, "Cube::save(): write failed; file: ", name); } return save_okay; } @@ -4410,9 +4591,8 @@ Cube::save(const std::string name, const file_type type, const bool print_st template inline -arma_cold bool -Cube::save(const hdf5_name& spec, const file_type type, const bool print_status) const +Cube::save(const hdf5_name& spec, const file_type type) const { arma_extra_debug_sigprint(); @@ -4420,7 +4600,7 @@ Cube::save(const hdf5_name& spec, const file_type type, const bool print_sta if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) { - arma_debug_check(true, "Cube::save(): unsupported file type for hdf5_name()"); + arma_stop_runtime_error("Cube::save(): unsupported file type for hdf5_name()"); return false; } @@ -4430,7 +4610,7 @@ Cube::save(const hdf5_name& spec, const file_type type, const bool print_sta if(append && replace) { - arma_debug_check(true, "Cube::save(): only one of 'append' or 'replace' options can be used"); + arma_stop_runtime_error("Cube::save(): only one of 'append' or 'replace' options can be used"); return false; } @@ -4450,15 +4630,15 @@ Cube::save(const hdf5_name& spec, const file_type type, const bool print_sta save_okay = diskio::save_hdf5_binary(*this, spec, err_msg); } - if((print_status == true) && (save_okay == false)) + if(save_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("Cube::save(): ", err_msg, spec.filename); + arma_debug_warn_level(3, "Cube::save(): ", err_msg, "; file: ", spec.filename); } else { - arma_debug_warn("Cube::save(): couldn't write to ", spec.filename); + arma_debug_warn_level(3, "Cube::save(): write failed; file: ", spec.filename); } } @@ -4470,9 +4650,8 @@ Cube::save(const hdf5_name& spec, const file_type type, const bool print_sta //! save the cube to a stream template inline -arma_cold bool -Cube::save(std::ostream& os, const file_type type, const bool print_status) const +Cube::save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); @@ -4501,11 +4680,11 @@ Cube::save(std::ostream& os, const file_type type, const bool print_status) break; default: - if(print_status) { arma_debug_warn("Cube::save(): unsupported file type"); } + arma_debug_warn_level(1, "Cube::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("Cube::save(): couldn't write to given stream"); } + if(save_okay == false) { arma_debug_warn_level(3, "Cube::save(): stream write failed"); } return save_okay; } @@ -4515,9 +4694,8 @@ Cube::save(std::ostream& os, const file_type type, const bool print_status) //! load a cube from a file template inline -arma_cold bool -Cube::load(const std::string name, const file_type type, const bool print_status) +Cube::load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); @@ -4559,27 +4737,24 @@ Cube::load(const std::string name, const file_type type, const bool print_st break; default: - if(print_status) { arma_debug_warn("Cube::load(): unsupported file type"); } + arma_debug_warn_level(1, "Cube::load(): unsupported file type"); load_okay = false; } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { + (*this).soft_reset(); + if(err_msg.length() > 0) { - arma_debug_warn("Cube::load(): ", err_msg, name); + arma_debug_warn_level(3, "Cube::load(): ", err_msg, "; file: ", name); } else { - arma_debug_warn("Cube::load(): couldn't read ", name); + arma_debug_warn_level(3, "Cube::load(): read failed; file: ", name); } } - if(load_okay == false) - { - (*this).soft_reset(); - } - return load_okay; } @@ -4587,16 +4762,14 @@ Cube::load(const std::string name, const file_type type, const bool print_st template inline -arma_cold bool -Cube::load(const hdf5_name& spec, const file_type type, const bool print_status) +Cube::load(const hdf5_name& spec, const file_type type) { arma_extra_debug_sigprint(); if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) { - if(print_status) { arma_debug_warn("Cube::load(): unsupported file type for hdf5_name()"); } - (*this).soft_reset(); + arma_stop_runtime_error("Cube::load(): unsupported file type for hdf5_name()"); return false; } @@ -4619,23 +4792,20 @@ Cube::load(const hdf5_name& spec, const file_type type, const bool print_sta } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { + (*this).soft_reset(); + if(err_msg.length() > 0) { - arma_debug_warn("Cube::load(): ", err_msg, spec.filename); + arma_debug_warn_level(3, "Cube::load(): ", err_msg, "; file: ", spec.filename); } else { - arma_debug_warn("Cube::load(): couldn't read ", spec.filename); + arma_debug_warn_level(3, "Cube::load(): read failed; file: ", spec.filename); } } - if(load_okay == false) - { - (*this).soft_reset(); - } - return load_okay; } @@ -4644,9 +4814,8 @@ Cube::load(const hdf5_name& spec, const file_type type, const bool print_sta //! load a cube from a stream template inline -arma_cold bool -Cube::load(std::istream& is, const file_type type, const bool print_status) +Cube::load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); @@ -4680,110 +4849,97 @@ Cube::load(std::istream& is, const file_type type, const bool print_status) break; default: - if(print_status) { arma_debug_warn("Cube::load(): unsupported file type"); } + arma_debug_warn_level(1, "Cube::load(): unsupported file type"); load_okay = false; } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { + (*this).soft_reset(); + if(err_msg.length() > 0) { - arma_debug_warn("Cube::load(): ", err_msg, "the given stream"); + arma_debug_warn_level(3, "Cube::load(): ", err_msg); } else { - arma_debug_warn("Cube::load(): couldn't load from the given stream"); + arma_debug_warn_level(3, "Cube::load(): stream read failed"); } } - if(load_okay == false) - { - (*this).soft_reset(); - } - return load_okay; } -//! save the cube to a file, without printing any error messages template inline -arma_cold bool Cube::quiet_save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(name, type, false); + return (*this).save(name, type); } template inline -arma_cold bool Cube::quiet_save(const hdf5_name& spec, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(spec, type, false); + return (*this).save(spec, type); } -//! save the cube to a stream, without printing any error messages template inline -arma_cold bool Cube::quiet_save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(os, type, false); + return (*this).save(os, type); } -//! load a cube from a file, without printing any error messages template inline -arma_cold bool Cube::quiet_load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(name, type, false); + return (*this).load(name, type); } template inline -arma_cold bool Cube::quiet_load(const hdf5_name& spec, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(spec, type, false); + return (*this).load(spec, type); } -//! load a cube from a stream, without printing any error messages template inline -arma_cold bool Cube::quiet_load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(is, type, false); + return (*this).load(is, type); } @@ -4867,7 +5023,7 @@ Cube::begin_slice(const uword slice_num) { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "begin_slice(): index out of bounds"); + arma_debug_check_bounds( (slice_num >= n_slices), "begin_slice(): index out of bounds" ); return slice_memptr(slice_num); } @@ -4881,7 +5037,7 @@ Cube::begin_slice(const uword slice_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "begin_slice(): index out of bounds"); + arma_debug_check_bounds( (slice_num >= n_slices), "begin_slice(): index out of bounds" ); return slice_memptr(slice_num); } @@ -4895,7 +5051,7 @@ Cube::end_slice(const uword slice_num) { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "end_slice(): index out of bounds"); + arma_debug_check_bounds( (slice_num >= n_slices), "end_slice(): index out of bounds" ); return slice_memptr(slice_num) + n_elem_slice; } @@ -4909,7 +5065,7 @@ Cube::end_slice(const uword slice_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "end_slice(): index out of bounds"); + arma_debug_check_bounds( (slice_num >= n_slices), "end_slice(): index out of bounds" ); return slice_memptr(slice_num) + n_elem_slice; } @@ -4949,6 +5105,54 @@ Cube::size() const +template +inline +eT& +Cube::front() + { + arma_debug_check( (n_elem == 0), "Cube::front(): cube is empty" ); + + return access::rw(mem[0]); + } + + + +template +inline +const eT& +Cube::front() const + { + arma_debug_check( (n_elem == 0), "Cube::front(): cube is empty" ); + + return mem[0]; + } + + + +template +inline +eT& +Cube::back() + { + arma_debug_check( (n_elem == 0), "Cube::back(): cube is empty" ); + + return access::rw(mem[n_elem-1]); + } + + + +template +inline +const eT& +Cube::back() const + { + arma_debug_check( (n_elem == 0), "Cube::back(): cube is empty" ); + + return mem[n_elem-1]; + } + + + template inline void @@ -5027,10 +5231,24 @@ Cube::steal_mem(Cube& x) { arma_extra_debug_sigprint(); + (*this).steal_mem(x, false); + } + + + +template +inline +void +Cube::steal_mem(Cube& x, const bool is_move) + { + arma_extra_debug_sigprint(); + if(this == &x) { return; } - if( (mem_state <= 1) && ( ((x.mem_state == 0) && (x.n_elem > Cube_prealloc::mem_n_elem)) || (x.mem_state == 1) ) ) + if( (mem_state <= 1) && ( (x.n_alloc > Cube_prealloc::mem_n_elem) || (x.mem_state == 1) || (is_move && (x.mem_state == 2)) ) ) { + arma_extra_debug_print("Cube::steal_mem(): stealing memory"); + reset(); const uword x_n_slices = x.n_slices; @@ -5040,22 +5258,27 @@ Cube::steal_mem(Cube& x) access::rw(n_elem_slice) = x.n_elem_slice; access::rw(n_slices) = x_n_slices; access::rw(n_elem) = x.n_elem; + access::rw(n_alloc) = x.n_alloc; access::rw(mem_state) = x.mem_state; access::rw(mem) = x.mem; if(x_n_slices > Cube_prealloc::mat_ptrs_size) { - access::rw( mat_ptrs) = x.mat_ptrs; - access::rw(x.mat_ptrs) = 0; + arma_extra_debug_print("Cube::steal_mem(): stealing mat_ptrs array"); + + mat_ptrs = x.mat_ptrs; + x.mat_ptrs = nullptr; } else { - access::rw(mat_ptrs) = const_cast< const Mat** >(mat_ptrs_local); + arma_extra_debug_print("Cube::steal_mem(): copying mat_ptrs array"); + + mat_ptrs = mat_ptrs_local; for(uword i=0; i < x_n_slices; ++i) { - mat_ptrs[i] = x.mat_ptrs[i]; - x.mat_ptrs[i] = 0; + mat_ptrs[i] = raw_mat_ptr_type(x.mat_ptrs[i]); // cast required by std::atomic + x.mat_ptrs[i] = nullptr; } } @@ -5064,12 +5287,20 @@ Cube::steal_mem(Cube& x) access::rw(x.n_elem_slice) = 0; access::rw(x.n_slices) = 0; access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; access::rw(x.mem_state) = 0; - access::rw(x.mem) = 0; + access::rw(x.mem) = nullptr; } else { + arma_extra_debug_print("Cube::steal_mem(): copying memory"); + (*this).operator=(x); + + if( (is_move) && (x.mem_state == 0) && (x.n_alloc <= Cube_prealloc::mem_n_elem) ) + { + x.reset(); + } } } @@ -5095,10 +5326,10 @@ Cube::fixed::mem_setup() access::rw(Cube::n_elem_slice) = fixed_n_rows * fixed_n_cols; access::rw(Cube::n_slices) = fixed_n_slices; access::rw(Cube::n_elem) = fixed_n_elem; + access::rw(Cube::n_alloc) = 0; access::rw(Cube::mem_state) = 3; access::rw(Cube::mem) = (fixed_n_elem > Cube_prealloc::mem_n_elem) ? mem_local_extra : mem_local; - access::rw(Cube::mat_ptrs) = const_cast< const Mat** >( \ - (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? mat_ptrs_local_extra : mat_ptrs_local ); + Cube::mat_ptrs = (fixed_n_slices > Cube_prealloc::mat_ptrs_size) ? mat_ptrs_local_extra : mat_ptrs_local; create_mat(); } @@ -5109,9 +5340,10 @@ Cube::fixed::mem_setup() access::rw(Cube::n_elem_slice) = 0; access::rw(Cube::n_slices) = 0; access::rw(Cube::n_elem) = 0; + access::rw(Cube::n_alloc) = 0; access::rw(Cube::mem_state) = 3; - access::rw(Cube::mem) = 0; - access::rw(Cube::mat_ptrs) = 0; + access::rw(Cube::mem) = nullptr; + Cube::mat_ptrs = nullptr; } } @@ -5125,6 +5357,15 @@ Cube::fixed::fixed() arma_extra_debug_sigprint_this(this); mem_setup(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Cube::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::fill_zeros(mem_use, fixed_n_elem); + } } @@ -5146,6 +5387,20 @@ Cube::fixed::fixed(const fixed +template +inline +Cube::fixed::fixed(const fill::scalar_holder f) + { + arma_extra_debug_sigprint_this(this); + + mem_setup(); + + (*this).fill(f.scalar); + } + + + template template template @@ -5156,12 +5411,12 @@ Cube::fixed::fixed(const fill::f mem_setup(); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } - if(is_same_type::yes) { arma_debug_check(true, "Cube::fixed::fixed(): unsupported fill type"); } + arma_static_check( (is_same_type::yes), "Cube::fixed::fixed(): unsupported fill type" ); } @@ -5217,7 +5472,6 @@ Cube::fixed::operator=(const fix template template arma_inline -arma_warn_unused eT& Cube::fixed::operator[] (const uword i) { @@ -5229,7 +5483,6 @@ Cube::fixed::operator[] (const u template template arma_inline -arma_warn_unused const eT& Cube::fixed::operator[] (const uword i) const { @@ -5241,7 +5494,6 @@ Cube::fixed::operator[] (const u template template arma_inline -arma_warn_unused eT& Cube::fixed::at(const uword i) { @@ -5253,7 +5505,6 @@ Cube::fixed::at(const uword i) template template arma_inline -arma_warn_unused const eT& Cube::fixed::at(const uword i) const { @@ -5265,11 +5516,10 @@ Cube::fixed::at(const uword i) c template template arma_inline -arma_warn_unused eT& Cube::fixed::operator() (const uword i) { - arma_debug_check( (i >= fixed_n_elem), "Cube::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= fixed_n_elem), "Cube::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[i] : mem_local[i]; } @@ -5279,21 +5529,49 @@ Cube::fixed::operator() (const u template template arma_inline -arma_warn_unused const eT& Cube::fixed::operator() (const uword i) const { - arma_debug_check( (i >= fixed_n_elem), "Cube::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= fixed_n_elem), "Cube::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[i] : mem_local[i]; } +#if defined(__cpp_multidimensional_subscript) + + template + template + arma_inline + eT& + Cube::fixed::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + + + + template + template + arma_inline + const eT& + Cube::fixed::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + const uword i = in_slice*fixed_n_elem_slice + in_col*fixed_n_rows + in_row; + + return (use_extra) ? mem_local_extra[i] : mem_local[i]; + } + +#endif + + + template template arma_inline -arma_warn_unused eT& Cube::fixed::at(const uword in_row, const uword in_col, const uword in_slice) { @@ -5307,7 +5585,6 @@ Cube::fixed::at(const uword in_r template template arma_inline -arma_warn_unused const eT& Cube::fixed::at(const uword in_row, const uword in_col, const uword in_slice) const { @@ -5321,11 +5598,10 @@ Cube::fixed::at(const uword in_r template template arma_inline -arma_warn_unused eT& Cube::fixed::operator() (const uword in_row, const uword in_col, const uword in_slice) { - arma_debug_check + arma_debug_check_bounds ( (in_row >= fixed_n_rows ) || (in_col >= fixed_n_cols ) || @@ -5344,11 +5620,10 @@ Cube::fixed::operator() (const u template template arma_inline -arma_warn_unused const eT& Cube::fixed::operator() (const uword in_row, const uword in_col, const uword in_slice) const { - arma_debug_check + arma_debug_check_bounds ( (in_row >= fixed_n_rows ) || (in_col >= fixed_n_cols ) || @@ -5572,11 +5847,7 @@ Cube_aux::set_real(Cube< std::complex >& out, const BaseCube& X) const uword N = out.n_elem; - for(uword i=0; i( A[i], out_mem[i].imag() ); - } + for(uword i=0; i >& out, const BaseCube& X) for(uword col = 0; col < local_n_cols; ++col ) for(uword row = 0; row < local_n_rows; ++row ) { - (*out_mem) = std::complex( P.at(row,col,slice), (*out_mem).imag() ); + (*out_mem).real(P.at(row,col,slice)); out_mem++; } } @@ -5624,11 +5895,7 @@ Cube_aux::set_imag(Cube< std::complex >& out, const BaseCube& X) const uword N = out.n_elem; - for(uword i=0; i( out_mem[i].real(), A[i] ); - } + for(uword i=0; i >& out, const BaseCube& X) for(uword col = 0; col < local_n_cols; ++col ) for(uword row = 0; row < local_n_rows; ++row ) { - (*out_mem) = std::complex( (*out_mem).real(), P.at(row,col,slice) ); + (*out_mem).imag(P.at(row,col,slice)); out_mem++; } } @@ -5644,7 +5911,7 @@ Cube_aux::set_imag(Cube< std::complex >& out, const BaseCube& X) -#ifdef ARMA_EXTRA_CUBE_MEAT +#if defined(ARMA_EXTRA_CUBE_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_CUBE_MEAT) #endif diff --git a/src/armadillo_bits/GenCube_bones.hpp b/src/armadillo_bits/GenCube_bones.hpp index 62a4982f..3c6099f3 100644 --- a/src/armadillo_bits/GenCube_bones.hpp +++ b/src/armadillo_bits/GenCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,19 +20,18 @@ //! @{ -//! support class for generator functions (eg. zeros, randu, randn, ...) +//! support class for generator functions (zeros, ones) template class GenCube - : public BaseCube > - , public GenSpecialiser::yes, is_same_type::yes, is_same_type::yes, is_same_type::yes> + : public BaseCube< eT, GenCube > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool use_at = false; - static const bool is_simple = (is_same_type::value) || (is_same_type::value); + static constexpr bool use_at = false; + static constexpr bool is_simple = (is_same_type::value) || (is_same_type::value); arma_aligned const uword n_rows; arma_aligned const uword n_cols; @@ -39,9 +40,9 @@ class GenCube arma_inline GenCube(const uword in_n_rows, const uword in_n_cols, const uword in_n_slices); arma_inline ~GenCube(); - arma_inline eT operator[] (const uword i) const; - arma_inline eT at (const uword row, const uword col, const uword slice) const; - arma_inline eT at_alt (const uword i) const; + arma_inline eT operator[] (const uword i) const; + arma_inline eT at (const uword r, const uword c, const uword s) const; + arma_inline eT at_alt (const uword i) const; inline void apply (Cube& out) const; inline void apply_inplace_plus (Cube& out) const; diff --git a/src/armadillo_bits/GenCube_meat.hpp b/src/armadillo_bits/GenCube_meat.hpp index 182d80ab..61735f3b 100644 --- a/src/armadillo_bits/GenCube_meat.hpp +++ b/src/armadillo_bits/GenCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -45,7 +47,10 @@ arma_inline eT GenCube::operator[](const uword) const { - return (*this).generate(); + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings } @@ -55,7 +60,10 @@ arma_inline eT GenCube::at(const uword, const uword, const uword) const { - return (*this).generate(); + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings } @@ -65,7 +73,10 @@ arma_inline eT GenCube::at_alt(const uword) const { - return (*this).generate(); + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + + return eT(0); // prevent pedantic compiler warnings } @@ -80,10 +91,8 @@ GenCube::apply(Cube& out) const // NOTE: we're assuming that the cube has already been set to the correct size; // this is done by either the Cube contructor or operator=() - if(is_same_type::yes) { out.ones(); } - else if(is_same_type::yes) { out.zeros(); } - else if(is_same_type::yes) { out.randu(); } - else if(is_same_type::yes) { out.randn(); } + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } } @@ -97,24 +106,9 @@ GenCube::apply_inplace_plus(Cube& out) const arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "addition"); - - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword i,j; - - for(i=0, j=1; j::yes) { - const eT tmp_i = (*this).generate(); - const eT tmp_j = (*this).generate(); - - out_mem[i] += tmp_i; - out_mem[j] += tmp_j; - } - - if(i < n_elem) - { - out_mem[i] += (*this).generate(); + arrayops::inplace_plus(out.memptr(), eT(1), out.n_elem); } } @@ -130,24 +124,9 @@ GenCube::apply_inplace_minus(Cube& out) const arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "subtraction"); - - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword i,j; - - for(i=0, j=1; j::yes) { - out_mem[i] -= (*this).generate(); + arrayops::inplace_minus(out.memptr(), eT(1), out.n_elem); } } @@ -163,24 +142,10 @@ GenCube::apply_inplace_schur(Cube& out) const arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise multiplication"); - - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword i,j; - - for(i=0, j=1; j::yes) { - out_mem[i] *= (*this).generate(); + arrayops::inplace_mul(out.memptr(), eT(0), out.n_elem); + // NOTE: not using arrayops::fill_zeros(), as 'out' may have NaN elements } } @@ -196,24 +161,9 @@ GenCube::apply_inplace_div(Cube& out) const arma_debug_assert_same_size(out.n_rows, out.n_cols, out.n_slices, n_rows, n_cols, n_slices, "element-wise division"); - - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword i,j; - - for(i=0, j=1; j::yes) { - out_mem[i] /= (*this).generate(); + arrayops::inplace_div(out.memptr(), eT(0), out.n_elem); } } @@ -229,10 +179,8 @@ GenCube::apply(subview_cube& out) const // NOTE: we're assuming that the subcube has the same dimensions as the GenCube object // this is checked by subview_cube::operator=() - if(is_same_type::yes) { out.ones(); } - else if(is_same_type::yes) { out.zeros(); } - else if(is_same_type::yes) { out.randu(); } - else if(is_same_type::yes) { out.randn(); } + if(is_same_type::yes) { out.zeros(); } + else if(is_same_type::yes) { out.ones(); } } diff --git a/src/armadillo_bits/GenSpecialiser.hpp b/src/armadillo_bits/GenSpecialiser.hpp deleted file mode 100644 index b67608bd..00000000 --- a/src/armadillo_bits/GenSpecialiser.hpp +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -//! \addtogroup GenSpecialiser -//! @{ - - -template -struct GenSpecialiser - { - arma_inline elem_type generate() const { return elem_type(); } - }; - - -template -struct GenSpecialiser - { - arma_inline elem_type generate() const { return elem_type(0); } - }; - - -template -struct GenSpecialiser - { - arma_inline elem_type generate() const { return elem_type(1); } - }; - - -template -struct GenSpecialiser - { - arma_inline elem_type generate() const { return elem_type(arma_rng::randu()); } - }; - - -template -struct GenSpecialiser - { - arma_inline elem_type generate() const { return elem_type(arma_rng::randn()); } - }; - - -//! @} diff --git a/src/armadillo_bits/Gen_bones.hpp b/src/armadillo_bits/Gen_bones.hpp index be481048..172e5b9c 100644 --- a/src/armadillo_bits/Gen_bones.hpp +++ b/src/armadillo_bits/Gen_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,23 +20,22 @@ //! @{ -//! support class for generator functions (eg. zeros, randu, randn, ...) +//! support class for generator functions (zeros, ones, eye) template class Gen - : public Base > - , public GenSpecialiser::yes, is_same_type::yes, is_same_type::yes, is_same_type::yes> + : public Base< typename T1::elem_type, Gen > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - static const bool use_at = (is_same_type::value); - static const bool is_simple = (is_same_type::value) || (is_same_type::value); + static constexpr bool use_at = (is_same_type::value); + static constexpr bool is_simple = (is_same_type::value) || (is_same_type::value); - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; arma_aligned const uword n_rows; arma_aligned const uword n_cols; @@ -42,9 +43,9 @@ class Gen arma_inline Gen(const uword in_n_rows, const uword in_n_cols); arma_inline ~Gen(); - arma_inline elem_type operator[] (const uword ii) const; - arma_inline elem_type at (const uword row, const uword col) const; - arma_inline elem_type at_alt (const uword ii) const; + arma_inline elem_type operator[] (const uword ii) const; + arma_inline elem_type at (const uword r, const uword c) const; + arma_inline elem_type at_alt (const uword ii) const; inline void apply (Mat& out) const; inline void apply_inplace_plus (Mat& out) const; diff --git a/src/armadillo_bits/Gen_meat.hpp b/src/armadillo_bits/Gen_meat.hpp index 6c3da397..96b8940c 100644 --- a/src/armadillo_bits/Gen_meat.hpp +++ b/src/armadillo_bits/Gen_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -46,14 +48,11 @@ Gen::operator[](const uword ii) const { typedef typename T1::elem_type eT; - if(is_same_type::yes) - { - return ((ii % n_rows) == (ii / n_rows)) ? eT(1) : eT(0); - } - else - { - return (*this).generate(); - } + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + else if(is_same_type::yes) { return ((ii % n_rows) == (ii / n_rows)) ? eT(1) : eT(0); } + + return eT(0); // prevent pedantic compiler warnings } @@ -61,18 +60,15 @@ Gen::operator[](const uword ii) const template arma_inline typename T1::elem_type -Gen::at(const uword row, const uword col) const +Gen::at(const uword r, const uword c) const { typedef typename T1::elem_type eT; - if(is_same_type::yes) - { - return (row == col) ? eT(1) : eT(0); - } - else - { - return (*this).generate(); - } + if(is_same_type::yes) { return eT(0); } + else if(is_same_type::yes) { return eT(1); } + else if(is_same_type::yes) { return (r == c) ? eT(1) : eT(0); } + + return eT(0); // prevent pedantic compiler warnings } @@ -97,11 +93,9 @@ Gen::apply(Mat& out) const // NOTE: we're assuming that the matrix has already been set to the correct size; // this is done by either the Mat contructor or operator=() - if(is_same_type::yes) { out.eye(); } + if(is_same_type::yes) { out.zeros(); } else if(is_same_type::yes) { out.ones(); } - else if(is_same_type::yes) { out.zeros(); } - else if(is_same_type::yes) { out.randu(); } - else if(is_same_type::yes) { out.randn(); } + else if(is_same_type::yes) { out.eye(); } } @@ -117,35 +111,16 @@ Gen::apply_inplace_plus(Mat& out) const typedef typename T1::elem_type eT; - - if(is_same_type::yes) + if(is_same_type::yes) { - const uword N = (std::min)(n_rows, n_cols); - - for(uword iq=0; iq < N; ++iq) - { - out.at(iq,iq) += eT(1); - } + arrayops::inplace_plus(out.memptr(), eT(1), out.n_elem); } else + if(is_same_type::yes) { - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword iq,jq; - for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2) - { - const eT tmp_i = (*this).generate(); - const eT tmp_j = (*this).generate(); - - out_mem[iq] += tmp_i; - out_mem[jq] += tmp_j; - } + const uword N = (std::min)(n_rows, n_cols); - if(iq < n_elem) - { - out_mem[iq] += (*this).generate(); - } + for(uword ii=0; ii < N; ++ii) { out.at(ii,ii) += eT(1); } } } @@ -163,35 +138,16 @@ Gen::apply_inplace_minus(Mat& out) const typedef typename T1::elem_type eT; - - if(is_same_type::yes) + if(is_same_type::yes) { - const uword N = (std::min)(n_rows, n_cols); - - for(uword iq=0; iq < N; ++iq) - { - out.at(iq,iq) -= eT(1); - } + arrayops::inplace_minus(out.memptr(), eT(1), out.n_elem); } else + if(is_same_type::yes) { - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword iq,jq; - for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2) - { - const eT tmp_i = (*this).generate(); - const eT tmp_j = (*this).generate(); - - out_mem[iq] -= tmp_i; - out_mem[jq] -= tmp_j; - } + const uword N = (std::min)(n_rows, n_cols); - if(iq < n_elem) - { - out_mem[iq] -= (*this).generate(); - } + for(uword ii=0; ii < N; ++ii) { out.at(ii,ii) -= eT(1); } } } @@ -209,35 +165,18 @@ Gen::apply_inplace_schur(Mat& out) const typedef typename T1::elem_type eT; - - if(is_same_type::yes) + if(is_same_type::yes) { - const uword N = (std::min)(n_rows, n_cols); - - for(uword iq=0; iq < N; ++iq) - { - for(uword row=0; row < iq; ++row) { out.at(row,iq) = eT(0); } - for(uword row=iq+1; row < n_rows; ++row) { out.at(row,iq) = eT(0); } - } + arrayops::inplace_mul(out.memptr(), eT(0), out.n_elem); + // NOTE: not using arrayops::fill_zeros(), as 'out' may have NaN elements } else + if(is_same_type::yes) { - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword iq,jq; - for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2) - { - const eT tmp_i = (*this).generate(); - const eT tmp_j = (*this).generate(); - - out_mem[iq] *= tmp_i; - out_mem[jq] *= tmp_j; - } - - if(iq < n_elem) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) { - out_mem[iq] *= (*this).generate(); + if(r != c) { out.at(r,c) *= eT(0); } } } } @@ -256,37 +195,17 @@ Gen::apply_inplace_div(Mat& out) const typedef typename T1::elem_type eT; - - if(is_same_type::yes) + if(is_same_type::yes) { - const uword N = (std::min)(n_rows, n_cols); - - for(uword iq=0; iq < N; ++iq) - { - const eT zero = eT(0); - - for(uword row=0; row < iq; ++row) { out.at(row,iq) /= zero; } - for(uword row=iq+1; row < n_rows; ++row) { out.at(row,iq) /= zero; } - } + arrayops::inplace_div(out.memptr(), eT(0), out.n_elem); } else + if(is_same_type::yes) { - eT* out_mem = out.memptr(); - const uword n_elem = out.n_elem; - - uword iq,jq; - for(iq=0, jq=1; jq < n_elem; iq+=2, jq+=2) - { - const eT tmp_i = (*this).generate(); - const eT tmp_j = (*this).generate(); - - out_mem[iq] /= tmp_i; - out_mem[jq] /= tmp_j; - } - - if(iq < n_elem) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) { - out_mem[iq] /= (*this).generate(); + if(r != c) { out.at(r,c) /= eT(0); } } } } @@ -303,11 +222,9 @@ Gen::apply(subview& out) const // NOTE: we're assuming that the submatrix has the same dimensions as the Gen object // this is checked by subview::operator=() - if(is_same_type::yes) { out.eye(); } + if(is_same_type::yes) { out.zeros(); } else if(is_same_type::yes) { out.ones(); } - else if(is_same_type::yes) { out.zeros(); } - else if(is_same_type::yes) { out.randu(); } - else if(is_same_type::yes) { out.randn(); } + else if(is_same_type::yes) { out.eye(); } } diff --git a/src/armadillo_bits/GlueCube_bones.hpp b/src/armadillo_bits/GlueCube_bones.hpp index c7a13507..75173ef9 100644 --- a/src/armadillo_bits/GlueCube_bones.hpp +++ b/src/armadillo_bits/GlueCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,19 +23,18 @@ //! analog of the Glue class, intended for Cube objects template -class GlueCube : public BaseCube > +class GlueCube : public BaseCube< typename T1::elem_type, GlueCube > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - - arma_inline GlueCube(const BaseCube& in_A, const BaseCube& in_B); - arma_inline ~GlueCube(); + + inline GlueCube(const BaseCube& in_A, const BaseCube& in_B); + inline ~GlueCube(); const T1& A; //!< first operand; must be derived from BaseCube const T2& B; //!< second operand; must be derived from BaseCube - }; diff --git a/src/armadillo_bits/GlueCube_meat.hpp b/src/armadillo_bits/GlueCube_meat.hpp index 5f42dca4..41956507 100644 --- a/src/armadillo_bits/GlueCube_meat.hpp +++ b/src/armadillo_bits/GlueCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/Glue_bones.hpp b/src/armadillo_bits/Glue_bones.hpp index c0c42931..197ae746 100644 --- a/src/armadillo_bits/Glue_bones.hpp +++ b/src/armadillo_bits/Glue_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,33 +28,33 @@ struct Glue_traits {}; template struct Glue_traits { - static const bool is_row = glue_type::template traits::is_row; - static const bool is_col = glue_type::template traits::is_col; - static const bool is_xvec = glue_type::template traits::is_xvec; + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; }; template struct Glue_traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; template class Glue - : public Base > - , public Glue_traits::value > + : public Base< typename T1::elem_type, Glue > + , public Glue_traits::value> { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - arma_inline Glue(const T1& in_A, const T2& in_B); - arma_inline Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword); - arma_inline ~Glue(); + inline Glue(const T1& in_A, const T2& in_B); + inline Glue(const T1& in_A, const T2& in_B, const uword in_aux_uword); + inline ~Glue(); const T1& A; //!< first operand; must be derived from Base const T2& B; //!< second operand; must be derived from Base diff --git a/src/armadillo_bits/Glue_meat.hpp b/src/armadillo_bits/Glue_meat.hpp index fa0e0ae2..713fb16a 100644 --- a/src/armadillo_bits/Glue_meat.hpp +++ b/src/armadillo_bits/Glue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/MapMat_bones.hpp b/src/armadillo_bits/MapMat_bones.hpp index 71a57ed7..7ab46b06 100644 --- a/src/armadillo_bits/MapMat_bones.hpp +++ b/src/armadillo_bits/MapMat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,9 +30,9 @@ class MapMat typedef eT elem_type; //!< the type of elements stored in the matrix typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const uword n_rows; //!< number of rows (read-only) const uword n_cols; //!< number of columns (read-only) @@ -58,10 +60,8 @@ class MapMat inline explicit MapMat(const SpMat& x); inline void operator=(const SpMat& x); - #if defined(ARMA_USE_CXX11) inline MapMat(MapMat&& x); inline void operator=(MapMat&& x); - #endif inline void reset(); inline void set_size(const uword in_n_rows); @@ -81,23 +81,23 @@ class MapMat inline void speye(const uword in_n_rows, const uword in_n_cols); inline void speye(const SizeMat& s); - arma_inline arma_warn_unused MapMat_val operator[](const uword index); - inline arma_warn_unused eT operator[](const uword index) const; + arma_warn_unused arma_inline MapMat_val operator[](const uword index); + arma_warn_unused inline eT operator[](const uword index) const; - arma_inline arma_warn_unused MapMat_val operator()(const uword index); - inline arma_warn_unused eT operator()(const uword index) const; + arma_warn_unused arma_inline MapMat_val operator()(const uword index); + arma_warn_unused inline eT operator()(const uword index) const; - arma_inline arma_warn_unused MapMat_val at(const uword in_row, const uword in_col); - inline arma_warn_unused eT at(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline MapMat_val at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused MapMat_val operator()(const uword in_row, const uword in_col); - inline arma_warn_unused eT operator()(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline MapMat_val operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; - inline arma_warn_unused bool is_empty() const; - inline arma_warn_unused bool is_vec() const; - inline arma_warn_unused bool is_rowvec() const; - inline arma_warn_unused bool is_colvec() const; - inline arma_warn_unused bool is_square() const; + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_square() const; inline void sprandu(const uword in_n_rows, const uword in_n_cols, const double density); @@ -195,11 +195,11 @@ class SpMat_MapMat_val inline SpMat_MapMat_val& operator*=(const eT in_val); inline SpMat_MapMat_val& operator/=(const eT in_val); - inline SpMat_MapMat_val& operator++(); - inline arma_warn_unused eT operator++(int); + inline SpMat_MapMat_val& operator++(); + arma_warn_unused inline eT operator++(int); - inline SpMat_MapMat_val& operator--(); - inline arma_warn_unused eT operator--(int); + inline SpMat_MapMat_val& operator--(); + arma_warn_unused inline eT operator--(int); inline void set(const eT in_val); inline void add(const eT in_val); @@ -235,11 +235,11 @@ class SpSubview_MapMat_val : public SpMat_MapMat_val inline SpSubview_MapMat_val& operator*=(const eT in_val); inline SpSubview_MapMat_val& operator/=(const eT in_val); - inline SpSubview_MapMat_val& operator++(); - inline arma_warn_unused eT operator++(int); + inline SpSubview_MapMat_val& operator++(); + arma_warn_unused inline eT operator++(int); - inline SpSubview_MapMat_val& operator--(); - inline arma_warn_unused eT operator--(int); + inline SpSubview_MapMat_val& operator--(); + arma_warn_unused inline eT operator--(int); }; diff --git a/src/armadillo_bits/MapMat_meat.hpp b/src/armadillo_bits/MapMat_meat.hpp index b7d047d8..e67a3072 100644 --- a/src/armadillo_bits/MapMat_meat.hpp +++ b/src/armadillo_bits/MapMat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,7 +30,7 @@ MapMat::~MapMat() if(map_ptr) { (*map_ptr).clear(); delete map_ptr; } // try to expose buggy user code that accesses deleted objects - if(arma_config::debug) { map_ptr = NULL; } + if(arma_config::debug) { map_ptr = nullptr; } arma_type_check(( is_supported_elem_type::value == false )); } @@ -41,7 +43,7 @@ MapMat::MapMat() : n_rows (0) , n_cols (0) , n_elem (0) - , map_ptr(NULL) + , map_ptr(nullptr) { arma_extra_debug_sigprint_this(this); @@ -56,7 +58,7 @@ MapMat::MapMat(const uword in_n_rows, const uword in_n_cols) : n_rows (in_n_rows) , n_cols (in_n_cols) , n_elem (in_n_rows * in_n_cols) - , map_ptr(NULL) + , map_ptr(nullptr) { arma_extra_debug_sigprint_this(this); @@ -71,7 +73,7 @@ MapMat::MapMat(const SizeMat& s) : n_rows (s.n_rows) , n_cols (s.n_cols) , n_elem (s.n_rows * s.n_cols) - , map_ptr(NULL) + , map_ptr(nullptr) { arma_extra_debug_sigprint_this(this); @@ -86,7 +88,7 @@ MapMat::MapMat(const MapMat& x) : n_rows (0) , n_cols (0) , n_elem (0) - , map_ptr(NULL) + , map_ptr(nullptr) { arma_extra_debug_sigprint_this(this); @@ -121,7 +123,7 @@ MapMat::MapMat(const SpMat& x) : n_rows (0) , n_cols (0) , n_elem (0) - , map_ptr(NULL) + , map_ptr(nullptr) { arma_extra_debug_sigprint_this(this); @@ -164,60 +166,54 @@ MapMat::operator=(const SpMat& x) const uword index = (x_n_rows * col) + row; - #if defined(ARMA_USE_CXX11) - map_ref.emplace_hint(map_ref.cend(), index, val); - #else - map_ref.operator[](index) = val; - #endif + map_ref.emplace_hint(map_ref.cend(), index, val); } } } -#if defined(ARMA_USE_CXX11) +template +inline +MapMat::MapMat(MapMat&& x) + : n_rows (x.n_rows ) + , n_cols (x.n_cols ) + , n_elem (x.n_elem ) + , map_ptr(x.map_ptr) + { + arma_extra_debug_sigprint_this(this); + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.map_ptr) = nullptr; + } - template - inline - MapMat::MapMat(MapMat&& x) - : n_rows (x.n_rows ) - , n_cols (x.n_cols ) - , n_elem (x.n_elem ) - , map_ptr(x.map_ptr) - { - arma_extra_debug_sigprint_this(this); - - access::rw(x.n_rows) = 0; - access::rw(x.n_cols) = 0; - access::rw(x.n_elem) = 0; - access::rw(x.map_ptr) = NULL; - } + + +template +inline +void +MapMat::operator=(MapMat&& x) + { + arma_extra_debug_sigprint(); + if(this == &x) { return; } + reset(); - template - inline - void - MapMat::operator=(MapMat&& x) - { - arma_extra_debug_sigprint(); - - reset(); - - if(map_ptr) { delete map_ptr; } - - access::rw(n_rows) = x.n_rows; - access::rw(n_cols) = x.n_cols; - access::rw(n_elem) = x.n_elem; - access::rw(map_ptr) = x.map_ptr; - - access::rw(x.n_rows) = 0; - access::rw(x.n_cols) = 0; - access::rw(x.n_elem) = 0; - access::rw(x.map_ptr) = NULL; - } - -#endif + if(map_ptr) { delete map_ptr; } + + access::rw(n_rows) = x.n_rows; + access::rw(n_cols) = x.n_cols; + access::rw(n_elem) = x.n_elem; + access::rw(map_ptr) = x.map_ptr; + + access::rw(x.n_rows) = 0; + access::rw(x.n_cols) = 0; + access::rw(x.n_elem) = 0; + access::rw(x.map_ptr) = nullptr; + } @@ -356,11 +352,7 @@ MapMat::eye(const uword in_n_rows, const uword in_n_cols) { const uword index = (in_n_rows * i) + i; - #if defined(ARMA_USE_CXX11) - map_ref.emplace_hint(map_ref.cend(), index, eT(1)); - #else - map_ref.operator[](index) = eT(1); - #endif + map_ref.emplace_hint(map_ref.cend(), index, eT(1)); } } @@ -416,7 +408,6 @@ MapMat::speye(const SizeMat& s) template arma_inline -arma_warn_unused MapMat_val MapMat::operator[](const uword index) { @@ -427,7 +418,6 @@ MapMat::operator[](const uword index) template inline -arma_warn_unused eT MapMat::operator[](const uword index) const { @@ -443,11 +433,10 @@ MapMat::operator[](const uword index) const template arma_inline -arma_warn_unused MapMat_val MapMat::operator()(const uword index) { - arma_debug_check( (index >= n_elem), "MapMat::operator(): index out of bounds" ); + arma_debug_check_bounds( (index >= n_elem), "MapMat::operator(): index out of bounds" ); return MapMat_val(*this, index); } @@ -456,11 +445,10 @@ MapMat::operator()(const uword index) template inline -arma_warn_unused eT MapMat::operator()(const uword index) const { - arma_debug_check( (index >= n_elem), "MapMat::operator(): index out of bounds" ); + arma_debug_check_bounds( (index >= n_elem), "MapMat::operator(): index out of bounds" ); map_type& map_ref = (*map_ptr); @@ -474,7 +462,6 @@ MapMat::operator()(const uword index) const template arma_inline -arma_warn_unused MapMat_val MapMat::at(const uword in_row, const uword in_col) { @@ -487,7 +474,6 @@ MapMat::at(const uword in_row, const uword in_col) template inline -arma_warn_unused eT MapMat::at(const uword in_row, const uword in_col) const { @@ -505,11 +491,10 @@ MapMat::at(const uword in_row, const uword in_col) const template arma_inline -arma_warn_unused MapMat_val MapMat::operator()(const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); const uword index = (n_rows * in_col) + in_row; @@ -520,11 +505,10 @@ MapMat::operator()(const uword in_row, const uword in_col) template inline -arma_warn_unused eT MapMat::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "MapMat::operator(): index out of bounds" ); const uword index = (n_rows * in_col) + in_row; @@ -540,7 +524,6 @@ MapMat::operator()(const uword in_row, const uword in_col) const template inline -arma_warn_unused bool MapMat::is_empty() const { @@ -551,7 +534,6 @@ MapMat::is_empty() const template inline -arma_warn_unused bool MapMat::is_vec() const { @@ -562,7 +544,6 @@ MapMat::is_vec() const template inline -arma_warn_unused bool MapMat::is_rowvec() const { @@ -574,7 +555,6 @@ MapMat::is_rowvec() const //! returns true if the object can be interpreted as a column vector template inline -arma_warn_unused bool MapMat::is_colvec() const { @@ -585,7 +565,6 @@ MapMat::is_colvec() const template inline -arma_warn_unused bool MapMat::is_square() const { @@ -619,11 +598,7 @@ MapMat::sprandu(const uword in_n_rows, const uword in_n_cols, const double d const uword index = indx_mem[i]; const eT val = vals_mem[i]; - #if defined(ARMA_USE_CXX11) - map_ref.emplace_hint(map_ref.cend(), index, val); - #else - map_ref.operator[](index) = val; - #endif + map_ref.emplace_hint(map_ref.cend(), index, val); } } @@ -744,10 +719,10 @@ MapMat::init_cold() // ensure that n_elem can hold the result of (n_rows * n_cols) - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "MapMat(): requested size is too large"; #else - const char* error_message = "MapMat(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "MapMat(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_check @@ -762,7 +737,7 @@ MapMat::init_cold() map_ptr = new (std::nothrow) map_type; - arma_check_bad_alloc( (map_ptr == NULL), "MapMat(): out of memory" ); + arma_check_bad_alloc( (map_ptr == nullptr), "MapMat(): out of memory" ); } @@ -778,10 +753,10 @@ MapMat::init_warm(const uword in_n_rows, const uword in_n_cols) // ensure that n_elem can hold the result of (n_rows * n_cols) - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "MapMat(): requested size is too large"; #else - const char* error_message = "MapMat(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "MapMat(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_check @@ -814,24 +789,16 @@ MapMat::set_val(const uword index, const eT& in_val) if(in_val != eT(0)) { - #if defined(ARMA_USE_CXX11) + map_type& map_ref = (*map_ptr); + + if( (map_ref.empty() == false) && (index > uword(map_ref.crbegin()->first)) ) { - map_type& map_ref = (*map_ptr); - - if( (map_ref.empty() == false) && (index > uword(map_ref.crbegin()->first)) ) - { - map_ref.emplace_hint(map_ref.cend(), index, in_val); - } - else - { - map_ref.operator[](index) = in_val; - } + map_ref.emplace_hint(map_ref.cend(), index, in_val); } - #else + else { - (*map_ptr).operator[](index) = in_val; + map_ref.operator[](index) = in_val; } - #endif } else { @@ -1204,13 +1171,11 @@ SpMat_MapMat_val::operator=(const eT in_val) (*this).set(in_val); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { - s_parent.cache_mutex.lock(); + const std::lock_guard lock(s_parent.cache_mutex); (*this).set(in_val); - - s_parent.cache_mutex.unlock(); } #else { @@ -1239,13 +1204,11 @@ SpMat_MapMat_val::operator+=(const eT in_val) (*this).add(in_val); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { - s_parent.cache_mutex.lock(); + const std::lock_guard lock(s_parent.cache_mutex); (*this).add(in_val); - - s_parent.cache_mutex.unlock(); } #else { @@ -1274,13 +1237,11 @@ SpMat_MapMat_val::operator-=(const eT in_val) (*this).sub(in_val); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { - s_parent.cache_mutex.lock(); + const std::lock_guard lock(s_parent.cache_mutex); (*this).sub(in_val); - - s_parent.cache_mutex.unlock(); } #else { @@ -1307,13 +1268,11 @@ SpMat_MapMat_val::operator*=(const eT in_val) (*this).mul(in_val); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { - s_parent.cache_mutex.lock(); + const std::lock_guard lock(s_parent.cache_mutex); (*this).mul(in_val); - - s_parent.cache_mutex.unlock(); } #else { @@ -1340,13 +1299,11 @@ SpMat_MapMat_val::operator/=(const eT in_val) (*this).div(in_val); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { - s_parent.cache_mutex.lock(); + const std::lock_guard lock(s_parent.cache_mutex); (*this).div(in_val); - - s_parent.cache_mutex.unlock(); } #else { @@ -1373,7 +1330,6 @@ SpMat_MapMat_val::operator++() template inline -arma_warn_unused eT SpMat_MapMat_val::operator++(int) { @@ -1402,7 +1358,6 @@ SpMat_MapMat_val::operator--() template inline -arma_warn_unused eT SpMat_MapMat_val::operator--(int) { @@ -1765,7 +1720,6 @@ SpSubview_MapMat_val::operator++() template inline -arma_warn_unused eT SpSubview_MapMat_val::operator++(int) { @@ -1804,7 +1758,6 @@ SpSubview_MapMat_val::operator--() template inline -arma_warn_unused eT SpSubview_MapMat_val::operator--(int) { diff --git a/src/armadillo_bits/Mat_bones.hpp b/src/armadillo_bits/Mat_bones.hpp index c80999bb..baa41daf 100644 --- a/src/armadillo_bits/Mat_bones.hpp +++ b/src/armadillo_bits/Mat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,9 +31,10 @@ class Mat : public Base< eT, Mat > typedef eT elem_type; //!< the type of elements stored in the matrix typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT - const uword n_rows; //!< number of rows (read-only) - const uword n_cols; //!< number of columns (read-only) - const uword n_elem; //!< number of elements (read-only) + const uword n_rows; //!< number of rows (read-only) + const uword n_cols; //!< number of columns (read-only) + const uword n_elem; //!< number of elements (read-only) + const uword n_alloc; //!< number of allocated elements (read-only); NOTE: n_alloc can be 0, even if n_elem > 0 const uhword vec_state; //!< 0: matrix layout; 1: column vector layout; 2: row vector layout const uhword mem_state; @@ -50,29 +53,34 @@ class Mat : public Base< eT, Mat > public: - static const bool is_col = false; - static const bool is_row = false; - static const bool is_xvec = false; + static constexpr bool is_col = false; + static constexpr bool is_row = false; + static constexpr bool is_xvec = false; inline ~Mat(); inline Mat(); - inline explicit Mat(const uword in_rows, const uword in_cols); + inline explicit Mat(const uword in_n_rows, const uword in_n_cols); inline explicit Mat(const SizeMat& s); - template inline Mat(const uword in_rows, const uword in_cols, const fill::fill_class& f); - template inline Mat(const SizeMat& s, const fill::fill_class& f); + template inline explicit Mat(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&); + template inline explicit Mat(const SizeMat& s, const arma_initmode_indicator&); + + template inline Mat(const uword in_n_rows, const uword in_n_cols, const fill::fill_class& f); + template inline Mat(const SizeMat& s, const fill::fill_class& f); + + inline Mat(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f); + inline Mat(const SizeMat& s, const fill::scalar_holder f); - inline arma_cold Mat(const char* text); - inline arma_cold Mat& operator=(const char* text); + arma_cold inline Mat(const char* text); + arma_cold inline Mat& operator=(const char* text); - inline arma_cold Mat(const std::string& text); - inline arma_cold Mat& operator=(const std::string& text); + arma_cold inline Mat(const std::string& text); + arma_cold inline Mat& operator=(const std::string& text); inline Mat(const std::vector& x); inline Mat& operator=(const std::vector& x); - #if defined(ARMA_USE_CXX11) inline Mat(const std::initializer_list& list); inline Mat& operator=(const std::initializer_list& list); @@ -81,19 +89,18 @@ class Mat : public Base< eT, Mat > inline Mat(Mat&& m); inline Mat& operator=(Mat&& m); - #endif inline Mat( eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const bool copy_aux_mem = true, const bool strict = false); inline Mat(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols); - inline Mat& operator=(const eT val); + inline Mat& operator= (const eT val); inline Mat& operator+=(const eT val); inline Mat& operator-=(const eT val); inline Mat& operator*=(const eT val); inline Mat& operator/=(const eT val); inline Mat(const Mat& m); - inline Mat& operator=(const Mat& m); + inline Mat& operator= (const Mat& m); inline Mat& operator+=(const Mat& m); inline Mat& operator-=(const Mat& m); inline Mat& operator*=(const Mat& m); @@ -101,7 +108,7 @@ class Mat : public Base< eT, Mat > inline Mat& operator/=(const Mat& m); template inline Mat(const BaseCube& X); - template inline Mat& operator=(const BaseCube& X); + template inline Mat& operator= (const BaseCube& X); template inline Mat& operator+=(const BaseCube& X); template inline Mat& operator-=(const BaseCube& X); template inline Mat& operator*=(const BaseCube& X); @@ -111,10 +118,10 @@ class Mat : public Base< eT, Mat > template inline explicit Mat(const Base& A, const Base& B); - inline explicit Mat(const subview& X, const bool use_colmem); // only to be used by the quasi_unwrap class + inline explicit Mat(const subview& X, const bool use_colmem); // only to be used by the quasi_unwrap class inline Mat(const subview& X); - inline Mat& operator=(const subview& X); + inline Mat& operator= (const subview& X); inline Mat& operator+=(const subview& X); inline Mat& operator-=(const subview& X); inline Mat& operator*=(const subview& X); @@ -129,7 +136,7 @@ class Mat : public Base< eT, Mat > inline Mat(const xtrans_mat& X); // xtrans_mat can only be generated by the Proxy class inline Mat(const subview_cube& X); - inline Mat& operator=(const subview_cube& X); + inline Mat& operator= (const subview_cube& X); inline Mat& operator+=(const subview_cube& X); inline Mat& operator-=(const subview_cube& X); inline Mat& operator*=(const subview_cube& X); @@ -137,7 +144,7 @@ class Mat : public Base< eT, Mat > inline Mat& operator/=(const subview_cube& X); inline Mat(const diagview& X); - inline Mat& operator=(const diagview& X); + inline Mat& operator= (const diagview& X); inline Mat& operator+=(const diagview& X); inline Mat& operator-=(const diagview& X); inline Mat& operator*=(const diagview& X); @@ -162,15 +169,18 @@ class Mat : public Base< eT, Mat > // Operators on sparse matrices (and subviews) template inline explicit Mat(const SpBase& m); - template inline Mat& operator=(const SpBase& m); + template inline Mat& operator= (const SpBase& m); template inline Mat& operator+=(const SpBase& m); template inline Mat& operator-=(const SpBase& m); template inline Mat& operator*=(const SpBase& m); template inline Mat& operator%=(const SpBase& m); template inline Mat& operator/=(const SpBase& m); + inline explicit Mat(const SpSubview& X); + inline Mat& operator= (const SpSubview& X); + inline explicit Mat(const spdiagview& X); - inline Mat& operator=(const spdiagview& X); + inline Mat& operator= (const spdiagview& X); inline Mat& operator+=(const spdiagview& X); inline Mat& operator-=(const spdiagview& X); inline Mat& operator*=(const spdiagview& X); @@ -178,8 +188,8 @@ class Mat : public Base< eT, Mat > inline Mat& operator/=(const spdiagview& X); - inline mat_injector operator<<(const eT val); - inline mat_injector operator<<(const injector_end_of_row<>& x); + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const injector_end_of_row<>& x); arma_inline subview_row row(const uword row_num); @@ -199,17 +209,17 @@ class Mat : public Base< eT, Mat > inline const Col unsafe_col(const uword col_num) const; - arma_inline subview rows(const uword in_row1, const uword in_row2); - arma_inline const subview rows(const uword in_row1, const uword in_row2) const; + arma_inline subview rows(const uword in_row1, const uword in_row2); + arma_inline const subview rows(const uword in_row1, const uword in_row2) const; - arma_inline subview cols(const uword in_col1, const uword in_col2); - arma_inline const subview cols(const uword in_col1, const uword in_col2) const; + arma_inline subview_cols cols(const uword in_col1, const uword in_col2); + arma_inline const subview_cols cols(const uword in_col1, const uword in_col2) const; - inline subview rows(const span& row_span); - inline const subview rows(const span& row_span) const; + inline subview rows(const span& row_span); + inline const subview rows(const span& row_span) const; - arma_inline subview cols(const span& col_span); - arma_inline const subview cols(const span& col_span) const; + arma_inline subview_cols cols(const span& col_span); + arma_inline const subview_cols cols(const span& col_span) const; arma_inline subview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); @@ -233,11 +243,11 @@ class Mat : public Base< eT, Mat > inline subview tail_rows(const uword N); inline const subview tail_rows(const uword N) const; - inline subview head_cols(const uword N); - inline const subview head_cols(const uword N) const; + inline subview_cols head_cols(const uword N); + inline const subview_cols head_cols(const uword N) const; - inline subview tail_cols(const uword N); - inline const subview tail_cols(const uword N) const; + inline subview_cols tail_cols(const uword N); + inline const subview_cols tail_cols(const uword N) const; template arma_inline subview_elem1 elem(const Base& a); template arma_inline const subview_elem1 elem(const Base& a) const; @@ -275,13 +285,11 @@ class Mat : public Base< eT, Mat > template inline const subview_each2< Mat, 0, T1 > each_col(const Base& indices) const; template inline const subview_each2< Mat, 1, T1 > each_row(const Base& indices) const; - #if defined(ARMA_USE_CXX11) - inline const Mat& each_col(const std::function< void( Col&) >& F); + inline Mat& each_col(const std::function< void( Col&) >& F); inline const Mat& each_col(const std::function< void(const Col&) >& F) const; - inline const Mat& each_row(const std::function< void( Row&) >& F); + inline Mat& each_row(const std::function< void( Row&) >& F); inline const Mat& each_row(const std::function< void(const Row&) >& F) const; - #endif arma_inline diagview diag(const sword in_id = 0); @@ -300,15 +308,18 @@ class Mat : public Base< eT, Mat > template inline void shed_rows(const Base& indices); template inline void shed_cols(const Base& indices); - inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero = true); - inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero = true); + arma_deprecated inline void insert_rows(const uword row_num, const uword N, const bool set_to_zero); + arma_deprecated inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero); + + inline void insert_rows(const uword row_num, const uword N); + inline void insert_cols(const uword col_num, const uword N); template inline void insert_rows(const uword row_num, const Base& X); template inline void insert_cols(const uword col_num, const Base& X); template inline Mat(const Gen& X); - template inline Mat& operator=(const Gen& X); + template inline Mat& operator= (const Gen& X); template inline Mat& operator+=(const Gen& X); template inline Mat& operator-=(const Gen& X); template inline Mat& operator*=(const Gen& X); @@ -316,7 +327,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const Gen& X); template inline Mat(const Op& X); - template inline Mat& operator=(const Op& X); + template inline Mat& operator= (const Op& X); template inline Mat& operator+=(const Op& X); template inline Mat& operator-=(const Op& X); template inline Mat& operator*=(const Op& X); @@ -324,7 +335,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const Op& X); template inline Mat(const eOp& X); - template inline Mat& operator=(const eOp& X); + template inline Mat& operator= (const eOp& X); template inline Mat& operator+=(const eOp& X); template inline Mat& operator-=(const eOp& X); template inline Mat& operator*=(const eOp& X); @@ -332,7 +343,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const eOp& X); template inline Mat(const mtOp& X); - template inline Mat& operator=(const mtOp& X); + template inline Mat& operator= (const mtOp& X); template inline Mat& operator+=(const mtOp& X); template inline Mat& operator-=(const mtOp& X); template inline Mat& operator*=(const mtOp& X); @@ -340,7 +351,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const mtOp& X); template inline Mat(const CubeToMatOp& X); - template inline Mat& operator=(const CubeToMatOp& X); + template inline Mat& operator= (const CubeToMatOp& X); template inline Mat& operator+=(const CubeToMatOp& X); template inline Mat& operator-=(const CubeToMatOp& X); template inline Mat& operator*=(const CubeToMatOp& X); @@ -348,7 +359,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const CubeToMatOp& X); template inline Mat(const SpToDOp& X); - template inline Mat& operator=(const SpToDOp& X); + template inline Mat& operator= (const SpToDOp& X); template inline Mat& operator+=(const SpToDOp& X); template inline Mat& operator-=(const SpToDOp& X); template inline Mat& operator*=(const SpToDOp& X); @@ -356,7 +367,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const SpToDOp& X); template inline Mat(const Glue& X); - template inline Mat& operator=(const Glue& X); + template inline Mat& operator= (const Glue& X); template inline Mat& operator+=(const Glue& X); template inline Mat& operator-=(const Glue& X); template inline Mat& operator*=(const Glue& X); @@ -367,7 +378,7 @@ class Mat : public Base< eT, Mat > template inline Mat& operator-=(const Glue& X); template inline Mat(const eGlue& X); - template inline Mat& operator=(const eGlue& X); + template inline Mat& operator= (const eGlue& X); template inline Mat& operator+=(const eGlue& X); template inline Mat& operator-=(const eGlue& X); template inline Mat& operator*=(const eGlue& X); @@ -375,27 +386,40 @@ class Mat : public Base< eT, Mat > template inline Mat& operator/=(const eGlue& X); template inline Mat(const mtGlue& X); - template inline Mat& operator=(const mtGlue& X); + template inline Mat& operator= (const mtGlue& X); template inline Mat& operator+=(const mtGlue& X); template inline Mat& operator-=(const mtGlue& X); template inline Mat& operator*=(const mtGlue& X); template inline Mat& operator%=(const mtGlue& X); template inline Mat& operator/=(const mtGlue& X); + template inline Mat(const SpToDGlue& X); + template inline Mat& operator= (const SpToDGlue& X); + template inline Mat& operator+=(const SpToDGlue& X); + template inline Mat& operator-=(const SpToDGlue& X); + template inline Mat& operator*=(const SpToDGlue& X); + template inline Mat& operator%=(const SpToDGlue& X); + template inline Mat& operator/=(const SpToDGlue& X); + + + arma_warn_unused arma_inline const eT& at_alt (const uword ii) const; - arma_inline arma_warn_unused const eT& at_alt (const uword ii) const; + arma_warn_unused arma_inline eT& operator[] (const uword ii); + arma_warn_unused arma_inline const eT& operator[] (const uword ii) const; + arma_warn_unused arma_inline eT& at (const uword ii); + arma_warn_unused arma_inline const eT& at (const uword ii) const; + arma_warn_unused arma_inline eT& operator() (const uword ii); + arma_warn_unused arma_inline const eT& operator() (const uword ii) const; - arma_inline arma_warn_unused eT& operator[] (const uword ii); - arma_inline arma_warn_unused const eT& operator[] (const uword ii) const; - arma_inline arma_warn_unused eT& at (const uword ii); - arma_inline arma_warn_unused const eT& at (const uword ii) const; - arma_inline arma_warn_unused eT& operator() (const uword ii); - arma_inline arma_warn_unused const eT& operator() (const uword ii) const; + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col) const; + #endif - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; arma_inline const Mat& operator++(); arma_inline void operator++(int); @@ -403,113 +427,109 @@ class Mat : public Base< eT, Mat > arma_inline const Mat& operator--(); arma_inline void operator--(int); - arma_inline arma_warn_unused bool is_empty() const; - arma_inline arma_warn_unused bool is_vec() const; - arma_inline arma_warn_unused bool is_rowvec() const; - arma_inline arma_warn_unused bool is_colvec() const; - arma_inline arma_warn_unused bool is_square() const; - inline arma_warn_unused bool is_finite() const; + arma_warn_unused arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_vec() const; + arma_warn_unused arma_inline bool is_rowvec() const; + arma_warn_unused arma_inline bool is_colvec() const; + arma_warn_unused arma_inline bool is_square() const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; - inline arma_warn_unused bool is_sorted(const char* direction = "ascend") const; - inline arma_warn_unused bool is_sorted(const char* direction, const uword dim) const; + arma_warn_unused inline bool is_sorted(const char* direction = "ascend") const; + arma_warn_unused inline bool is_sorted(const char* direction, const uword dim) const; template - inline arma_warn_unused bool is_sorted_helper(const comparator& comp, const uword dim) const; - - arma_inline arma_warn_unused bool in_range(const uword ii) const; - arma_inline arma_warn_unused bool in_range(const span& x ) const; - - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const span& col_span) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const span& col_span) const; + arma_warn_unused inline bool is_sorted_helper(const comparator& comp, const uword dim) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + arma_warn_unused arma_inline bool in_range(const uword ii) const; + arma_warn_unused arma_inline bool in_range(const span& x ) const; - arma_inline arma_warn_unused eT* colptr(const uword in_col); - arma_inline arma_warn_unused const eT* colptr(const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; - arma_inline arma_warn_unused eT* memptr(); - arma_inline arma_warn_unused const eT* memptr() const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + arma_warn_unused arma_inline eT* colptr(const uword in_col); + arma_warn_unused arma_inline const eT* colptr(const uword in_col) const; - arma_cold inline void impl_print( const std::string& extra_text) const; - arma_cold inline void impl_print(std::ostream& user_stream, const std::string& extra_text) const; - - arma_cold inline void impl_raw_print( const std::string& extra_text) const; - arma_cold inline void impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; template - inline void copy_size(const Base& X); + inline Mat& copy_size(const Base& X); - inline void set_size(const uword in_elem); - inline void set_size(const uword in_rows, const uword in_cols); - inline void set_size(const SizeMat& s); + inline Mat& set_size(const uword new_n_elem); + inline Mat& set_size(const uword new_n_rows, const uword new_n_cols); + inline Mat& set_size(const SizeMat& s); - inline void resize(const uword in_elem); - inline void resize(const uword in_rows, const uword in_cols); - inline void resize(const SizeMat& s); + inline Mat& resize(const uword new_n_elem); + inline Mat& resize(const uword new_n_rows, const uword new_n_cols); + inline Mat& resize(const SizeMat& s); - inline void reshape(const uword in_rows, const uword in_cols); - inline void reshape(const SizeMat& s); + inline Mat& reshape(const uword new_n_rows, const uword new_n_cols); + inline Mat& reshape(const SizeMat& s); - arma_deprecated inline void reshape(const uword in_rows, const uword in_cols, const uword dim); //!< NOTE: don't use this form: it will be removed + arma_deprecated inline void reshape(const uword new_n_rows, const uword new_n_cols, const uword dim); //!< NOTE: don't use this form: it will be removed - template inline const Mat& for_each(functor F); + template inline Mat& for_each(functor F); template inline const Mat& for_each(functor F) const; - template inline const Mat& transform(functor F); - template inline const Mat& imbue(functor F); + template inline Mat& transform(functor F); + template inline Mat& imbue(functor F); + + inline Mat& replace(const eT old_val, const eT new_val); - inline const Mat& replace(const eT old_val, const eT new_val); + inline Mat& clean(const pod_type threshold); - inline const Mat& clean(const pod_type threshold); + inline Mat& clamp(const eT min_val, const eT max_val); - inline const Mat& fill(const eT val); + inline Mat& fill(const eT val); template - inline const Mat& fill(const fill::fill_class& f); + inline Mat& fill(const fill::fill_class& f); - inline const Mat& zeros(); - inline const Mat& zeros(const uword in_elem); - inline const Mat& zeros(const uword in_rows, const uword in_cols); - inline const Mat& zeros(const SizeMat& s); + inline Mat& zeros(); + inline Mat& zeros(const uword new_n_elem); + inline Mat& zeros(const uword new_n_rows, const uword new_n_cols); + inline Mat& zeros(const SizeMat& s); - inline const Mat& ones(); - inline const Mat& ones(const uword in_elem); - inline const Mat& ones(const uword in_rows, const uword in_cols); - inline const Mat& ones(const SizeMat& s); + inline Mat& ones(); + inline Mat& ones(const uword new_n_elem); + inline Mat& ones(const uword new_n_rows, const uword new_n_cols); + inline Mat& ones(const SizeMat& s); - inline const Mat& randu(); - inline const Mat& randu(const uword in_elem); - inline const Mat& randu(const uword in_rows, const uword in_cols); - inline const Mat& randu(const SizeMat& s); + inline Mat& randu(); + inline Mat& randu(const uword new_n_elem); + inline Mat& randu(const uword new_n_rows, const uword new_n_cols); + inline Mat& randu(const SizeMat& s); - inline const Mat& randn(); - inline const Mat& randn(const uword in_elem); - inline const Mat& randn(const uword in_rows, const uword in_cols); - inline const Mat& randn(const SizeMat& s); + inline Mat& randn(); + inline Mat& randn(const uword new_n_elem); + inline Mat& randn(const uword new_n_rows, const uword new_n_cols); + inline Mat& randn(const SizeMat& s); - inline const Mat& eye(); - inline const Mat& eye(const uword in_rows, const uword in_cols); - inline const Mat& eye(const SizeMat& s); + inline Mat& eye(); + inline Mat& eye(const uword new_n_rows, const uword new_n_cols); + inline Mat& eye(const SizeMat& s); - inline arma_cold void reset(); - inline arma_cold void soft_reset(); + arma_cold inline void reset(); + arma_cold inline void soft_reset(); template inline void set_real(const Base& X); template inline void set_imag(const Base& X); - inline arma_warn_unused eT min() const; - inline arma_warn_unused eT max() const; + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; inline eT min(uword& index_of_min_val) const; inline eT max(uword& index_of_max_val) const; @@ -518,21 +538,25 @@ class Mat : public Base< eT, Mat > inline eT max(uword& row_of_max_val, uword& col_of_max_val) const; - inline arma_cold bool save(const std::string name, const file_type type = arma_binary, const bool print_status = true) const; - inline arma_cold bool save(const hdf5_name& spec, const file_type type = hdf5_binary, const bool print_status = true) const; - inline arma_cold bool save( std::ostream& os, const file_type type = arma_binary, const bool print_status = true) const; + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_cold inline bool save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool load(const std::string name, const file_type type = auto_detect, const bool print_status = true); - inline arma_cold bool load(const hdf5_name& spec, const file_type type = hdf5_binary, const bool print_status = true); - inline arma_cold bool load( std::istream& is, const file_type type = auto_detect, const bool print_status = true); + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_cold inline bool load(const csv_name& spec, const file_type type = csv_ascii); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); - inline arma_cold bool quiet_save(const std::string name, const file_type type = arma_binary) const; - inline arma_cold bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; - inline arma_cold bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const hdf5_name& spec, const file_type type = hdf5_binary) const; + arma_deprecated inline bool quiet_save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool quiet_load(const std::string name, const file_type type = auto_detect); - inline arma_cold bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); - inline arma_cold bool quiet_load( std::istream& is, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const hdf5_name& spec, const file_type type = hdf5_binary); + arma_deprecated inline bool quiet_load(const csv_name& spec, const file_type type = csv_ascii); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); // for container-like functionality @@ -554,20 +578,20 @@ class Mat : public Base< eT, Mat > inline row_iterator(); inline row_iterator(const row_iterator& X); - inline row_iterator(Mat& in_M, const uword in_row); + inline row_iterator(Mat& in_M, const uword in_row, const uword in_col); - inline arma_warn_unused eT& operator* (); + arma_warn_unused inline eT& operator* (); - inline row_iterator& operator++(); - inline arma_warn_unused row_iterator operator++(int); + inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); - inline row_iterator& operator--(); - inline arma_warn_unused row_iterator operator--(int); + inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); - inline arma_warn_unused bool operator!=(const row_iterator& X) const; - inline arma_warn_unused bool operator==(const row_iterator& X) const; - inline arma_warn_unused bool operator!=(const const_row_iterator& X) const; - inline arma_warn_unused bool operator==(const const_row_iterator& X) const; + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; typedef std::bidirectional_iterator_tag iterator_category; typedef eT value_type; @@ -576,7 +600,6 @@ class Mat : public Base< eT, Mat > typedef eT& reference; arma_aligned Mat* M; - arma_aligned eT* current_ptr; arma_aligned uword current_row; arma_aligned uword current_col; }; @@ -589,20 +612,20 @@ class Mat : public Base< eT, Mat > inline const_row_iterator(); inline const_row_iterator(const row_iterator& X); inline const_row_iterator(const const_row_iterator& X); - inline const_row_iterator(const Mat& in_M, const uword in_row); + inline const_row_iterator(const Mat& in_M, const uword in_row, const uword in_col); - inline arma_warn_unused const eT& operator*() const; + arma_warn_unused inline const eT& operator*() const; - inline const_row_iterator& operator++(); - inline arma_warn_unused const_row_iterator operator++(int); + inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); - inline const_row_iterator& operator--(); - inline arma_warn_unused const_row_iterator operator--(int); + inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); - inline arma_warn_unused bool operator!=(const row_iterator& X) const; - inline arma_warn_unused bool operator==(const row_iterator& X) const; - inline arma_warn_unused bool operator!=(const const_row_iterator& X) const; - inline arma_warn_unused bool operator==(const const_row_iterator& X) const; + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; typedef std::bidirectional_iterator_tag iterator_category; typedef eT value_type; @@ -611,7 +634,6 @@ class Mat : public Base< eT, Mat > typedef const eT& reference; arma_aligned const Mat* M; - arma_aligned const eT* current_ptr; arma_aligned uword current_row; arma_aligned uword current_col; }; @@ -627,21 +649,21 @@ class Mat : public Base< eT, Mat > inline row_col_iterator(const row_col_iterator& in_it); inline row_col_iterator(Mat& in_M, const uword row = 0, const uword col = 0); - inline arma_warn_unused eT& operator*(); + arma_warn_unused inline eT& operator*(); - inline row_col_iterator& operator++(); - inline arma_warn_unused row_col_iterator operator++(int); + inline row_col_iterator& operator++(); + arma_warn_unused inline row_col_iterator operator++(int); - inline row_col_iterator& operator--(); - inline arma_warn_unused row_col_iterator operator--(int); + inline row_col_iterator& operator--(); + arma_warn_unused inline row_col_iterator operator--(int); - inline arma_warn_unused uword row() const; - inline arma_warn_unused uword col() const; + arma_warn_unused inline uword row() const; + arma_warn_unused inline uword col() const; - inline arma_warn_unused bool operator==(const row_col_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const row_col_iterator& rhs) const; - inline arma_warn_unused bool operator==(const const_row_col_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_row_col_iterator& rhs) const; typedef std::bidirectional_iterator_tag iterator_category; typedef eT value_type; @@ -665,21 +687,21 @@ class Mat : public Base< eT, Mat > inline const_row_col_iterator(const const_row_col_iterator& in_it); inline const_row_col_iterator(const Mat& in_M, const uword row = 0, const uword col = 0); - inline arma_warn_unused const eT& operator*() const; + arma_warn_unused inline const eT& operator*() const; - inline const_row_col_iterator& operator++(); - inline arma_warn_unused const_row_col_iterator operator++(int); + inline const_row_col_iterator& operator++(); + arma_warn_unused inline const_row_col_iterator operator++(int); - inline const_row_col_iterator& operator--(); - inline arma_warn_unused const_row_col_iterator operator--(int); + inline const_row_col_iterator& operator--(); + arma_warn_unused inline const_row_col_iterator operator--(int); - inline arma_warn_unused uword row() const; - inline arma_warn_unused uword col() const; + arma_warn_unused inline uword row() const; + arma_warn_unused inline uword col() const; - inline arma_warn_unused bool operator==(const const_row_col_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_row_col_iterator& rhs) const; - inline arma_warn_unused bool operator==(const row_col_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_row_col_iterator& rhs) const; + arma_warn_unused inline bool operator==(const row_col_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const row_col_iterator& rhs) const; // So that we satisfy the STL iterator types. typedef std::bidirectional_iterator_tag iterator_category; @@ -726,15 +748,16 @@ class Mat : public Base< eT, Mat > inline bool empty() const; inline uword size() const; - inline eT& front(); - inline const eT& front() const; + arma_warn_unused inline eT& front(); + arma_warn_unused inline const eT& front() const; - inline eT& back(); - inline const eT& back() const; + arma_warn_unused inline eT& back(); + arma_warn_unused inline const eT& back() const; inline void swap(Mat& B); - inline void steal_mem(Mat& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Mat& X); //!< don't use this unless you're writing code internal to Armadillo + inline void steal_mem(Mat& X, const bool is_move); //!< don't use this unless you're writing code internal to Armadillo inline void steal_mem_col(Mat& X, const uword max_n_rows); @@ -745,14 +768,12 @@ class Mat : public Base< eT, Mat > protected: inline void init_cold(); - inline void init_warm(uword in_rows, uword in_cols); + inline void init_warm(uword in_n_rows, uword in_n_cols); - inline arma_cold void init(const std::string& text); + arma_cold inline void init(const std::string& text); - #if defined(ARMA_USE_CXX11) - inline void init(const std::initializer_list& list); - inline void init(const std::initializer_list< std::initializer_list >& list); - #endif + inline void init(const std::initializer_list& list); + inline void init(const std::initializer_list< std::initializer_list >& list); template inline void init(const Base& A, const Base& B); @@ -778,7 +799,7 @@ class Mat : public Base< eT, Mat > public: - #ifdef ARMA_EXTRA_MAT_PROTO + #if defined(ARMA_EXTRA_MAT_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_MAT_PROTO) #endif }; @@ -791,8 +812,8 @@ class Mat::fixed : public Mat { private: - static const uword fixed_n_elem = fixed_n_rows * fixed_n_cols; - static const bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + static constexpr uword fixed_n_elem = fixed_n_rows * fixed_n_cols; + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; @@ -804,9 +825,9 @@ class Mat::fixed : public Mat typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_col = (fixed_n_cols == 1); - static const bool is_row = (fixed_n_rows == 1); - static const bool is_xvec = false; + static constexpr bool is_col = (fixed_n_cols == 1); + static constexpr bool is_row = (fixed_n_rows == 1); + static constexpr bool is_xvec = false; static const uword n_rows; // value provided below the class definition static const uword n_cols; // value provided below the class definition @@ -815,6 +836,7 @@ class Mat::fixed : public Mat arma_inline fixed(); arma_inline fixed(const fixed& X); + inline fixed(const fill::scalar_holder f); template inline fixed(const fill::fill_class& f); template inline fixed(const Base& A); template inline fixed(const Base& A, const Base& B); @@ -827,13 +849,11 @@ class Mat::fixed : public Mat using Mat::operator=; using Mat::operator(); - #if defined(ARMA_USE_CXX11) - inline fixed(const std::initializer_list& list); - inline Mat& operator=(const std::initializer_list& list); - - inline fixed(const std::initializer_list< std::initializer_list >& list); - inline Mat& operator=(const std::initializer_list< std::initializer_list >& list); - #endif + inline fixed(const std::initializer_list& list); + inline Mat& operator=(const std::initializer_list& list); + + inline fixed(const std::initializer_list< std::initializer_list >& list); + inline Mat& operator=(const std::initializer_list< std::initializer_list >& list); arma_inline Mat& operator=(const fixed& X); @@ -842,31 +862,36 @@ class Mat::fixed : public Mat template inline Mat& operator=(const eGlue& X); #endif - arma_inline const Op< Mat_fixed_type, op_htrans > t() const; - arma_inline const Op< Mat_fixed_type, op_htrans > ht() const; - arma_inline const Op< Mat_fixed_type, op_strans > st() const; + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Mat_fixed_type, op_strans > st() const; + + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; - arma_inline arma_warn_unused const eT& at_alt (const uword i) const; + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; - arma_inline arma_warn_unused eT& operator[] (const uword i); - arma_inline arma_warn_unused const eT& operator[] (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword i); - arma_inline arma_warn_unused const eT& at (const uword i) const; - arma_inline arma_warn_unused eT& operator() (const uword i); - arma_inline arma_warn_unused const eT& operator() (const uword i) const; + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline eT& operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator[] (const uword in_row, const uword in_col) const; + #endif - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT* colptr(const uword in_col); - arma_inline arma_warn_unused const eT* colptr(const uword in_col) const; + arma_warn_unused arma_inline eT* colptr(const uword in_col); + arma_warn_unused arma_inline const eT* colptr(const uword in_col) const; - arma_inline arma_warn_unused eT* memptr(); - arma_inline arma_warn_unused const eT* memptr() const; + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; - arma_inline arma_warn_unused bool is_vec() const; + arma_warn_unused arma_inline bool is_vec() const; inline const Mat& fill(const eT val); inline const Mat& zeros(); diff --git a/src/armadillo_bits/Mat_meat.hpp b/src/armadillo_bits/Mat_meat.hpp index 9d78a01d..cb92ffde 100644 --- a/src/armadillo_bits/Mat_meat.hpp +++ b/src/armadillo_bits/Mat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -24,13 +26,14 @@ Mat::~Mat() { arma_extra_debug_sigprint_this(this); - if( (mem_state == 0) && (n_elem > arma_config::mat_prealloc) ) + if(n_alloc > 0) { + arma_extra_debug_print("Mat::destructor: releasing memory"); memory::release( access::rw(mem) ); } // try to expose buggy user code that accesses deleted objects - if(arma_config::debug) { access::rw(mem) = 0; } + if(arma_config::debug) { access::rw(mem) = nullptr; } arma_type_check(( is_supported_elem_type::value == false )); } @@ -43,6 +46,7 @@ Mat::Mat() : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -59,6 +63,7 @@ Mat::Mat(const uword in_n_rows, const uword in_n_cols) : n_rows(in_n_rows) , n_cols(in_n_cols) , n_elem(in_n_rows*in_n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -66,6 +71,12 @@ Mat::Mat(const uword in_n_rows, const uword in_n_cols) arma_extra_debug_sigprint_this(this); init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } } @@ -76,6 +87,33 @@ Mat::Mat(const SizeMat& s) : n_rows(s.n_rows) , n_cols(s.n_cols) , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -83,6 +121,38 @@ Mat::Mat(const SizeMat& s) arma_extra_debug_sigprint_this(this); init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } + } + + + +//! internal use only +template +template +inline +Mat::Mat(const SizeMat& s, const arma_initmode_indicator&) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + if(do_zeros) + { + arma_extra_debug_print("Mat::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } } @@ -95,6 +165,7 @@ Mat::Mat(const uword in_n_rows, const uword in_n_cols, const fill::fill_clas : n_rows(in_n_rows) , n_cols(in_n_cols) , n_elem(in_n_rows*in_n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -115,6 +186,7 @@ Mat::Mat(const SizeMat& s, const fill::fill_class& f) : n_rows(s.n_rows) , n_cols(s.n_cols) , n_elem(s.n_rows*s.n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -128,6 +200,47 @@ Mat::Mat(const SizeMat& s, const fill::fill_class& f) +//! construct the matrix to have user specified dimensions and fill with specified value +template +inline +Mat::Mat(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : n_rows(in_n_rows) + , n_cols(in_n_cols) + , n_elem(in_n_rows*in_n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + +template +inline +Mat::Mat(const SizeMat& s, const fill::scalar_holder f) + : n_rows(s.n_rows) + , n_cols(s.n_cols) + , n_elem(s.n_rows*s.n_cols) + , n_alloc() + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + init_cold(); + + (*this).fill(f.scalar); + } + + + //! constructor used by Row and Col classes template inline @@ -135,6 +248,7 @@ Mat::Mat(const arma_vec_indicator&, const uhword in_vec_state) : n_rows( (in_vec_state == 2) ? 1 : 0 ) , n_cols( (in_vec_state == 1) ? 1 : 0 ) , n_elem(0) + , n_alloc(0) , vec_state(in_vec_state) , mem_state(0) , mem() @@ -151,6 +265,7 @@ Mat::Mat(const arma_vec_indicator&, const uword in_n_rows, const uword in_n_ : n_rows(in_n_rows) , n_cols(in_n_cols) , n_elem(in_n_rows*in_n_cols) + , n_alloc() , vec_state(in_vec_state) , mem_state(0) , mem() @@ -168,6 +283,7 @@ Mat::Mat(const arma_fixed_indicator&, const uword in_n_rows, const uword in_ : n_rows (in_n_rows) , n_cols (in_n_cols) , n_elem (in_n_rows*in_n_cols) + , n_alloc (0) , vec_state (in_vec_state) , mem_state (3) , mem (in_mem) @@ -182,14 +298,14 @@ inline void Mat::init_cold() { - arma_extra_debug_sigprint( arma_str::format("n_rows = %d, n_cols = %d") % n_rows % n_cols ); + arma_extra_debug_sigprint( arma_str::format("n_rows = %u, n_cols = %u") % n_rows % n_cols ); // ensure that n_elem can hold the result of (n_rows * n_cols) - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "Mat::init(): requested size is too large"; #else - const char* error_message = "Mat::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "Mat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_check @@ -204,20 +320,17 @@ Mat::init_cold() if(n_elem <= arma_config::mat_prealloc) { - if(n_elem == 0) - { - access::rw(mem) = NULL; - } - else - { - arma_extra_debug_print("Mat::init(): using local memory"); - access::rw(mem) = mem_local; - } + if(n_elem > 0) { arma_extra_debug_print("Mat::init(): using local memory"); } + + access::rw(mem) = (n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; } else { arma_extra_debug_print("Mat::init(): acquiring memory"); - access::rw(mem) = memory::acquire(n_elem); + + access::rw(mem) = memory::acquire(n_elem); + access::rw(n_alloc) = n_elem; } } @@ -228,17 +341,21 @@ inline void Mat::init_warm(uword in_n_rows, uword in_n_cols) { - arma_extra_debug_sigprint( arma_str::format("in_n_rows = %d, in_n_cols = %d") % in_n_rows % in_n_cols ); + arma_extra_debug_sigprint( arma_str::format("in_n_rows = %u, in_n_cols = %u") % in_n_rows % in_n_cols ); if( (n_rows == in_n_rows) && (n_cols == in_n_cols) ) { return; } bool err_state = false; - char* err_msg = 0; + char* err_msg = nullptr; const uhword t_vec_state = vec_state; const uhword t_mem_state = mem_state; - arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), "Mat::init(): size is fixed and hence cannot be changed" ); + const char* error_message_1 = "Mat::init(): size is fixed and hence cannot be changed"; + const char* error_message_2 = "Mat::init(): requested size is not compatible with column vector layout"; + const char* error_message_3 = "Mat::init(): requested size is not compatible with row vector layout"; + + arma_debug_set_error( err_state, err_msg, (t_mem_state == 3), error_message_1 ); if(t_vec_state > 0) { @@ -249,17 +366,17 @@ Mat::init_warm(uword in_n_rows, uword in_n_cols) } else { - if(t_vec_state == 1) { arma_debug_set_error( err_state, err_msg, (in_n_cols != 1), "Mat::init(): requested size is not compatible with column vector layout" ); } - if(t_vec_state == 2) { arma_debug_set_error( err_state, err_msg, (in_n_rows != 1), "Mat::init(): requested size is not compatible with row vector layout" ); } + if(t_vec_state == 1) { arma_debug_set_error( err_state, err_msg, (in_n_cols != 1), error_message_2 ); } + if(t_vec_state == 2) { arma_debug_set_error( err_state, err_msg, (in_n_rows != 1), error_message_3 ); } } } // ensure that n_elem can hold the result of (n_rows * n_cols) - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) - const char* error_message = "Mat::init(): requested size is too large"; + #if defined(ARMA_64BIT_WORD) + const char* error_message_4 = "Mat::init(): requested size is too large"; #else - const char* error_message = "Mat::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message_4 = "Mat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_set_error @@ -271,7 +388,7 @@ Mat::init_warm(uword in_n_rows, uword in_n_cols) ? ( (double(in_n_rows) * double(in_n_cols)) > double(ARMA_MAX_UWORD) ) : false ), - error_message + error_message_4 ); arma_debug_check(err_state, err_msg); @@ -284,62 +401,55 @@ Mat::init_warm(uword in_n_rows, uword in_n_cols) arma_extra_debug_print("Mat::init(): reusing memory"); access::rw(n_rows) = in_n_rows; access::rw(n_cols) = in_n_cols; + return; } - else // condition: old_n_elem != new_n_elem + + arma_debug_check( (t_mem_state == 2), "Mat::init(): mismatch between size of auxiliary memory and requested size" ); + + if(new_n_elem <= arma_config::mat_prealloc) { - arma_debug_check( (t_mem_state == 2), "Mat::init(): mismatch between size of auxiliary memory and requested size" ); - - if(new_n_elem < old_n_elem) // reuse existing memory if possible + if(n_alloc > 0) { - if( (t_mem_state == 0) && (new_n_elem <= arma_config::mat_prealloc) ) - { - if(old_n_elem > arma_config::mat_prealloc) - { - arma_extra_debug_print("Mat::init(): releasing memory"); - memory::release( access::rw(mem) ); - } - - if(new_n_elem == 0) - { - access::rw(mem) = NULL; - } - else - { - arma_extra_debug_print("Mat::init(): using local memory"); - access::rw(mem) = mem_local; - } - } - else - { - arma_extra_debug_print("Mat::init(): reusing memory"); - } + arma_extra_debug_print("Mat::init(): releasing memory"); + memory::release( access::rw(mem) ); } - else // condition: new_n_elem > old_n_elem + + if(new_n_elem > 0) { arma_extra_debug_print("Mat::init(): using local memory"); } + + access::rw(mem) = (new_n_elem == 0) ? nullptr : mem_local; + access::rw(n_alloc) = 0; + } + else // condition: new_n_elem > arma_config::mat_prealloc + { + if(new_n_elem > n_alloc) { - if( (t_mem_state == 0) && (old_n_elem > arma_config::mat_prealloc) ) + if(n_alloc > 0) { arma_extra_debug_print("Mat::init(): releasing memory"); memory::release( access::rw(mem) ); + + // in case memory::acquire() throws an exception + access::rw(mem) = nullptr; + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem) = 0; + access::rw(n_alloc) = 0; } - if(new_n_elem <= arma_config::mat_prealloc) - { - arma_extra_debug_print("Mat::init(): using local memory"); - access::rw(mem) = mem_local; - } - else - { - arma_extra_debug_print("Mat::init(): acquiring memory"); - access::rw(mem) = memory::acquire(new_n_elem); - } - - access::rw(mem_state) = 0; + arma_extra_debug_print("Mat::init(): acquiring memory"); + access::rw(mem) = memory::acquire(new_n_elem); + access::rw(n_alloc) = new_n_elem; + } + else // condition: new_n_elem <= n_alloc + { + arma_extra_debug_print("Mat::init(): reusing memory"); } - - access::rw(n_rows) = in_n_rows; - access::rw(n_cols) = in_n_cols; - access::rw(n_elem) = new_n_elem; } + + access::rw(n_rows) = in_n_rows; + access::rw(n_cols) = in_n_cols; + access::rw(n_elem) = new_n_elem; + access::rw(mem_state) = 0; } @@ -347,11 +457,11 @@ Mat::init_warm(uword in_n_rows, uword in_n_cols) //! create the matrix from a textual description template inline -arma_cold Mat::Mat(const char* text) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -366,13 +476,13 @@ Mat::Mat(const char* text) //! create the matrix from a textual description template inline -arma_cold Mat& Mat::operator=(const char* text) { arma_extra_debug_sigprint(); init( std::string(text) ); + return *this; } @@ -381,11 +491,11 @@ Mat::operator=(const char* text) //! create the matrix from a textual description template inline -arma_cold Mat::Mat(const std::string& text) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -400,13 +510,13 @@ Mat::Mat(const std::string& text) //! create the matrix from a textual description template inline -arma_cold Mat& Mat::operator=(const std::string& text) { arma_extra_debug_sigprint(); init(text); + return *this; } @@ -415,7 +525,6 @@ Mat::operator=(const std::string& text) //! internal function to create the matrix from a textual description template inline -arma_cold void Mat::init(const std::string& text_orig) { @@ -537,6 +646,7 @@ Mat::Mat(const std::vector& x) : n_rows(uword(x.size())) , n_cols(1) , n_elem(uword(x.size())) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -545,10 +655,7 @@ Mat::Mat(const std::vector& x) init_cold(); - if(n_elem > 0) - { - arrayops::copy( memptr(), &(x[0]), n_elem ); - } + if(n_elem > 0) { arrayops::copy( memptr(), &(x[0]), n_elem ); } } @@ -563,143 +670,132 @@ Mat::operator=(const std::vector& x) init_warm(uword(x.size()), 1); - if(x.size() > 0) - { - arrayops::copy( memptr(), &(x[0]), uword(x.size()) ); - } + if(x.size() > 0) { arrayops::copy( memptr(), &(x[0]), uword(x.size()) ); } return *this; } -#if defined(ARMA_USE_CXX11) - - template - inline - Mat::Mat(const std::initializer_list& list) - : n_rows(0) - , n_cols(0) - , n_elem(0) - , vec_state(0) - , mem_state(0) - , mem() - { - arma_extra_debug_sigprint_this(this); - - init(list); - } - - - - template - inline - Mat& - Mat::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - init(list); - - return *this; - } +template +inline +Mat::Mat(const std::initializer_list& list) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + init(list); + } + + + +template +inline +Mat& +Mat::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + init(list); - template - inline - Mat::Mat(const std::initializer_list< std::initializer_list >& list) - : n_rows(0) - , n_cols(0) - , n_elem(0) - , vec_state(0) - , mem_state(0) - , mem() - { - arma_extra_debug_sigprint_this(this); - - init(list); - } + return *this; + } + + + +template +inline +Mat::Mat(const std::initializer_list< std::initializer_list >& list) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + init(list); + } + + + +template +inline +Mat& +Mat::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + init(list); - template - inline - Mat& - Mat::operator=(const std::initializer_list< std::initializer_list >& list) - { - arma_extra_debug_sigprint(); - - init(list); - - return *this; - } + return *this; + } - template - inline - Mat::Mat(Mat&& X) - : n_rows (X.n_rows) - , n_cols (X.n_cols) - , n_elem (X.n_elem) - , vec_state(0 ) - , mem_state(0 ) - , mem ( ) +template +inline +Mat::Mat(Mat&& X) + : n_rows (X.n_rows ) + , n_cols (X.n_cols ) + , n_elem (X.n_elem ) + , n_alloc (X.n_alloc) + , vec_state(0 ) + , mem_state(0 ) + , mem ( ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + access::rw(mem_state) = X.mem_state; + access::rw(mem) = X.mem; - if( ((X.mem_state == 0) && (X.n_elem > arma_config::mat_prealloc)) || (X.mem_state == 1) || (X.mem_state == 2) ) - { - access::rw(mem_state) = X.mem_state; - access::rw(mem) = X.mem; - - access::rw(X.n_rows) = 0; - access::rw(X.n_cols) = 0; - access::rw(X.n_elem) = 0; - access::rw(X.mem_state) = 0; - access::rw(X.mem) = 0; - } - else - { - init_cold(); - - arrayops::copy( memptr(), X.mem, X.n_elem ); - - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) ) - { - access::rw(X.n_rows) = 0; - access::rw(X.n_cols) = 0; - access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; - } - } + access::rw(X.n_rows) = 0; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; } - - - - template - inline - Mat& - Mat::operator=(Mat&& X) + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + init_cold(); - (*this).steal_mem(X); + arrayops::copy( memptr(), X.mem, X.n_elem ); - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) && (this != &X) ) + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) { access::rw(X.n_rows) = 0; access::rw(X.n_cols) = 0; access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; + access::rw(X.mem) = nullptr; } - - return *this; } + } + + + +template +inline +Mat& +Mat::operator=(Mat&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); -#endif + (*this).steal_mem(X, true); + return *this; + } + //! Set the matrix to be equal to the specified scalar. @@ -712,7 +808,9 @@ Mat::operator=(const eT val) arma_extra_debug_sigprint(); init_warm(1,1); + access::rw(mem[0]) = val; + return *this; } @@ -785,6 +883,7 @@ Mat::Mat(const Mat& in_mat) : n_rows(in_mat.n_rows) , n_cols(in_mat.n_cols) , n_elem(in_mat.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -818,86 +917,74 @@ Mat::operator=(const Mat& in_mat) -#if defined(ARMA_USE_CXX11) +template +inline +void +Mat::init(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); - template - inline - void - Mat::init(const std::initializer_list& list) + const uword N = uword(list.size()); + + set_size(1, N); + + if(N > 0) { arrayops::copy( memptr(), list.begin(), N ); } + } + + + +template +inline +void +Mat::init(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + uword x_n_rows = uword(list.size()); + uword x_n_cols = 0; + + auto it = list.begin(); + auto it_end = list.end(); + + for(; it != it_end; ++it) { x_n_cols = (std::max)(x_n_cols, uword((*it).size())); } + + Mat& t = (*this); + + if(t.mem_state == 3) { - arma_extra_debug_sigprint(); - - const uword N = uword(list.size()); - - set_size(1, N); - - arrayops::copy( memptr(), list.begin(), N ); + arma_debug_check( ((x_n_rows != t.n_rows) || (x_n_cols != t.n_cols)), "Mat::init(): size mismatch between fixed size matrix and initialiser list" ); + } + else + { + t.set_size(x_n_rows, x_n_cols); } + uword row_num = 0; + auto row_it = list.begin(); + auto row_it_end = list.end(); - template - inline - void - Mat::init(const std::initializer_list< std::initializer_list >& list) + for(; row_it != row_it_end; ++row_it) { - arma_extra_debug_sigprint(); - - uword x_n_rows = uword(list.size()); - uword x_n_cols = 0; + uword col_num = 0; - bool x_n_cols_found = false; + auto col_it = (*row_it).begin(); + auto col_it_end = (*row_it).end(); - auto it = list.begin(); - auto it_end = list.end(); - - for(; it != it_end; ++it) + for(; col_it != col_it_end; ++col_it) { - if(x_n_cols_found == false) - { - x_n_cols = uword((*it).size()); - x_n_cols_found = true; - } - else - { - arma_check( (uword((*it).size()) != x_n_cols), "Mat::init(): inconsistent number of columns in initialiser list" ); - } + t.at(row_num, col_num) = (*col_it); + ++col_num; } - Mat& t = (*this); - - if(t.mem_state == 3) - { - arma_debug_check( ((x_n_rows != t.n_rows) || (x_n_cols != t.n_cols)), "Mat::init(): size mismatch between fixed size matrix and initialiser list" ); - } - else + for(uword c=col_num; c < x_n_cols; ++c) { - t.set_size(x_n_rows, x_n_cols); + t.at(row_num, c) = eT(0); } - uword row_num = 0; - - auto row_it = list.begin(); - auto row_it_end = list.end(); - - for(; row_it != row_it_end; ++row_it) - { - uword col_num = 0; - - auto col_it = (*row_it).begin(); - auto col_it_end = (*row_it).end(); - - for(; col_it != col_it_end; ++col_it) - { - t.at(row_num, col_num) = (*col_it); - ++col_num; - } - - ++row_num; - } + ++row_num; } - -#endif + } @@ -998,8 +1085,8 @@ Mat::swap(Mat& B) const uword A_n_elem = A.n_elem; const uword B_n_elem = B.n_elem; - const bool A_use_local_mem = (A_n_elem <= arma_config::mat_prealloc); - const bool B_use_local_mem = (B_n_elem <= arma_config::mat_prealloc); + const bool A_use_local_mem = (A.n_alloc <= arma_config::mat_prealloc); + const bool B_use_local_mem = (B.n_alloc <= arma_config::mat_prealloc); if( (A_use_local_mem == false) && (B_use_local_mem == false) ) { @@ -1041,9 +1128,10 @@ Mat::swap(Mat& B) access::rw(A.mem) = A_mem_local; } - std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); - std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); - std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); + std::swap( access::rw(A.n_rows), access::rw(B.n_rows) ); + std::swap( access::rw(A.n_cols), access::rw(B.n_cols) ); + std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); + std::swap( access::rw(A.n_alloc), access::rw(B.n_alloc) ); } else if( (A_mem_state <= 2) && (B_mem_state <= 2) && (A.n_elem == B.n_elem) && layout_ok ) @@ -1096,7 +1184,19 @@ Mat::swap(Mat& B) template inline void -Mat::steal_mem(Mat& x) +Mat::steal_mem(Mat& x) + { + arma_extra_debug_sigprint(); + + (*this).steal_mem(x, false); + } + + + +template +inline +void +Mat::steal_mem(Mat& x, const bool is_move) { arma_extra_debug_sigprint(); @@ -1105,44 +1205,48 @@ Mat::steal_mem(Mat& x) const uword x_n_rows = x.n_rows; const uword x_n_cols = x.n_cols; const uword x_n_elem = x.n_elem; + const uword x_n_alloc = x.n_alloc; const uhword x_vec_state = x.vec_state; const uhword x_mem_state = x.mem_state; const uhword t_vec_state = vec_state; const uhword t_mem_state = mem_state; - bool layout_ok = false; - - if(t_vec_state == x_vec_state) - { - layout_ok = true; - } - else - { - if( (t_vec_state == 1) && (x_n_cols == 1) ) { layout_ok = true; } - if( (t_vec_state == 2) && (x_n_rows == 1) ) { layout_ok = true; } - } - + const bool layout_ok = (t_vec_state == x_vec_state) || ((t_vec_state == 1) && (x_n_cols == 1)) || ((t_vec_state == 2) && (x_n_rows == 1)); - if( (t_mem_state <= 1) && ( ((x_mem_state == 0) && (x_n_elem > arma_config::mat_prealloc)) || (x_mem_state == 1) ) && layout_ok ) + if( layout_ok && (t_mem_state <= 1) && ( (x_n_alloc > arma_config::mat_prealloc) || (x_mem_state == 1) || (is_move && (x_mem_state == 2)) ) ) { + arma_extra_debug_print("Mat::steal_mem(): stealing memory"); + reset(); access::rw(n_rows) = x_n_rows; access::rw(n_cols) = x_n_cols; access::rw(n_elem) = x_n_elem; + access::rw(n_alloc) = x_n_alloc; access::rw(mem_state) = x_mem_state; access::rw(mem) = x.mem; - access::rw(x.n_rows) = 0; - access::rw(x.n_cols) = 0; + access::rw(x.n_rows) = (x_vec_state == 2) ? 1 : 0; + access::rw(x.n_cols) = (x_vec_state == 1) ? 1 : 0; access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; access::rw(x.mem_state) = 0; - access::rw(x.mem) = 0; + access::rw(x.mem) = nullptr; } else { + arma_extra_debug_print("Mat::steal_mem(): copying memory"); + (*this).operator=(x); + + if( (is_move) && (x_mem_state == 0) && (x_n_alloc <= arma_config::mat_prealloc) ) + { + access::rw(x.n_rows) = (x_vec_state == 2) ? 1 : 0; + access::rw(x.n_cols) = (x_vec_state == 1) ? 1 : 0; + access::rw(x.n_elem) = 0; + access::rw(x.mem) = nullptr; + } } } @@ -1156,6 +1260,7 @@ Mat::steal_mem_col(Mat& x, const uword max_n_rows) arma_extra_debug_sigprint(); const uword x_n_elem = x.n_elem; + const uword x_n_alloc = x.n_alloc; const uhword x_mem_state = x.mem_state; const uhword t_vec_state = vec_state; @@ -1172,7 +1277,7 @@ Mat::steal_mem_col(Mat& x, const uword max_n_rows) if( (this != &x) && (t_vec_state <= 1) && (t_mem_state <= 1) && (x_mem_state <= 1) ) { - if( (x_mem_state == 0) && ((x_n_elem <= arma_config::mat_prealloc) || (alt_n_rows <= arma_config::mat_prealloc)) ) + if( (x_mem_state == 0) && ((x_n_alloc <= arma_config::mat_prealloc) || (alt_n_rows <= arma_config::mat_prealloc)) ) { (*this).set_size(alt_n_rows, uword(1)); @@ -1185,19 +1290,21 @@ Mat::steal_mem_col(Mat& x, const uword max_n_rows) access::rw(n_rows) = alt_n_rows; access::rw(n_cols) = 1; access::rw(n_elem) = alt_n_rows; + access::rw(n_alloc) = x_n_alloc; access::rw(mem_state) = x_mem_state; access::rw(mem) = x.mem; access::rw(x.n_rows) = 0; access::rw(x.n_cols) = 0; access::rw(x.n_elem) = 0; + access::rw(x.n_alloc) = 0; access::rw(x.mem_state) = 0; - access::rw(x.mem) = 0; + access::rw(x.mem) = nullptr; } } else { - Mat tmp(alt_n_rows, 1); + Mat tmp(alt_n_rows, 1, arma_nozeros_indicator()); arrayops::copy( tmp.memptr(), x.memptr(), alt_n_rows ); @@ -1218,9 +1325,10 @@ Mat::Mat(eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols, const : n_rows ( aux_n_rows ) , n_cols ( aux_n_cols ) , n_elem ( aux_n_rows*aux_n_cols ) + , n_alloc ( 0 ) , vec_state( 0 ) , mem_state( copy_aux_mem ? 0 : ( strict ? 2 : 1 ) ) - , mem ( copy_aux_mem ? 0 : aux_mem ) + , mem ( copy_aux_mem ? nullptr : aux_mem ) { arma_extra_debug_sigprint_this(this); @@ -1242,6 +1350,7 @@ Mat::Mat(const eT* aux_mem, const uword aux_n_rows, const uword aux_n_cols) : n_rows(aux_n_rows) , n_cols(aux_n_cols) , n_elem(aux_n_rows*aux_n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -1265,11 +1374,13 @@ Mat::Mat(const char junk, const eT* aux_mem, const uword aux_n_rows, const u : n_rows (aux_n_rows ) , n_cols (aux_n_cols ) , n_elem (aux_n_rows*aux_n_cols) + , n_alloc (0 ) , vec_state(0 ) , mem_state(3 ) , mem (aux_mem ) { arma_extra_debug_sigprint_this(this); + arma_ignore(junk); } @@ -1365,6 +1476,7 @@ Mat::Mat(const BaseCube& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -1816,6 +1928,7 @@ Mat::Mat : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -1833,9 +1946,10 @@ Mat::Mat(const subview& X, const bool use_colmem) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc(0) , vec_state(0) , mem_state(use_colmem ? 3 : 0) - , mem (use_colmem ? X.colptr(0) : NULL) + , mem (use_colmem ? X.colptr(0) : nullptr) { arma_extra_debug_sigprint_this(this); @@ -1853,13 +1967,14 @@ Mat::Mat(const subview& X, const bool use_colmem) -//! construct a matrix from subview (e.g. construct a matrix from a delayed submatrix operation) +//! construct a matrix from subview (eg. construct a matrix from a delayed submatrix operation) template inline Mat::Mat(const subview& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -1873,7 +1988,7 @@ Mat::Mat(const subview& X) -//! construct a matrix from subview (e.g. construct a matrix from a delayed submatrix operation) +//! construct a matrix from subview (eg. construct a matrix from a delayed submatrix operation) template inline Mat& @@ -1980,6 +2095,7 @@ Mat::Mat(const subview_row_strans& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -1999,6 +2115,7 @@ Mat::Mat(const subview_row_htrans& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -2018,6 +2135,7 @@ Mat::Mat(const xvec_htrans& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -2038,6 +2156,7 @@ Mat::Mat(const xtrans_mat& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -2058,6 +2177,7 @@ Mat::Mat(const subview_cube& x) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -2123,6 +2243,7 @@ Mat::operator*=(const subview_cube& X) arma_extra_debug_sigprint(); const Mat tmp(X); + glue_times::apply_inplace(*this, tmp); return *this; @@ -2160,13 +2281,14 @@ Mat::operator/=(const subview_cube& X) -//! construct a matrix from diagview (e.g. construct a matrix from a delayed diag operation) +//! construct a matrix from diagview (eg. construct a matrix from a delayed diag operation) template inline Mat::Mat(const diagview& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -2180,7 +2302,7 @@ Mat::Mat(const diagview& X) -//! construct a matrix from diagview (e.g. construct a matrix from a delayed diag operation) +//! construct a matrix from diagview (eg. construct a matrix from a delayed diag operation) template inline Mat& @@ -2290,6 +2412,7 @@ Mat::Mat(const subview_elem1& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -2398,6 +2521,7 @@ Mat::Mat(const subview_elem2& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -2506,6 +2630,7 @@ Mat::Mat(const SpBase& m) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -2525,46 +2650,30 @@ Mat::operator=(const SpBase& m) { arma_extra_debug_sigprint(); - if( (is_SpMat::value) || (is_SpMat::stored_type>::value) ) + const unwrap_spmat U(m.get_ref()); + const SpMat& x = U.M; + + const uword x_n_cols = x.n_cols; + + (*this).zeros(x.n_rows, x_n_cols); + + if(x.n_nonzero == 0) { return *this; } + + const eT* x_values = x.values; + const uword* x_row_indices = x.row_indices; + const uword* x_col_ptrs = x.col_ptrs; + + for(uword x_col = 0; x_col < x_n_cols; ++x_col) { - const unwrap_spmat U(m.get_ref()); - const SpMat& x = U.M; - - const uword x_n_cols = x.n_cols; + const uword start = x_col_ptrs[x_col ]; + const uword end = x_col_ptrs[x_col + 1]; - (*this).zeros(x.n_rows, x_n_cols); - - const eT* x_values = x.values; - const uword* x_row_indices = x.row_indices; - const uword* x_col_ptrs = x.col_ptrs; - - for(uword x_col = 0; x_col < x_n_cols; ++x_col) + for(uword i = start; i < end; ++i) { - const uword start = x_col_ptrs[x_col ]; - const uword end = x_col_ptrs[x_col + 1]; + const uword x_row = x_row_indices[i]; + const eT x_val = x_values[i]; - for(uword i = start; i < end; ++i) - { - const uword x_row = x_row_indices[i]; - const eT x_val = x_values[i]; - - at(x_row, x_col) = x_val; - } - } - } - else - { - const SpProxy p(m.get_ref()); - - (*this).zeros(p.get_n_rows(), p.get_n_cols()); - - typename SpProxy::const_iterator_type it = p.begin(); - typename SpProxy::const_iterator_type it_end = p.end(); - - while(it != it_end) - { - at(it.row(), it.col()) = (*it); - ++it; + at(x_row, x_col) = x_val; } } @@ -2588,11 +2697,7 @@ Mat::operator+=(const SpBase& m) typename SpProxy::const_iterator_type it = p.begin(); typename SpProxy::const_iterator_type it_end = p.end(); - while(it != it_end) - { - at(it.row(), it.col()) += (*it); - ++it; - } + for(; it != it_end; ++it) { at(it.row(), it.col()) += (*it); } return *this; } @@ -2614,11 +2719,7 @@ Mat::operator-=(const SpBase& m) typename SpProxy::const_iterator_type it = p.begin(); typename SpProxy::const_iterator_type it_end = p.end(); - while(it != it_end) - { - at(it.row(), it.col()) -= (*it); - ++it; - } + for(; it != it_end; ++it) { at(it.row(), it.col()) -= (*it); } return *this; } @@ -2658,7 +2759,7 @@ Mat::operator%=(const SpBase& m) typename SpProxy::const_iterator_type it_end = p.end(); // We have to zero everything that isn't being used. - arrayops::inplace_set(memptr(), eT(0), (it.col() * n_rows) + it.row()); + arrayops::fill_zeros(memptr(), (it.col() * n_rows) + it.row()); while(it != it_end) { @@ -2672,7 +2773,7 @@ Mat::operator%=(const SpBase& m) ? (p.get_n_cols() * n_rows) : (it.col() * n_rows) + it.row(); - arrayops::inplace_set(memptr() + cur_loc + 1, eT(0), (next_loc - cur_loc - 1)); + arrayops::fill_zeros(memptr() + cur_loc + 1, (next_loc - cur_loc - 1)); } return *this; @@ -2688,12 +2789,12 @@ Mat::operator/=(const SpBase& m) { arma_extra_debug_sigprint(); + // NOTE: use of this function is not advised; it is implemented only for completeness + const SpProxy p(m.get_ref()); arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise division"); - // If you use this method, you are probably stupid or misguided, but for completeness it is implemented. - // Unfortunately the best way to do this is loop over every element. for(uword c = 0; c < n_cols; ++c) for(uword r = 0; r < n_rows; ++r) { @@ -2705,12 +2806,82 @@ Mat::operator/=(const SpBase& m) +template +inline +Mat::Mat(const SpSubview& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(X); + } + + + +template +inline +Mat& +Mat::operator=(const SpSubview& X) + { + arma_extra_debug_sigprint(); + + (*this).zeros(X.n_rows, X.n_cols); + + if(X.n_nonzero == 0) { return *this; } + + if(X.n_rows == X.m.n_rows) + { + X.m.sync(); + + const uword sv_col_start = X.aux_col1; + const uword sv_col_end = X.aux_col1 + X.n_cols - 1; + + const eT* m_values = X.m.values; + const uword* m_row_indices = X.m.row_indices; + const uword* m_col_ptrs = X.m.col_ptrs; + + for(uword m_col = sv_col_start; m_col <= sv_col_end; ++m_col) + { + const uword m_col_adjusted = m_col - sv_col_start; + + const uword start = m_col_ptrs[m_col ]; + const uword end = m_col_ptrs[m_col + 1]; + + for(uword ii = start; ii < end; ++ii) + { + const uword m_row = m_row_indices[ii]; + const eT m_val = m_values[ii]; + + at(m_row, m_col_adjusted) = m_val; + } + } + } + else + { + typename SpSubview::const_iterator it = X.begin(); + typename SpSubview::const_iterator it_end = X.end(); + + for(; it != it_end; ++it) { at(it.row(), it.col()) = (*it); } + } + + return *this; + } + + + template inline Mat::Mat(const spdiagview& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(X.n_elem) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -2848,7 +3019,7 @@ Mat::row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "Mat::row(): index out of bounds" ); + arma_debug_check_bounds( row_num >= n_rows, "Mat::row(): index out of bounds" ); return subview_row(*this, row_num); } @@ -2863,7 +3034,7 @@ Mat::row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "Mat::row(): index out of bounds" ); + arma_debug_check_bounds( row_num >= n_rows, "Mat::row(): index out of bounds" ); return subview_row(*this, row_num); } @@ -2885,7 +3056,7 @@ Mat::operator()(const uword row_num, const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -2914,7 +3085,7 @@ Mat::operator()(const uword row_num, const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -2936,7 +3107,7 @@ Mat::col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Mat::col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Mat::col(): index out of bounds" ); return subview_col(*this, col_num); } @@ -2951,7 +3122,7 @@ Mat::col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Mat::col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Mat::col(): index out of bounds" ); return subview_col(*this, col_num); } @@ -2973,7 +3144,7 @@ Mat::operator()(const span& row_span, const uword col_num) const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -3002,7 +3173,7 @@ Mat::operator()(const span& row_span, const uword col_num) const const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -3028,7 +3199,7 @@ Mat::unsafe_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds" ); return Col(colptr(col_num), n_rows, false, true); } @@ -3047,7 +3218,7 @@ Mat::unsafe_col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Mat::unsafe_col(): index out of bounds" ); typedef const Col out_type; @@ -3064,7 +3235,7 @@ Mat::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Mat::rows(): indices out of bounds or incorrectly used" @@ -3085,7 +3256,7 @@ Mat::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Mat::rows(): indices out of bounds or incorrectly used" @@ -3101,12 +3272,12 @@ Mat::rows(const uword in_row1, const uword in_row2) const //! creation of subview (submatrix comprised of specified column vectors) template arma_inline -subview +subview_cols Mat::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Mat::cols(): indices out of bounds or incorrectly used" @@ -3114,7 +3285,7 @@ Mat::cols(const uword in_col1, const uword in_col2) const uword subview_n_cols = in_col2 - in_col1 + 1; - return subview(*this, 0, in_col1, n_rows, subview_n_cols); + return subview_cols(*this, in_col1, subview_n_cols); } @@ -3122,12 +3293,12 @@ Mat::cols(const uword in_col1, const uword in_col2) //! creation of subview (submatrix comprised of specified column vectors) template arma_inline -const subview +const subview_cols Mat::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Mat::cols(): indices out of bounds or incorrectly used" @@ -3135,7 +3306,7 @@ Mat::cols(const uword in_col1, const uword in_col2) const const uword subview_n_cols = in_col2 - in_col1 + 1; - return subview(*this, 0, in_col1, n_rows, subview_n_cols); + return subview_cols(*this, in_col1, subview_n_cols); } @@ -3156,7 +3327,7 @@ Mat::rows(const span& row_span) const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) , @@ -3184,7 +3355,7 @@ Mat::rows(const span& row_span) const const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) , @@ -3199,7 +3370,7 @@ Mat::rows(const span& row_span) const //! creation of subview (submatrix comprised of specified column vectors) template arma_inline -subview +subview_cols Mat::cols(const span& col_span) { arma_extra_debug_sigprint(); @@ -3212,14 +3383,14 @@ Mat::cols(const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) , "Mat::cols(): indices out of bounds or incorrectly used" ); - return subview(*this, 0, in_col1, n_rows, submat_n_cols); + return subview_cols(*this, in_col1, submat_n_cols); } @@ -3227,7 +3398,7 @@ Mat::cols(const span& col_span) //! creation of subview (submatrix comprised of specified column vectors) template arma_inline -const subview +const subview_cols Mat::cols(const span& col_span) const { arma_extra_debug_sigprint(); @@ -3240,14 +3411,14 @@ Mat::cols(const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ) , "Mat::cols(): indices out of bounds or incorrectly used" ); - return subview(*this, 0, in_col1, n_rows, submat_n_cols); + return subview_cols(*this, in_col1, submat_n_cols); } @@ -3260,7 +3431,7 @@ Mat::submat(const uword in_row1, const uword in_col1, const uword in_row2, c { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "Mat::submat(): indices out of bounds or incorrectly used" @@ -3282,7 +3453,7 @@ Mat::submat(const uword in_row1, const uword in_col1, const uword in_row2, c { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "Mat::submat(): indices out of bounds or incorrectly used" @@ -3310,7 +3481,7 @@ Mat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "Mat::submat(): indices or size out of bounds" @@ -3335,7 +3506,7 @@ Mat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) cons const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "Mat::submat(): indices or size out of bounds" @@ -3368,7 +3539,7 @@ Mat::submat(const span& row_span, const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -3404,7 +3575,7 @@ Mat::submat(const span& row_span, const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -3473,7 +3644,7 @@ Mat::head_rows(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "Mat::head_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "Mat::head_rows(): size out of bounds" ); return subview(*this, 0, 0, N, n_cols); } @@ -3487,7 +3658,7 @@ Mat::head_rows(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "Mat::head_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "Mat::head_rows(): size out of bounds" ); return subview(*this, 0, 0, N, n_cols); } @@ -3501,7 +3672,7 @@ Mat::tail_rows(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "Mat::tail_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "Mat::tail_rows(): size out of bounds" ); const uword start_row = n_rows - N; @@ -3517,7 +3688,7 @@ Mat::tail_rows(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "Mat::tail_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "Mat::tail_rows(): size out of bounds" ); const uword start_row = n_rows - N; @@ -3528,60 +3699,60 @@ Mat::tail_rows(const uword N) const template inline -subview +subview_cols Mat::head_cols(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "Mat::head_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "Mat::head_cols(): size out of bounds" ); - return subview(*this, 0, 0, n_rows, N); + return subview_cols(*this, 0, N); } template inline -const subview +const subview_cols Mat::head_cols(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "Mat::head_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "Mat::head_cols(): size out of bounds" ); - return subview(*this, 0, 0, n_rows, N); + return subview_cols(*this, 0, N); } template inline -subview +subview_cols Mat::tail_cols(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "Mat::tail_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "Mat::tail_cols(): size out of bounds" ); const uword start_col = n_cols - N; - return subview(*this, 0, start_col, n_rows, N); + return subview_cols(*this, start_col, N); } template inline -const subview +const subview_cols Mat::tail_cols(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "Mat::tail_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "Mat::tail_cols(): size out of bounds" ); const uword start_col = n_cols - N; - return subview(*this, 0, start_col, n_rows, N); + return subview_cols(*this, start_col, N); } @@ -3868,143 +4039,139 @@ Mat::each_row(const Base& indices) const -#if defined(ARMA_USE_CXX11) +//! apply a lambda function to each column, where each column is interpreted as a column vector +template +inline +Mat& +Mat::each_col(const std::function< void(Col&) >& F) + { + arma_extra_debug_sigprint(); - //! apply a lambda function to each column, where each column is interpreted as a column vector - template - inline - const Mat& - Mat::each_col(const std::function< void(Col&) >& F) + for(uword ii=0; ii < n_cols; ++ii) { - arma_extra_debug_sigprint(); - - for(uword ii=0; ii < n_cols; ++ii) - { - Col tmp(colptr(ii), n_rows, false, true); - F(tmp); - } - - return *this; + Col tmp(colptr(ii), n_rows, false, true); + F(tmp); } + return *this; + } + + + +template +inline +const Mat& +Mat::each_col(const std::function< void(const Col&) >& F) const + { + arma_extra_debug_sigprint(); - - template - inline - const Mat& - Mat::each_col(const std::function< void(const Col&) >& F) const + for(uword ii=0; ii < n_cols; ++ii) { - arma_extra_debug_sigprint(); - - for(uword ii=0; ii < n_cols; ++ii) - { - const Col tmp(const_cast(colptr(ii)), n_rows, false, true); - F(tmp); - } - - return *this; + const Col tmp(const_cast(colptr(ii)), n_rows, false, true); + F(tmp); } + return *this; + } + + + +//! apply a lambda function to each row, where each row is interpreted as a row vector +template +inline +Mat& +Mat::each_row(const std::function< void(Row&) >& F) + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); - //! apply a lambda function to each row, where each row is interpreted as a row vector - template - inline - const Mat& - Mat::each_row(const std::function< void(Row&) >& F) + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) { - arma_extra_debug_sigprint(); - - podarray array1(n_cols); - podarray array2(n_cols); - - Row tmp1( array1.memptr(), n_cols, false, true ); - Row tmp2( array2.memptr(), n_cols, false, true ); - - eT* tmp1_mem = tmp1.memptr(); - eT* tmp2_mem = tmp2.memptr(); - - uword ii, jj; - - for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + for(uword col_id = 0; col_id < n_cols; ++col_id) { - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - const eT* col_mem = colptr(col_id); - - tmp1_mem[col_id] = col_mem[ii]; - tmp2_mem[col_id] = col_mem[jj]; - } - - F(tmp1); - F(tmp2); + const eT* col_mem = colptr(col_id); - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - eT* col_mem = colptr(col_id); - - col_mem[ii] = tmp1_mem[col_id]; - col_mem[jj] = tmp2_mem[col_id]; - } + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; } - if(ii < n_rows) + F(tmp1); + F(tmp2); + + for(uword col_id = 0; col_id < n_cols; ++col_id) { - tmp1 = (*this).row(ii); + eT* col_mem = colptr(col_id); - F(tmp1); - - (*this).row(ii) = tmp1; + col_mem[ii] = tmp1_mem[col_id]; + col_mem[jj] = tmp2_mem[col_id]; } + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); - return *this; + F(tmp1); + + (*this).row(ii) = tmp1; } + return *this; + } + + + +template +inline +const Mat& +Mat::each_row(const std::function< void(const Row&) >& F) const + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); - template - inline - const Mat& - Mat::each_row(const std::function< void(const Row&) >& F) const + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) { - arma_extra_debug_sigprint(); - - podarray array1(n_cols); - podarray array2(n_cols); - - Row tmp1( array1.memptr(), n_cols, false, true ); - Row tmp2( array2.memptr(), n_cols, false, true ); - - eT* tmp1_mem = tmp1.memptr(); - eT* tmp2_mem = tmp2.memptr(); - - uword ii, jj; - - for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + for(uword col_id = 0; col_id < n_cols; ++col_id) { - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - const eT* col_mem = colptr(col_id); - - tmp1_mem[col_id] = col_mem[ii]; - tmp2_mem[col_id] = col_mem[jj]; - } + const eT* col_mem = colptr(col_id); - F(tmp1); - F(tmp2); + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; } - if(ii < n_rows) - { - tmp1 = (*this).row(ii); - - F(tmp1); - } + F(tmp1); + F(tmp2); + } + + if(ii < n_rows) + { + tmp1 = (*this).row(ii); - return *this; + F(tmp1); } -#endif + return *this; + } @@ -4019,7 +4186,7 @@ Mat::diag(const sword in_id) const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; const uword col_offset = (in_id > 0) ? uword( in_id) : 0; - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "Mat::diag(): requested diagonal out of bounds" @@ -4043,7 +4210,7 @@ Mat::diag(const sword in_id) const const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "Mat::diag(): requested diagonal out of bounds" @@ -4066,7 +4233,7 @@ Mat::swap_rows(const uword in_row1, const uword in_row2) const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; - arma_debug_check + arma_debug_check_bounds ( (in_row1 >= local_n_rows) || (in_row2 >= local_n_rows), "Mat::swap_rows(): index out of bounds" @@ -4097,7 +4264,7 @@ Mat::swap_cols(const uword in_colA, const uword in_colB) const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; - arma_debug_check + arma_debug_check_bounds ( (in_colA >= local_n_cols) || (in_colB >= local_n_cols), "Mat::swap_cols(): index out of bounds" @@ -4141,7 +4308,7 @@ Mat::shed_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "Mat::shed_row(): index out of bounds"); + arma_debug_check_bounds( row_num >= n_rows, "Mat::shed_row(): index out of bounds" ); shed_rows(row_num, row_num); } @@ -4156,7 +4323,7 @@ Mat::shed_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "Mat::shed_col(): index out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "Mat::shed_col(): index out of bounds" ); shed_cols(col_num, col_num); } @@ -4171,7 +4338,7 @@ Mat::shed_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "Mat::shed_rows(): indices out of bounds or incorrectly used" @@ -4180,7 +4347,7 @@ Mat::shed_rows(const uword in_row1, const uword in_row2) const uword n_keep_front = in_row1; const uword n_keep_back = n_rows - (in_row2 + 1); - Mat X(n_keep_front + n_keep_back, n_cols); + Mat X(n_keep_front + n_keep_back, n_cols, arma_nozeros_indicator()); if(n_keep_front > 0) { @@ -4205,7 +4372,7 @@ Mat::shed_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "Mat::shed_cols(): indices out of bounds or incorrectly used" @@ -4214,7 +4381,7 @@ Mat::shed_cols(const uword in_col1, const uword in_col2) const uword n_keep_front = in_col1; const uword n_keep_back = n_cols - (in_col2 + 1); - Mat X(n_rows, n_keep_front + n_keep_back); + Mat X(n_rows, n_keep_front + n_keep_back, arma_nozeros_indicator()); if(n_keep_front > 0) { @@ -4245,7 +4412,7 @@ Mat::shed_rows(const Base& indices) arma_debug_check( ((tmp1.is_vec() == false) && (tmp1.is_empty() == false)), "Mat::shed_rows(): list of indices must be a vector" ); - if(tmp1.is_empty()) { return; } + if(tmp1.is_empty()) { return; } const Col tmp2(const_cast(tmp1.memptr()), tmp1.n_elem, false, false); @@ -4260,11 +4427,11 @@ Mat::shed_rows(const Base& indices) { for(uword i=0; i= n_rows), "Mat::shed_rows(): indices out of bounds" ); + arma_debug_check_bounds( (rows_to_shed_mem[i] >= n_rows), "Mat::shed_rows(): indices out of bounds" ); } } - Col tmp3(n_rows); + Col tmp3(n_rows, arma_nozeros_indicator()); uword* tmp3_mem = tmp3.memptr(); @@ -4315,7 +4482,7 @@ Mat::shed_cols(const Base& indices) arma_debug_check( ((tmp1.is_vec() == false) && (tmp1.is_empty() == false)), "Mat::shed_cols(): list of indices must be a vector" ); - if(tmp1.is_empty()) { return; } + if(tmp1.is_empty()) { return; } const Col tmp2(const_cast(tmp1.memptr()), tmp1.n_elem, false, false); @@ -4330,11 +4497,11 @@ Mat::shed_cols(const Base& indices) { for(uword i=0; i= n_cols), "Mat::shed_cols(): indices out of bounds" ); + arma_debug_check_bounds( (cols_to_shed_mem[i] >= n_cols), "Mat::shed_cols(): indices out of bounds" ); } } - Col tmp3(n_cols); + Col tmp3(n_cols, arma_nozeros_indicator()); uword* tmp3_mem = tmp3.memptr(); @@ -4371,8 +4538,6 @@ Mat::shed_cols(const Base& indices) -//! insert N rows at the specified row position, -//! optionally setting the elements of the inserted rows to zero template inline void @@ -4380,6 +4545,20 @@ Mat::insert_rows(const uword row_num, const uword N, const bool set_to_zero) { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_rows(row_num, N); + } + + + +template +inline +void +Mat::insert_rows(const uword row_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_rows = n_rows; const uword t_n_cols = n_cols; @@ -4387,35 +4566,29 @@ Mat::insert_rows(const uword row_num, const uword N, const bool set_to_zero) const uword B_n_rows = t_n_rows - row_num; // insertion at row_num == n_rows is in effect an append operation - arma_debug_check( (row_num > t_n_rows), "Mat::insert_rows(): index out of bounds"); + arma_debug_check_bounds( (row_num > t_n_rows), "Mat::insert_rows(): index out of bounds" ); + + if(N == 0) { return; } + + Mat out(t_n_rows + N, t_n_cols, arma_nozeros_indicator()); - if(N > 0) + if(A_n_rows > 0) { - Mat out(t_n_rows + N, t_n_cols); - - if(A_n_rows > 0) - { - out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); - } - - if(B_n_rows > 0) - { - out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); - } - - if(set_to_zero) - { - out.rows(row_num, row_num + N - 1).zeros(); - } - - steal_mem(out); + out.rows(0, A_n_rows-1) = rows(0, A_n_rows-1); + } + + if(B_n_rows > 0) + { + out.rows(row_num + N, t_n_rows + N - 1) = rows(row_num, t_n_rows-1); } + + out.rows(row_num, row_num + N - 1).zeros(); + + steal_mem(out); } -//! insert N columns at the specified column position, -//! optionally setting the elements of the inserted columns to zero template inline void @@ -4423,6 +4596,20 @@ Mat::insert_cols(const uword col_num, const uword N, const bool set_to_zero) { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Mat::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_rows = n_rows; const uword t_n_cols = n_cols; @@ -4430,29 +4617,25 @@ Mat::insert_cols(const uword col_num, const uword N, const bool set_to_zero) const uword B_n_cols = t_n_cols - col_num; // insertion at col_num == n_cols is in effect an append operation - arma_debug_check( (col_num > t_n_cols), "Mat::insert_cols(): index out of bounds"); + arma_debug_check_bounds( (col_num > t_n_cols), "Mat::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } - if(N > 0) + Mat out(t_n_rows, t_n_cols + N, arma_nozeros_indicator()); + + if(A_n_cols > 0) { - Mat out(t_n_rows, t_n_cols + N); - - if(A_n_cols > 0) - { - out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); - } - - if(B_n_cols > 0) - { - out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); - } - - if(set_to_zero) - { - out.cols(col_num, col_num + N - 1).zeros(); - } - - steal_mem(out); + out.cols(0, A_n_cols-1) = cols(0, A_n_cols-1); } + + if(B_n_cols > 0) + { + out.cols(col_num + N, t_n_cols + N - 1) = cols(col_num, t_n_cols-1); + } + + out.cols(col_num, col_num + N - 1).zeros(); + + steal_mem(out); } @@ -4480,7 +4663,10 @@ Mat::insert_rows(const uword row_num, const Base& X) const uword B_n_rows = t_n_rows - row_num; bool err_state = false; - char* err_msg = 0; + char* err_msg = nullptr; + + const char* error_message_1 = "Mat::insert_rows(): index out of bounds"; + const char* error_message_2 = "Mat::insert_rows(): given object has an incompatible number of columns"; // insertion at row_num == n_rows is in effect an append operation @@ -4489,7 +4675,7 @@ Mat::insert_rows(const uword row_num, const Base& X) err_state, err_msg, (row_num > t_n_rows), - "Mat::insert_rows(): index out of bounds" + error_message_1 ); arma_debug_set_error @@ -4497,14 +4683,14 @@ Mat::insert_rows(const uword row_num, const Base& X) err_state, err_msg, ( (C_n_cols != t_n_cols) && ( (t_n_rows > 0) || (t_n_cols > 0) ) && ( (C_n_rows > 0) || (C_n_cols > 0) ) ), - "Mat::insert_rows(): given object has an incompatible number of columns" + error_message_2 ); - arma_debug_check(err_state, err_msg); + arma_debug_check_bounds(err_state, err_msg); if(C_n_rows > 0) { - Mat out( t_n_rows + C_n_rows, (std::max)(t_n_cols, C_n_cols) ); + Mat out( t_n_rows + C_n_rows, (std::max)(t_n_cols, C_n_cols), arma_nozeros_indicator() ); if(t_n_cols > 0) { @@ -4553,7 +4739,10 @@ Mat::insert_cols(const uword col_num, const Base& X) const uword B_n_cols = t_n_cols - col_num; bool err_state = false; - char* err_msg = 0; + char* err_msg = nullptr; + + const char* error_message_1 = "Mat::insert_cols(): index out of bounds"; + const char* error_message_2 = "Mat::insert_cols(): given object has an incompatible number of rows"; // insertion at col_num == n_cols is in effect an append operation @@ -4562,7 +4751,7 @@ Mat::insert_cols(const uword col_num, const Base& X) err_state, err_msg, (col_num > t_n_cols), - "Mat::insert_cols(): index out of bounds" + error_message_1 ); arma_debug_set_error @@ -4570,14 +4759,14 @@ Mat::insert_cols(const uword col_num, const Base& X) err_state, err_msg, ( (C_n_rows != t_n_rows) && ( (t_n_rows > 0) || (t_n_cols > 0) ) && ( (C_n_rows > 0) || (C_n_cols > 0) ) ), - "Mat::insert_cols(): given object has an incompatible number of rows" + error_message_2 ); - arma_debug_check(err_state, err_msg); + arma_debug_check_bounds(err_state, err_msg); if(C_n_cols > 0) { - Mat out( (std::max)(t_n_rows, C_n_rows), t_n_cols + C_n_cols ); + Mat out( (std::max)(t_n_rows, C_n_rows), t_n_cols + C_n_cols, arma_nozeros_indicator() ); if(t_n_rows > 0) { @@ -4610,6 +4799,7 @@ Mat::Mat(const Gen& X) : n_rows(X.n_rows) , n_cols(X.n_cols) , n_elem(n_rows*n_cols) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -4729,7 +4919,7 @@ Mat::operator/=(const Gen& X) -//! create a matrix from Op, i.e. run the previously delayed unary operations +//! create a matrix from Op, ie. run the previously delayed unary operations template template inline @@ -4737,6 +4927,7 @@ Mat::Mat(const Op& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -4750,7 +4941,7 @@ Mat::Mat(const Op& X) -//! create a matrix from Op, i.e. run the previously delayed unary operations +//! create a matrix from Op, ie. run the previously delayed unary operations template template inline @@ -4858,7 +5049,7 @@ Mat::operator/=(const Op& X) -//! create a matrix from eOp, i.e. run the previously delayed unary operations +//! create a matrix from eOp, ie. run the previously delayed unary operations template template inline @@ -4866,6 +5057,7 @@ Mat::Mat(const eOp& X) : n_rows(X.get_n_rows()) , n_cols(X.get_n_cols()) , n_elem(X.get_n_elem()) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -4881,7 +5073,7 @@ Mat::Mat(const eOp& X) -//! create a matrix from eOp, i.e. run the previously delayed unary operations +//! create a matrix from eOp, ie. run the previously delayed unary operations template template inline @@ -4894,20 +5086,11 @@ Mat::operator=(const eOp& X) const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); - if(bad_alias == false) - { - init_warm(X.get_n_rows(), X.get_n_cols()); - - eop_type::apply(*this, X); - } - else - { - arma_extra_debug_print("bad_alias = true"); - - Mat tmp(X); - - steal_mem(tmp); - } + if(bad_alias) { Mat tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols()); + + eop_type::apply(*this, X); return *this; } @@ -4921,9 +5104,13 @@ Mat& Mat::operator+=(const eOp& X) { arma_extra_debug_sigprint(); - + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator+=(tmp); } + eop_type::apply_inplace_plus(*this, X); return *this; @@ -4938,9 +5125,13 @@ Mat& Mat::operator-=(const eOp& X) { arma_extra_debug_sigprint(); - + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator-=(tmp); } + eop_type::apply_inplace_minus(*this, X); return *this; @@ -4972,9 +5163,13 @@ Mat& Mat::operator%=(const eOp& X) { arma_extra_debug_sigprint(); - + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator%=(tmp); } + eop_type::apply_inplace_schur(*this, X); return *this; @@ -4989,9 +5184,13 @@ Mat& Mat::operator/=(const eOp& X) { arma_extra_debug_sigprint(); - + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + const bool bad_alias = (eOp::proxy_type::has_subview && X.P.is_alias(*this)); + + if(bad_alias) { const Mat tmp(X); return (*this).operator/=(tmp); } + eop_type::apply_inplace_div(*this, X); return *this; @@ -5006,6 +5205,7 @@ Mat::Mat(const mtOp& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -5114,6 +5314,7 @@ Mat::Mat(const CubeToMatOp& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -5236,6 +5437,7 @@ Mat::Mat(const SpToDOp& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -5249,7 +5451,7 @@ Mat::Mat(const SpToDOp& X) -//! create a matrix from an SpToDOp, i.e. run the previously delayed unary operations +//! create a matrix from an SpToDOp, ie. run the previously delayed unary operations template template inline @@ -5357,7 +5559,7 @@ Mat::operator/=(const SpToDOp& X) -//! create a matrix from Glue, i.e. run the previously delayed binary operations +//! create a matrix from Glue, ie. run the previously delayed binary operations template template inline @@ -5365,6 +5567,7 @@ Mat::Mat(const Glue& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -5379,7 +5582,7 @@ Mat::Mat(const Glue& X) -//! create a matrix from Glue, i.e. run the previously delayed binary operations +//! create a matrix from Glue, ie. run the previously delayed binary operations template template inline @@ -5523,7 +5726,7 @@ Mat::operator-=(const Glue& X) -//! create a matrix from eGlue, i.e. run the previously delayed binary operations +//! create a matrix from eGlue, ie. run the previously delayed binary operations template template inline @@ -5531,6 +5734,7 @@ Mat::Mat(const eGlue& X) : n_rows(X.get_n_rows()) , n_cols(X.get_n_cols()) , n_elem(X.get_n_elem()) + , n_alloc() , vec_state(0) , mem_state(0) , mem() @@ -5547,7 +5751,7 @@ Mat::Mat(const eGlue& X) -//! create a matrix from eGlue, i.e. run the previously delayed binary operations +//! create a matrix from eGlue, ie. run the previously delayed binary operations template template inline @@ -5566,20 +5770,11 @@ Mat::operator=(const eGlue& X) (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) ); - if(bad_alias == false) - { - init_warm(X.get_n_rows(), X.get_n_cols()); - - eglue_type::apply(*this, X); - } - else - { - arma_extra_debug_print("bad_alias = true"); - - Mat tmp(X); - - steal_mem(tmp); - } + if(bad_alias) { Mat tmp(X); steal_mem(tmp); return *this; } + + init_warm(X.get_n_rows(), X.get_n_cols()); + + eglue_type::apply(*this, X); return *this; } @@ -5598,6 +5793,15 @@ Mat::operator+=(const eGlue& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator+=(tmp); } + eglue_type::apply_inplace_plus(*this, X); return *this; @@ -5617,6 +5821,15 @@ Mat::operator-=(const eGlue& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator-=(tmp); } + eglue_type::apply_inplace_minus(*this, X); return *this; @@ -5636,6 +5849,7 @@ Mat::operator*=(const eGlue& X) arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); glue_times::apply_inplace(*this, X); + return *this; } @@ -5652,7 +5866,17 @@ Mat::operator%=(const eGlue& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator%=(tmp); } + eglue_type::apply_inplace_schur(*this, X); + return *this; } @@ -5669,7 +5893,17 @@ Mat::operator/=(const eGlue& X) arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + const bool bad_alias = + ( + (eGlue::proxy1_type::has_subview && X.P1.is_alias(*this)) + || + (eGlue::proxy2_type::has_subview && X.P2.is_alias(*this)) + ); + + if(bad_alias) { const Mat tmp(X); return (*this).operator/=(tmp); } + eglue_type::apply_inplace_div(*this, X); + return *this; } @@ -5682,6 +5916,7 @@ Mat::Mat(const mtGlue& X) : n_rows(0) , n_cols(0) , n_elem(0) + , n_alloc(0) , vec_state(0) , mem_state(0) , mem() @@ -5785,14 +6020,144 @@ Mat::operator/=(const mtGlue& X) +template +template +inline +Mat::Mat(const SpToDGlue& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_alloc(0) + , vec_state(0) + , mem_state(0) + , mem() + { + arma_extra_debug_sigprint_this(this); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + } + + + +template +template +inline +Mat& +Mat::operator=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_type::apply(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator+=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator+=(m); + } + + + +template +template +inline +Mat& +Mat::operator-=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator-=(m); + } + + + +template +template +inline +Mat& +Mat::operator*=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + glue_times::apply_inplace(*this, X); + + return *this; + } + + + +template +template +inline +Mat& +Mat::operator%=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator%=(m); + } + + + +template +template +inline +Mat& +Mat::operator/=(const SpToDGlue& X) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_same_type< eT, typename T1::elem_type >::no )); + arma_type_check(( is_same_type< eT, typename T2::elem_type >::no )); + + const Mat m(X); + + return (*this).operator/=(m); + } + + + //! linear element accessor (treats the matrix as a vector); no bounds check; assumes memory is aligned template arma_inline -arma_warn_unused const eT& Mat::at_alt(const uword ii) const { const eT* mem_aligned = mem; + memory::mark_as_aligned(mem_aligned); return mem_aligned[ii]; @@ -5803,11 +6168,11 @@ Mat::at_alt(const uword ii) const //! linear element accessor (treats the matrix as a vector); bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused eT& Mat::operator() (const uword ii) { - arma_debug_check( (ii >= n_elem), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= n_elem), "Mat::operator(): index out of bounds" ); + return access::rw(mem[ii]); } @@ -5816,11 +6181,11 @@ Mat::operator() (const uword ii) //! linear element accessor (treats the matrix as a vector); bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused const eT& Mat::operator() (const uword ii) const { - arma_debug_check( (ii >= n_elem), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= n_elem), "Mat::operator(): index out of bounds" ); + return mem[ii]; } @@ -5828,7 +6193,6 @@ Mat::operator() (const uword ii) const //! linear element accessor (treats the matrix as a vector); no bounds check. template arma_inline -arma_warn_unused eT& Mat::operator[] (const uword ii) { @@ -5840,7 +6204,6 @@ Mat::operator[] (const uword ii) //! linear element accessor (treats the matrix as a vector); no bounds check template arma_inline -arma_warn_unused const eT& Mat::operator[] (const uword ii) const { @@ -5852,7 +6215,6 @@ Mat::operator[] (const uword ii) const //! linear element accessor (treats the matrix as a vector); no bounds check. template arma_inline -arma_warn_unused eT& Mat::at(const uword ii) { @@ -5864,7 +6226,6 @@ Mat::at(const uword ii) //! linear element accessor (treats the matrix as a vector); no bounds check template arma_inline -arma_warn_unused const eT& Mat::at(const uword ii) const { @@ -5876,11 +6237,11 @@ Mat::at(const uword ii) const //! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused eT& Mat::operator() (const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds" ); + return access::rw(mem[in_row + in_col*n_rows]); } @@ -5889,11 +6250,11 @@ Mat::operator() (const uword in_row, const uword in_col) //! element accessor; bounds checking not done when ARMA_NO_DEBUG is defined template arma_inline -arma_warn_unused const eT& Mat::operator() (const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "Mat::operator(): index out of bounds" ); + return mem[in_row + in_col*n_rows]; } @@ -5902,7 +6263,6 @@ Mat::operator() (const uword in_row, const uword in_col) const //! element accessor; no bounds check template arma_inline -arma_warn_unused eT& Mat::at(const uword in_row, const uword in_col) { @@ -5914,7 +6274,6 @@ Mat::at(const uword in_row, const uword in_col) //! element accessor; no bounds check template arma_inline -arma_warn_unused const eT& Mat::at(const uword in_row, const uword in_col) const { @@ -5923,6 +6282,32 @@ Mat::at(const uword in_row, const uword in_col) const +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + eT& + Mat::operator[] (const uword in_row, const uword in_col) + { + return access::rw( mem[in_row + in_col*n_rows] ); + } + + + + //! element accessor; no bounds check + template + arma_inline + const eT& + Mat::operator[] (const uword in_row, const uword in_col) const + { + return mem[in_row + in_col*n_rows]; + } + +#endif + + + //! prefix ++ template arma_inline @@ -5930,6 +6315,7 @@ const Mat& Mat::operator++() { Mat_aux::prefix_pp(*this); + return *this; } @@ -5953,6 +6339,7 @@ const Mat& Mat::operator--() { Mat_aux::prefix_mm(*this); + return *this; } @@ -5972,7 +6359,6 @@ Mat::operator--(int) //! returns true if the matrix has no elements template arma_inline -arma_warn_unused bool Mat::is_empty() const { @@ -5984,7 +6370,6 @@ Mat::is_empty() const //! returns true if the object can be interpreted as a column or row vector template arma_inline -arma_warn_unused bool Mat::is_vec() const { @@ -5996,7 +6381,6 @@ Mat::is_vec() const //! returns true if the object can be interpreted as a row vector template arma_inline -arma_warn_unused bool Mat::is_rowvec() const { @@ -6008,7 +6392,6 @@ Mat::is_rowvec() const //! returns true if the object can be interpreted as a column vector template arma_inline -arma_warn_unused bool Mat::is_colvec() const { @@ -6020,7 +6403,6 @@ Mat::is_colvec() const //! returns true if the object has the same number of non-zero rows and columnns template arma_inline -arma_warn_unused bool Mat::is_square() const { @@ -6029,23 +6411,22 @@ Mat::is_square() const -//! returns true if all of the elements are finite template inline -arma_warn_unused bool -Mat::is_finite() const +Mat::internal_is_finite() const { - return arrayops::is_finite( memptr(), n_elem ); + arma_extra_debug_sigprint(); + + return arrayops::is_finite(memptr(), n_elem); } template inline -arma_warn_unused bool -Mat::has_inf() const +Mat::internal_has_inf() const { arma_extra_debug_sigprint(); @@ -6056,9 +6437,8 @@ Mat::has_inf() const template inline -arma_warn_unused bool -Mat::has_nan() const +Mat::internal_has_nan() const { arma_extra_debug_sigprint(); @@ -6069,7 +6449,18 @@ Mat::has_nan() const template inline -arma_warn_unused +bool +Mat::internal_has_nonfinite() const + { + arma_extra_debug_sigprint(); + + return (arrayops::is_finite(memptr(), n_elem) == false); + } + + + +template +inline bool Mat::is_sorted(const char* direction) const { @@ -6082,13 +6473,12 @@ Mat::is_sorted(const char* direction) const template inline -arma_warn_unused bool Mat::is_sorted(const char* direction, const uword dim) const { arma_extra_debug_sigprint(); - const char sig1 = (direction != NULL) ? direction[0] : char(0); + const char sig1 = (direction != nullptr) ? direction[0] : char(0); // direction is one of: // "ascend" @@ -6158,7 +6548,6 @@ Mat::is_sorted(const char* direction, const uword dim) const template template inline -arma_warn_unused bool Mat::is_sorted_helper(const comparator& comp, const uword dim) const { @@ -6228,7 +6617,6 @@ Mat::is_sorted_helper(const comparator& comp, const uword dim) const //! returns true if the given index is currently in range template arma_inline -arma_warn_unused bool Mat::in_range(const uword ii) const { @@ -6240,7 +6628,6 @@ Mat::in_range(const uword ii) const //! returns true if the given start and end indices are currently in range template arma_inline -arma_warn_unused bool Mat::in_range(const span& x) const { @@ -6264,7 +6651,6 @@ Mat::in_range(const span& x) const //! returns true if the given location is currently in range template arma_inline -arma_warn_unused bool Mat::in_range(const uword in_row, const uword in_col) const { @@ -6275,7 +6661,6 @@ Mat::in_range(const uword in_row, const uword in_col) const template arma_inline -arma_warn_unused bool Mat::in_range(const span& row_span, const uword in_col) const { @@ -6298,7 +6683,6 @@ Mat::in_range(const span& row_span, const uword in_col) const template arma_inline -arma_warn_unused bool Mat::in_range(const uword in_row, const span& col_span) const { @@ -6321,7 +6705,6 @@ Mat::in_range(const uword in_row, const span& col_span) const template arma_inline -arma_warn_unused bool Mat::in_range(const span& row_span, const span& col_span) const { @@ -6343,7 +6726,6 @@ Mat::in_range(const span& row_span, const span& col_span) const template arma_inline -arma_warn_unused bool Mat::in_range(const uword in_row, const uword in_col, const SizeMat& s) const { @@ -6365,7 +6747,6 @@ Mat::in_range(const uword in_row, const uword in_col, const SizeMat& s) cons //! returns a pointer to array of eTs for a specified column; no bounds check template arma_inline -arma_warn_unused eT* Mat::colptr(const uword in_col) { @@ -6377,7 +6758,6 @@ Mat::colptr(const uword in_col) //! returns a pointer to array of eTs for a specified column; no bounds check template arma_inline -arma_warn_unused const eT* Mat::colptr(const uword in_col) const { @@ -6389,7 +6769,6 @@ Mat::colptr(const uword in_col) const //! returns a pointer to array of eTs used by the matrix template arma_inline -arma_warn_unused eT* Mat::memptr() { @@ -6401,7 +6780,6 @@ Mat::memptr() //! returns a pointer to array of eTs used by the matrix template arma_inline -arma_warn_unused const eT* Mat::memptr() const { @@ -6410,132 +6788,31 @@ Mat::memptr() const -//! print contents of the matrix (to the cout stream), -//! optionally preceding with a user specified line of text. -//! the precision and cell width are modified. -//! on return, the stream's state are restored to their original values. -template -arma_cold -inline -void -Mat::impl_print(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print(get_cout_stream(), *this, true); - } - - - -//! print contents of the matrix to a user specified stream, -//! optionally preceding with a user specified line of text. -//! the precision and cell width are modified. -//! on return, the stream's state are restored to their original values. -template -arma_cold -inline -void -Mat::impl_print(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print(user_stream, *this, true); - } - - - -//! print contents of the matrix (to the cout stream), -//! optionally preceding with a user specified line of text. -//! the stream's state are used as is and are not modified -//! (i.e. the precision and cell width are not modified). -template -arma_cold -inline -void -Mat::impl_raw_print(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print(get_cout_stream(), *this, false); - } - - - -//! print contents of the matrix to a user specified stream, -//! optionally preceding with a user specified line of text. -//! the stream's state are used as is and are not modified. -//! (i.e. the precision and cell width are not modified). -template -arma_cold -inline -void -Mat::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print(user_stream, *this, false); - } - - - //! change the matrix to have user specified dimensions (data is not preserved) template inline -void -Mat::set_size(const uword in_elem) +Mat& +Mat::set_size(const uword new_n_elem) { arma_extra_debug_sigprint(); switch(vec_state) { case 0: + // fallthrough case 1: - init_warm(in_elem, 1); + init_warm(new_n_elem, 1); break; case 2: - init_warm(1, in_elem); + init_warm(1, new_n_elem); break; default: ; } + + return *this; } @@ -6543,24 +6820,28 @@ Mat::set_size(const uword in_elem) //! change the matrix to have user specified dimensions (data is not preserved) template inline -void -Mat::set_size(const uword in_rows, const uword in_cols) +Mat& +Mat::set_size(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - init_warm(in_rows, in_cols); + init_warm(new_n_rows, new_n_cols); + + return *this; } template inline -void +Mat& Mat::set_size(const SizeMat& s) { arma_extra_debug_sigprint(); init_warm(s.n_rows, s.n_cols); + + return *this; } @@ -6568,25 +6849,28 @@ Mat::set_size(const SizeMat& s) //! change the matrix to have user specified dimensions (data is preserved) template inline -void -Mat::resize(const uword in_elem) +Mat& +Mat::resize(const uword new_n_elem) { arma_extra_debug_sigprint(); switch(vec_state) { case 0: + // fallthrough case 1: - (*this).resize(in_elem, 1); + (*this).resize(new_n_elem, 1); break; case 2: - (*this).resize(1, in_elem); + (*this).resize(1, new_n_elem); break; default: ; } + + return *this; } @@ -6594,24 +6878,28 @@ Mat::resize(const uword in_elem) //! change the matrix to have user specified dimensions (data is preserved) template inline -void -Mat::resize(const uword in_rows, const uword in_cols) +Mat& +Mat::resize(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - *this = arma::resize(*this, in_rows, in_cols); + op_resize::apply_mat_inplace((*this), new_n_rows, new_n_cols); + + return *this; } template inline -void +Mat& Mat::resize(const SizeMat& s) { arma_extra_debug_sigprint(); - *this = arma::resize(*this, s.n_rows, s.n_cols); + op_resize::apply_mat_inplace((*this), s.n_rows, s.n_cols); + + return *this; } @@ -6619,38 +6907,55 @@ Mat::resize(const SizeMat& s) //! change the matrix to have user specified dimensions (data is preserved) template inline -void -Mat::reshape(const uword in_rows, const uword in_cols) +Mat& +Mat::reshape(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, in_rows, in_cols); + op_reshape::apply_mat_inplace((*this), new_n_rows, new_n_cols); + + return *this; } template inline -void +Mat& Mat::reshape(const SizeMat& s) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, s.n_rows, s.n_cols); + op_reshape::apply_mat_inplace((*this), s.n_rows, s.n_cols); + + return *this; } //! NOTE: don't use this form; it's deprecated and will be removed template -arma_deprecated inline void -Mat::reshape(const uword in_rows, const uword in_cols, const uword dim) +Mat::reshape(const uword new_n_rows, const uword new_n_cols, const uword dim) { arma_extra_debug_sigprint(); - *this = arma::reshape(*this, in_rows, in_cols, dim); + arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); + + if(dim == 0) + { + op_reshape::apply_mat_inplace((*this), new_n_rows, new_n_cols); + } + else + if(dim == 1) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, (*this)); + + op_reshape::apply_mat_noalias((*this), tmp, new_n_rows, new_n_cols); + } } @@ -6659,7 +6964,7 @@ Mat::reshape(const uword in_rows, const uword in_cols, const uword dim) template template inline -void +Mat& Mat::copy_size(const Base& X) { arma_extra_debug_sigprint(); @@ -6670,6 +6975,8 @@ Mat::copy_size(const Base& X) const uword X_n_cols = P.get_n_cols(); init_warm(X_n_rows, X_n_cols); + + return *this; } @@ -6678,7 +6985,7 @@ Mat::copy_size(const Base& X) template template inline -const Mat& +Mat& Mat::for_each(functor F) { arma_extra_debug_sigprint(); @@ -6739,7 +7046,7 @@ Mat::for_each(functor F) const template template inline -const Mat& +Mat& Mat::transform(functor F) { arma_extra_debug_sigprint(); @@ -6776,7 +7083,7 @@ Mat::transform(functor F) template template inline -const Mat& +Mat& Mat::imbue(functor F) { arma_extra_debug_sigprint(); @@ -6808,7 +7115,7 @@ Mat::imbue(functor F) template inline -const Mat& +Mat& Mat::replace(const eT old_val, const eT new_val) { arma_extra_debug_sigprint(); @@ -6822,7 +7129,7 @@ Mat::replace(const eT old_val, const eT new_val) template inline -const Mat& +Mat& Mat::clean(const typename get_pod_type::result threshold) { arma_extra_debug_sigprint(); @@ -6834,10 +7141,34 @@ Mat::clean(const typename get_pod_type::result threshold) +template +inline +Mat& +Mat::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Mat::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "Mat::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "Mat::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + arrayops::clamp(memptr(), n_elem, min_val, max_val); + + return *this; + } + + + //! fill the matrix with the specified value template inline -const Mat& +Mat& Mat::fill(const eT val) { arma_extra_debug_sigprint(); @@ -6849,20 +7180,20 @@ Mat::fill(const eT val) -//! fill the matrix with the specified value +//! fill the matrix with the specified pattern template template inline -const Mat& +Mat& Mat::fill(const fill::fill_class&) { arma_extra_debug_sigprint(); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).eye(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } return *this; } @@ -6871,7 +7202,7 @@ Mat::fill(const fill::fill_class&) template inline -const Mat& +Mat& Mat::zeros() { arma_extra_debug_sigprint(); @@ -6885,12 +7216,12 @@ Mat::zeros() template inline -const Mat& -Mat::zeros(const uword in_elem) +Mat& +Mat::zeros(const uword new_n_elem) { arma_extra_debug_sigprint(); - set_size(in_elem); + set_size(new_n_elem); return (*this).zeros(); } @@ -6899,12 +7230,12 @@ Mat::zeros(const uword in_elem) template inline -const Mat& -Mat::zeros(const uword in_n_rows, const uword in_n_cols) +Mat& +Mat::zeros(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - set_size(in_n_rows, in_n_cols); + set_size(new_n_rows, new_n_cols); return (*this).zeros(); } @@ -6913,7 +7244,7 @@ Mat::zeros(const uword in_n_rows, const uword in_n_cols) template inline -const Mat& +Mat& Mat::zeros(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -6925,7 +7256,7 @@ Mat::zeros(const SizeMat& s) template inline -const Mat& +Mat& Mat::ones() { arma_extra_debug_sigprint(); @@ -6937,12 +7268,12 @@ Mat::ones() template inline -const Mat& -Mat::ones(const uword in_elem) +Mat& +Mat::ones(const uword new_n_elem) { arma_extra_debug_sigprint(); - set_size(in_elem); + set_size(new_n_elem); return fill(eT(1)); } @@ -6951,12 +7282,12 @@ Mat::ones(const uword in_elem) template inline -const Mat& -Mat::ones(const uword in_rows, const uword in_cols) +Mat& +Mat::ones(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols); + set_size(new_n_rows, new_n_cols); return fill(eT(1)); } @@ -6965,7 +7296,7 @@ Mat::ones(const uword in_rows, const uword in_cols) template inline -const Mat& +Mat& Mat::ones(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -6977,7 +7308,7 @@ Mat::ones(const SizeMat& s) template inline -const Mat& +Mat& Mat::randu() { arma_extra_debug_sigprint(); @@ -6991,12 +7322,12 @@ Mat::randu() template inline -const Mat& -Mat::randu(const uword in_elem) +Mat& +Mat::randu(const uword new_n_elem) { arma_extra_debug_sigprint(); - set_size(in_elem); + set_size(new_n_elem); return (*this).randu(); } @@ -7005,12 +7336,12 @@ Mat::randu(const uword in_elem) template inline -const Mat& -Mat::randu(const uword in_rows, const uword in_cols) +Mat& +Mat::randu(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols); + set_size(new_n_rows, new_n_cols); return (*this).randu(); } @@ -7019,7 +7350,7 @@ Mat::randu(const uword in_rows, const uword in_cols) template inline -const Mat& +Mat& Mat::randu(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -7031,7 +7362,7 @@ Mat::randu(const SizeMat& s) template inline -const Mat& +Mat& Mat::randn() { arma_extra_debug_sigprint(); @@ -7045,12 +7376,12 @@ Mat::randn() template inline -const Mat& -Mat::randn(const uword in_elem) +Mat& +Mat::randn(const uword new_n_elem) { arma_extra_debug_sigprint(); - set_size(in_elem); + set_size(new_n_elem); return (*this).randn(); } @@ -7059,12 +7390,12 @@ Mat::randn(const uword in_elem) template inline -const Mat& -Mat::randn(const uword in_rows, const uword in_cols) +Mat& +Mat::randn(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols); + set_size(new_n_rows, new_n_cols); return (*this).randn(); } @@ -7073,7 +7404,7 @@ Mat::randn(const uword in_rows, const uword in_cols) template inline -const Mat& +Mat& Mat::randn(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -7085,7 +7416,7 @@ Mat::randn(const SizeMat& s) template inline -const Mat& +Mat& Mat::eye() { arma_extra_debug_sigprint(); @@ -7094,10 +7425,7 @@ Mat::eye() const uword N = (std::min)(n_rows, n_cols); - for(uword ii=0; ii::eye() template inline -const Mat& -Mat::eye(const uword in_rows, const uword in_cols) +Mat& +Mat::eye(const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - set_size(in_rows, in_cols); + set_size(new_n_rows, new_n_cols); return (*this).eye(); } @@ -7120,7 +7448,7 @@ Mat::eye(const uword in_rows, const uword in_cols) template inline -const Mat& +Mat& Mat::eye(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -7132,33 +7460,21 @@ Mat::eye(const SizeMat& s) template inline -arma_cold void Mat::reset() { arma_extra_debug_sigprint(); - switch(vec_state) - { - default: - init_warm(0, 0); - break; - - case 1: - init_warm(0, 1); - break; - - case 2: - init_warm(1, 0); - break; - } + const uword new_n_rows = (vec_state == 2) ? 1 : 0; + const uword new_n_cols = (vec_state == 1) ? 1 : 0; + + init_warm(new_n_rows, new_n_cols); } template inline -arma_cold void Mat::soft_reset() { @@ -7171,7 +7487,7 @@ Mat::soft_reset() } else { - fill(Datum::nan); + zeros(); } } @@ -7205,7 +7521,6 @@ Mat::set_imag(const Base::pod_type,T1>& X) template inline -arma_warn_unused eT Mat::min() const { @@ -7225,7 +7540,6 @@ Mat::min() const template inline -arma_warn_unused eT Mat::max() const { @@ -7346,9 +7660,8 @@ Mat::max(uword& row_of_max_val, uword& col_of_max_val) const //! save the matrix to a file template inline -arma_cold bool -Mat::save(const std::string name, const file_type type, const bool print_status) const +Mat::save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); @@ -7365,7 +7678,15 @@ Mat::save(const std::string name, const file_type type, const bool print_sta break; case csv_ascii: - save_okay = diskio::save_csv_ascii(*this, name); + return (*this).save(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).save(csv_name(name), type); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, name); break; case raw_binary: @@ -7389,11 +7710,11 @@ Mat::save(const std::string name, const file_type type, const bool print_sta break; default: - if(print_status) { arma_debug_warn("Mat::save(): unsupported file type"); } + arma_debug_warn_level(1, "Mat::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("Mat::save(): couldn't write to ", name); } + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): write failed; file: ", name); } return save_okay; } @@ -7402,9 +7723,8 @@ Mat::save(const std::string name, const file_type type, const bool print_sta template inline -arma_cold bool -Mat::save(const hdf5_name& spec, const file_type type, const bool print_status) const +Mat::save(const hdf5_name& spec, const file_type type) const { arma_extra_debug_sigprint(); @@ -7412,7 +7732,7 @@ Mat::save(const hdf5_name& spec, const file_type type, const bool print_stat if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) { - arma_debug_check(true, "Mat::save(): unsupported file type for hdf5_name()"); + arma_stop_runtime_error("Mat::save(): unsupported file type for hdf5_name()"); return false; } @@ -7422,11 +7742,12 @@ Mat::save(const hdf5_name& spec, const file_type type, const bool print_stat if(append && replace) { - arma_debug_check(true, "Mat::save(): only one of 'append' or 'replace' options can be used"); + arma_stop_runtime_error("Mat::save(): only one of 'append' or 'replace' options can be used"); return false; } bool save_okay = false; + std::string err_msg; if(do_trans) @@ -7442,18 +7763,93 @@ Mat::save(const hdf5_name& spec, const file_type type, const bool print_stat save_okay = diskio::save_hdf5_binary(*this, spec, err_msg); } - if((print_status == true) && (save_okay == false)) + if(save_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("Mat::save(): ", err_msg, spec.filename); + arma_debug_warn_level(3, "Mat::save(): ", err_msg, "; file: ", spec.filename); } else { - arma_debug_warn("Mat::save(): couldn't write to ", spec.filename); + arma_debug_warn_level(3, "Mat::save(): write failed; file: ", spec.filename); + } + } + + return save_okay; + } + + + +template +inline +bool +Mat::save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("Mat::save(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + + arma_extra_debug_print("Mat::save(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + if(with_header) + { + if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) ) + { + arma_debug_warn_level(1, "Mat::save(): given header must have a vector layout"); + return false; + } + + for(uword i=0; i < spec.header_ro.n_elem; ++i) + { + const std::string& token = spec.header_ro.at(i); + + if(token.find(separator) != std::string::npos) + { + arma_debug_warn_level(1, "Mat::save(): token within the header contains the separator character: '", token, "'"); + return false; + } + } + + const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(spec.header_ro.n_elem != save_n_cols) + { + arma_debug_warn_level(1, "Mat::save(): size mismatch between header and matrix"); + return false; } } + bool save_okay = false; + + if(do_trans) + { + const Mat tmp = (*this).st(); + + save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header, separator); + } + else + { + save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header, separator); + } + + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): write failed; file: ", spec.filename); } + return save_okay; } @@ -7462,9 +7858,8 @@ Mat::save(const hdf5_name& spec, const file_type type, const bool print_stat //! save the matrix to a stream template inline -arma_cold bool -Mat::save(std::ostream& os, const file_type type, const bool print_status) const +Mat::save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); @@ -7481,7 +7876,15 @@ Mat::save(std::ostream& os, const file_type type, const bool print_status) c break; case csv_ascii: - save_okay = diskio::save_csv_ascii(*this, os); + save_okay = diskio::save_csv_ascii(*this, os, char(',')); + break; + + case ssv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(';')); + break; + + case coord_ascii: + save_okay = diskio::save_coord_ascii(*this, os); break; case raw_binary: @@ -7497,11 +7900,11 @@ Mat::save(std::ostream& os, const file_type type, const bool print_status) c break; default: - if(print_status) { arma_debug_warn("Mat::save(): unsupported file type"); } + arma_debug_warn_level(1, "Mat::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("Mat::save(): couldn't write to the given stream"); } + if(save_okay == false) { arma_debug_warn_level(3, "Mat::save(): stream write failed"); } return save_okay; } @@ -7511,9 +7914,8 @@ Mat::save(std::ostream& os, const file_type type, const bool print_status) c //! load a matrix from a file template inline -arma_cold bool -Mat::load(const std::string name, const file_type type, const bool print_status) +Mat::load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); @@ -7535,7 +7937,15 @@ Mat::load(const std::string name, const file_type type, const bool print_sta break; case csv_ascii: - load_okay = diskio::load_csv_ascii(*this, name, err_msg); + return (*this).load(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).load(csv_name(name), type); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, name, err_msg); break; case raw_binary: @@ -7559,26 +7969,23 @@ Mat::load(const std::string name, const file_type type, const bool print_sta break; default: - if(print_status) { arma_debug_warn("Mat::load(): unsupported file type"); } + arma_debug_warn_level(1, "Mat::load(): unsupported file type"); load_okay = false; } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("Mat::load(): ", err_msg, name); + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", name); } else { - arma_debug_warn("Mat::load(): couldn't read ", name); + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", name); } } - if(load_okay == false) - { - (*this).soft_reset(); - } + if(load_okay == false) { (*this).soft_reset(); } return load_okay; } @@ -7587,16 +7994,14 @@ Mat::load(const std::string name, const file_type type, const bool print_sta template inline -arma_cold bool -Mat::load(const hdf5_name& spec, const file_type type, const bool print_status) +Mat::load(const hdf5_name& spec, const file_type type) { arma_extra_debug_sigprint(); if( (type != hdf5_binary) && (type != hdf5_binary_trans) ) { - if(print_status) { arma_debug_warn("Mat::load(): unsupported file type for hdf5_name()"); } - (*this).soft_reset(); + arma_stop_runtime_error("Mat::load(): unsupported file type for hdf5_name()"); return false; } @@ -7619,21 +8024,105 @@ Mat::load(const hdf5_name& spec, const file_type type, const bool print_stat } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", spec.filename); + } + else + { + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", spec.filename); + } + } + + if(load_okay == false) { (*this).soft_reset(); } + + return load_okay; + } + + + +template +inline +bool +Mat::load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("Mat::load(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + const bool strict = bool(spec.opts.flags & csv_opts::flag_strict ); + + arma_extra_debug_print("Mat::load(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + if(strict ) { arma_extra_debug_print("strict"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + bool load_okay = false; + std::string err_msg; + + if(do_trans) + { + Mat tmp_mat; + + load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header, separator, strict); + + if(load_okay) + { + (*this) = tmp_mat.st(); + + if(with_header) + { + // field::set_size() preserves data if the number of elements hasn't changed + spec.header_rw.set_size(spec.header_rw.n_elem, 1); + } + } + } + else + { + load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header, separator, strict); + } + + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("Mat::load(): ", err_msg, spec.filename); + arma_debug_warn_level(3, "Mat::load(): ", err_msg, "; file: ", spec.filename); } else { - arma_debug_warn("Mat::load(): couldn't read ", spec.filename); + arma_debug_warn_level(3, "Mat::load(): read failed; file: ", spec.filename); + } + } + else + { + const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(with_header && (spec.header_rw.n_elem != load_n_cols)) + { + arma_debug_warn_level(3, "Mat::load(): size mismatch between header and matrix"); } } if(load_okay == false) { (*this).soft_reset(); + + if(with_header) { spec.header_rw.reset(); } } return load_okay; @@ -7644,9 +8133,8 @@ Mat::load(const hdf5_name& spec, const file_type type, const bool print_stat //! load a matrix from a stream template inline -arma_cold bool -Mat::load(std::istream& is, const file_type type, const bool print_status) +Mat::load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); @@ -7668,7 +8156,15 @@ Mat::load(std::istream& is, const file_type type, const bool print_status) break; case csv_ascii: - load_okay = diskio::load_csv_ascii(*this, is, err_msg); + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(','), false); + break; + + case ssv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(';'), false); + break; + + case coord_ascii: + load_okay = diskio::load_coord_ascii(*this, is, err_msg); break; case raw_binary: @@ -7684,110 +8180,121 @@ Mat::load(std::istream& is, const file_type type, const bool print_status) break; default: - if(print_status) { arma_debug_warn("Mat::load(): unsupported file type"); } + arma_debug_warn_level(1, "Mat::load(): unsupported file type"); load_okay = false; } - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("Mat::load(): ", err_msg, "the given stream"); + arma_debug_warn_level(3, "Mat::load(): ", err_msg); } else { - arma_debug_warn("Mat::load(): couldn't load from the given stream"); + arma_debug_warn_level(3, "Mat::load(): stream read failed"); } } - if(load_okay == false) - { - (*this).soft_reset(); - } + if(load_okay == false) { (*this).soft_reset(); } return load_okay; } -//! save the matrix to a file, without printing any error messages template inline -arma_cold bool Mat::quiet_save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(name, type, false); + return (*this).save(name, type); } template inline -arma_cold bool Mat::quiet_save(const hdf5_name& spec, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(spec, type, false); + return (*this).save(spec, type); + } + + + +template +inline +bool +Mat::quiet_save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + return (*this).save(spec, type); } -//! save the matrix to a stream, without printing any error messages template inline -arma_cold bool Mat::quiet_save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(os, type, false); + return (*this).save(os, type); } -//! load a matrix from a file, without printing any error messages template inline -arma_cold bool Mat::quiet_load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(name, type, false); + return (*this).load(name, type); } template inline -arma_cold bool Mat::quiet_load(const hdf5_name& spec, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(spec, type, false); + return (*this).load(spec, type); + } + + + +template +inline +bool +Mat::quiet_load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + return (*this).load(spec, type); } -//! load a matrix from a stream, without printing any error messages template inline -arma_cold bool Mat::quiet_load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(is, type, false); + return (*this).load(is, type); } @@ -7795,13 +8302,13 @@ Mat::quiet_load(std::istream& is, const file_type type) template inline Mat::row_iterator::row_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) , current_row(0 ) , current_col(0 ) { arma_extra_debug_sigprint(); - // Technically this iterator is invalid (it does not point to a valid element) + + // NOTE: this instance of row_iterator is invalid (it does not point to a valid element) } @@ -7810,7 +8317,6 @@ template inline Mat::row_iterator::row_iterator(const row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) { @@ -7821,11 +8327,10 @@ Mat::row_iterator::row_iterator(const row_iterator& X) template inline -Mat::row_iterator::row_iterator(Mat& in_M, const uword in_row) - : M (&in_M ) - , current_ptr(&(in_M.at(in_row,0))) - , current_row(in_row ) - , current_col(0 ) +Mat::row_iterator::row_iterator(Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_row(in_row) + , current_col(in_col) { arma_extra_debug_sigprint(); } @@ -7834,11 +8339,10 @@ Mat::row_iterator::row_iterator(Mat& in_M, const uword in_row) template inline -arma_warn_unused eT& Mat::row_iterator::operator*() { - return (*current_ptr); + return M->at(current_row,current_col); } @@ -7854,12 +8358,6 @@ Mat::row_iterator::operator++() { current_col = 0; current_row++; - - current_ptr = &(M->at(current_row, 0)); - } - else - { - current_ptr += M->n_rows; } return *this; @@ -7869,7 +8367,6 @@ Mat::row_iterator::operator++() template inline -arma_warn_unused typename Mat::row_iterator Mat::row_iterator::operator++(int) { @@ -7890,8 +8387,6 @@ Mat::row_iterator::operator--() if(current_col > 0) { current_col--; - - current_ptr -= M->n_rows; } else { @@ -7899,8 +8394,6 @@ Mat::row_iterator::operator--() { current_col = M->n_cols - 1; current_row--; - - current_ptr = &(M->at(current_row, current_col)); } } @@ -7911,7 +8404,6 @@ Mat::row_iterator::operator--() template inline -arma_warn_unused typename Mat::row_iterator Mat::row_iterator::operator--(int) { @@ -7926,44 +8418,40 @@ Mat::row_iterator::operator--(int) template inline -arma_warn_unused bool Mat::row_iterator::operator!=(const typename Mat::row_iterator& X) const { - return (current_ptr != X.current_ptr); + return ( (current_row != X.current_row) || (current_col != X.current_col) ); } template inline -arma_warn_unused bool Mat::row_iterator::operator==(const typename Mat::row_iterator& X) const { - return (current_ptr == X.current_ptr); + return ( (current_row == X.current_row) && (current_col == X.current_col) ); } template inline -arma_warn_unused bool Mat::row_iterator::operator!=(const typename Mat::const_row_iterator& X) const { - return (current_ptr != X.current_ptr); + return ( (current_row != X.current_row) || (current_col != X.current_col) ); } template inline -arma_warn_unused bool Mat::row_iterator::operator==(const typename Mat::const_row_iterator& X) const { - return (current_ptr == X.current_ptr); + return ( (current_row == X.current_row) && (current_col == X.current_col) ); } @@ -7971,13 +8459,13 @@ Mat::row_iterator::operator==(const typename Mat::const_row_iterator& X) template inline Mat::const_row_iterator::const_row_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) , current_row(0 ) , current_col(0 ) { arma_extra_debug_sigprint(); - // Technically this iterator is invalid (it does not point to a valid element) + + // NOTE: this instance of const_row_iterator is invalid (it does not point to a valid element) } @@ -7986,7 +8474,6 @@ template inline Mat::const_row_iterator::const_row_iterator(const typename Mat::row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) { @@ -7999,7 +8486,6 @@ template inline Mat::const_row_iterator::const_row_iterator(const typename Mat::const_row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) { @@ -8010,11 +8496,10 @@ Mat::const_row_iterator::const_row_iterator(const typename Mat::const_ro template inline -Mat::const_row_iterator::const_row_iterator(const Mat& in_M, const uword in_row) - : M (&in_M ) - , current_ptr(&(in_M.at(in_row,0))) - , current_row(in_row ) - , current_col(0 ) +Mat::const_row_iterator::const_row_iterator(const Mat& in_M, const uword in_row, const uword in_col) + : M (&in_M ) + , current_row(in_row) + , current_col(in_col) { arma_extra_debug_sigprint(); } @@ -8023,11 +8508,10 @@ Mat::const_row_iterator::const_row_iterator(const Mat& in_M, const uword template inline -arma_warn_unused const eT& Mat::const_row_iterator::operator*() const { - return (*current_ptr); + return M->at(current_row,current_col); } @@ -8043,12 +8527,6 @@ Mat::const_row_iterator::operator++() { current_col = 0; current_row++; - - current_ptr = &(M->at(current_row, 0)); - } - else - { - current_ptr += M->n_rows; } return *this; @@ -8058,7 +8536,6 @@ Mat::const_row_iterator::operator++() template inline -arma_warn_unused typename Mat::const_row_iterator Mat::const_row_iterator::operator++(int) { @@ -8079,8 +8556,6 @@ Mat::const_row_iterator::operator--() if(current_col > 0) { current_col--; - - current_ptr -= M->n_rows; } else { @@ -8088,8 +8563,6 @@ Mat::const_row_iterator::operator--() { current_col = M->n_cols - 1; current_row--; - - current_ptr = &(M->at(current_row, current_col)); } } @@ -8100,7 +8573,6 @@ Mat::const_row_iterator::operator--() template inline -arma_warn_unused typename Mat::const_row_iterator Mat::const_row_iterator::operator--(int) { @@ -8115,44 +8587,40 @@ Mat::const_row_iterator::operator--(int) template inline -arma_warn_unused bool Mat::const_row_iterator::operator!=(const typename Mat::row_iterator& X) const { - return (current_ptr != X.current_ptr); + return ( (current_row != X.current_row) || (current_col != X.current_col) ); } template inline -arma_warn_unused bool Mat::const_row_iterator::operator==(const typename Mat::row_iterator& X) const { - return (current_ptr == X.current_ptr); + return ( (current_row == X.current_row) && (current_col == X.current_col) ); } template inline -arma_warn_unused bool Mat::const_row_iterator::operator!=(const typename Mat::const_row_iterator& X) const { - return (current_ptr != X.current_ptr); + return ( (current_row != X.current_row) || (current_col != X.current_col) ); } template inline -arma_warn_unused bool Mat::const_row_iterator::operator==(const typename Mat::const_row_iterator& X) const { - return (current_ptr == X.current_ptr); + return ( (current_row == X.current_row) && (current_col == X.current_col) ); } @@ -8160,8 +8628,8 @@ Mat::const_row_iterator::operator==(const typename Mat::const_row_iterat template inline Mat::row_col_iterator::row_col_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) + , current_ptr(nullptr) , current_col(0 ) , current_row(0 ) { @@ -8199,7 +8667,6 @@ Mat::row_col_iterator::row_col_iterator(Mat& in_M, const uword in_row, c template inline -arma_warn_unused eT& Mat::row_col_iterator::operator*() { @@ -8233,7 +8700,6 @@ Mat::row_col_iterator::operator++() template inline -arma_warn_unused typename Mat::row_col_iterator Mat::row_col_iterator::operator++(int) { @@ -8270,7 +8736,6 @@ Mat::row_col_iterator::operator--() template inline -arma_warn_unused typename Mat::row_col_iterator Mat::row_col_iterator::operator--(int) { @@ -8285,7 +8750,6 @@ Mat::row_col_iterator::operator--(int) template inline -arma_warn_unused uword Mat::row_col_iterator::row() const { @@ -8296,7 +8760,6 @@ Mat::row_col_iterator::row() const template inline -arma_warn_unused uword Mat::row_col_iterator::col() const { @@ -8307,7 +8770,6 @@ Mat::row_col_iterator::col() const template inline -arma_warn_unused bool Mat::row_col_iterator::operator==(const row_col_iterator& rhs) const { @@ -8318,7 +8780,6 @@ Mat::row_col_iterator::operator==(const row_col_iterator& rhs) const template inline -arma_warn_unused bool Mat::row_col_iterator::operator!=(const row_col_iterator& rhs) const { @@ -8329,7 +8790,6 @@ Mat::row_col_iterator::operator!=(const row_col_iterator& rhs) const template inline -arma_warn_unused bool Mat::row_col_iterator::operator==(const const_row_col_iterator& rhs) const { @@ -8340,7 +8800,6 @@ Mat::row_col_iterator::operator==(const const_row_col_iterator& rhs) const template inline -arma_warn_unused bool Mat::row_col_iterator::operator!=(const const_row_col_iterator& rhs) const { @@ -8352,8 +8811,8 @@ Mat::row_col_iterator::operator!=(const const_row_col_iterator& rhs) const template inline Mat::const_row_col_iterator::const_row_col_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) + , current_ptr(nullptr) , current_col(0 ) , current_row(0 ) { @@ -8404,7 +8863,6 @@ Mat::const_row_col_iterator::const_row_col_iterator(const Mat& in_M, con template inline -arma_warn_unused const eT& Mat::const_row_col_iterator::operator*() const { @@ -8438,7 +8896,6 @@ Mat::const_row_col_iterator::operator++() template inline -arma_warn_unused typename Mat::const_row_col_iterator Mat::const_row_col_iterator::operator++(int) { @@ -8476,7 +8933,6 @@ Mat::const_row_col_iterator::operator--() template inline -arma_warn_unused typename Mat::const_row_col_iterator Mat::const_row_col_iterator::operator--(int) { @@ -8491,7 +8947,6 @@ Mat::const_row_col_iterator::operator--(int) template inline -arma_warn_unused uword Mat::const_row_col_iterator::row() const { @@ -8502,7 +8957,6 @@ Mat::const_row_col_iterator::row() const template inline -arma_warn_unused uword Mat::const_row_col_iterator::col() const { @@ -8513,7 +8967,6 @@ Mat::const_row_col_iterator::col() const template inline -arma_warn_unused bool Mat::const_row_col_iterator::operator==(const const_row_col_iterator& rhs) const { @@ -8524,7 +8977,6 @@ Mat::const_row_col_iterator::operator==(const const_row_col_iterator& rhs) c template inline -arma_warn_unused bool Mat::const_row_col_iterator::operator!=(const const_row_col_iterator& rhs) const { @@ -8535,7 +8987,6 @@ Mat::const_row_col_iterator::operator!=(const const_row_col_iterator& rhs) c template inline -arma_warn_unused bool Mat::const_row_col_iterator::operator==(const row_col_iterator& rhs) const { @@ -8546,7 +8997,6 @@ Mat::const_row_col_iterator::operator==(const row_col_iterator& rhs) const template inline -arma_warn_unused bool Mat::const_row_col_iterator::operator!=(const row_col_iterator& rhs) const { @@ -8634,7 +9084,7 @@ Mat::begin_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( (col_num >= n_cols), "Mat::begin_col(): index out of bounds"); + arma_debug_check_bounds( (col_num >= n_cols), "Mat::begin_col(): index out of bounds" ); return colptr(col_num); } @@ -8648,7 +9098,7 @@ Mat::begin_col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (col_num >= n_cols), "Mat::begin_col(): index out of bounds"); + arma_debug_check_bounds( (col_num >= n_cols), "Mat::begin_col(): index out of bounds" ); return colptr(col_num); } @@ -8662,7 +9112,7 @@ Mat::end_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( (col_num >= n_cols), "Mat::end_col(): index out of bounds"); + arma_debug_check_bounds( (col_num >= n_cols), "Mat::end_col(): index out of bounds" ); return colptr(col_num) + n_rows; } @@ -8676,7 +9126,7 @@ Mat::end_col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (col_num >= n_cols), "Mat::end_col(): index out of bounds"); + arma_debug_check_bounds( (col_num >= n_cols), "Mat::end_col(): index out of bounds" ); return colptr(col_num) + n_rows; } @@ -8690,9 +9140,9 @@ Mat::begin_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); - return typename Mat::row_iterator(*this, row_num); + return typename Mat::row_iterator(*this, row_num, uword(0)); } @@ -8704,9 +9154,9 @@ Mat::begin_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "Mat::begin_row(): index out of bounds" ); - return typename Mat::const_row_iterator(*this, row_num); + return typename Mat::const_row_iterator(*this, row_num, uword(0)); } @@ -8718,9 +9168,9 @@ Mat::end_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); - return typename Mat::row_iterator(*this, row_num + 1); + return typename Mat::row_iterator(*this, (row_num + uword(1)), 0); } @@ -8732,9 +9182,9 @@ Mat::end_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "Mat::end_row(): index out of bounds" ); - return typename Mat::const_row_iterator(*this, row_num + 1); + return typename Mat::const_row_iterator(*this, (row_num + uword(1)), 0); } @@ -8865,6 +9315,15 @@ Mat::fixed::fixed() : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) { arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Mat::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } } @@ -8885,6 +9344,19 @@ Mat::fixed::fixed(const fixed +template +inline +Mat::fixed::fixed(const fill::scalar_holder f) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + template template template @@ -8894,11 +9366,11 @@ Mat::fixed::fixed(const fill::fill_class::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).eye(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } } @@ -8972,71 +9444,67 @@ Mat::fixed::fixed(const std::string& text) -#if defined(ARMA_USE_CXX11) - - template - template - inline - Mat::fixed::fixed(const std::initializer_list& list) - : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) - { - arma_extra_debug_sigprint_this(this); - - (*this).operator=(list); - } +template +template +inline +Mat::fixed::fixed(const std::initializer_list& list) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + (*this).operator=(list); + } + + + +template +template +inline +Mat& +Mat::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + const uword N = uword(list.size()); - template - template - inline - Mat& - Mat::fixed::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - const uword N = uword(list.size()); - - arma_debug_check( (N > fixed_n_elem), "Mat::fixed: initialiser list is too long" ); - - eT* this_mem = (*this).memptr(); - - arrayops::copy( this_mem, list.begin(), N ); - - for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } - - return *this; - } + arma_debug_check( (N > fixed_n_elem), "Mat::fixed: initialiser list is too long" ); + eT* this_mem = (*this).memptr(); + arrayops::copy( this_mem, list.begin(), N ); - template - template - inline - Mat::fixed::fixed(const std::initializer_list< std::initializer_list >& list) - : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) - { - arma_extra_debug_sigprint_this(this); - - Mat::init(list); - } + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + return *this; + } + + + +template +template +inline +Mat::fixed::fixed(const std::initializer_list< std::initializer_list >& list) + : Mat( arma_fixed_indicator(), fixed_n_rows, fixed_n_cols, 0, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + Mat::init(list); + } + + + +template +template +inline +Mat& +Mat::fixed::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); - template - template - inline - Mat& - Mat::fixed::operator=(const std::initializer_list< std::initializer_list >& list) - { - arma_extra_debug_sigprint(); - - Mat::init(list); - - return *this; - } + Mat::init(list); -#endif + return *this; + } @@ -9076,20 +9544,11 @@ Mat::fixed::operator=(const fixed::proxy_type::has_subview && X.P.is_alias(*this)); - if(bad_alias == false) - { - arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); - - eop_type::apply(*this, X); - } - else - { - arma_extra_debug_print("bad_alias = true"); - - Mat tmp(X); - - (*this) = tmp; - } + if(bad_alias) { const Mat tmp(X); (*this) = tmp; return *this; } + + arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); + + eop_type::apply(*this, X); return *this; } @@ -9115,20 +9574,11 @@ Mat::fixed::operator=(const fixed::proxy2_type::has_subview && X.P2.is_alias(*this)) ); - if(bad_alias == false) - { - arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); - - eglue_type::apply(*this, X); - } - else - { - arma_extra_debug_print("bad_alias = true"); - - Mat tmp(X); - - (*this) = tmp; - } + if(bad_alias) { const Mat tmp(X); (*this) = tmp; return *this; } + + arma_debug_assert_same_size(fixed_n_rows, fixed_n_cols, X.get_n_rows(), X.get_n_cols(), "Mat::fixed::operator="); + + eglue_type::apply(*this, X); return *this; } @@ -9173,7 +9623,6 @@ Mat::fixed::st() const template template arma_inline -arma_warn_unused const eT& Mat::fixed::at_alt(const uword ii) const { @@ -9195,7 +9644,6 @@ Mat::fixed::at_alt(const uword ii) const template template arma_inline -arma_warn_unused eT& Mat::fixed::operator[] (const uword ii) { @@ -9207,7 +9655,6 @@ Mat::fixed::operator[] (const uword ii) template template arma_inline -arma_warn_unused const eT& Mat::fixed::operator[] (const uword ii) const { @@ -9219,7 +9666,6 @@ Mat::fixed::operator[] (const uword ii) const template template arma_inline -arma_warn_unused eT& Mat::fixed::at(const uword ii) { @@ -9231,7 +9677,6 @@ Mat::fixed::at(const uword ii) template template arma_inline -arma_warn_unused const eT& Mat::fixed::at(const uword ii) const { @@ -9243,11 +9688,10 @@ Mat::fixed::at(const uword ii) const template template arma_inline -arma_warn_unused eT& Mat::fixed::operator() (const uword ii) { - arma_debug_check( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; } @@ -9257,21 +9701,49 @@ Mat::fixed::operator() (const uword ii) template template arma_inline -arma_warn_unused const eT& Mat::fixed::operator() (const uword ii) const { - arma_debug_check( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Mat::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : mem_local[ii]; } +#if defined(__cpp_multidimensional_subscript) + + template + template + arma_inline + eT& + Mat::fixed::operator[] (const uword in_row, const uword in_col) + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + + + + template + template + arma_inline + const eT& + Mat::fixed::operator[] (const uword in_row, const uword in_col) const + { + const uword iq = in_row + in_col*fixed_n_rows; + + return (use_extra) ? mem_local_extra[iq] : mem_local[iq]; + } + +#endif + + + template template arma_inline -arma_warn_unused eT& Mat::fixed::at(const uword in_row, const uword in_col) { @@ -9285,7 +9757,6 @@ Mat::fixed::at(const uword in_row, const uword i template template arma_inline -arma_warn_unused const eT& Mat::fixed::at(const uword in_row, const uword in_col) const { @@ -9299,11 +9770,10 @@ Mat::fixed::at(const uword in_row, const uword i template template arma_inline -arma_warn_unused eT& Mat::fixed::operator() (const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds" ); const uword iq = in_row + in_col*fixed_n_rows; @@ -9315,11 +9785,10 @@ Mat::fixed::operator() (const uword in_row, cons template template arma_inline -arma_warn_unused const eT& Mat::fixed::operator() (const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= fixed_n_rows) || (in_col >= fixed_n_cols)), "Mat::operator(): index out of bounds" ); const uword iq = in_row + in_col*fixed_n_rows; @@ -9331,7 +9800,6 @@ Mat::fixed::operator() (const uword in_row, cons template template arma_inline -arma_warn_unused eT* Mat::fixed::colptr(const uword in_col) { @@ -9345,7 +9813,6 @@ Mat::fixed::colptr(const uword in_col) template template arma_inline -arma_warn_unused const eT* Mat::fixed::colptr(const uword in_col) const { @@ -9359,7 +9826,6 @@ Mat::fixed::colptr(const uword in_col) const template template arma_inline -arma_warn_unused eT* Mat::fixed::memptr() { @@ -9371,7 +9837,6 @@ Mat::fixed::memptr() template template arma_inline -arma_warn_unused const eT* Mat::fixed::memptr() const { @@ -9383,7 +9848,6 @@ Mat::fixed::memptr() const template template arma_inline -arma_warn_unused bool Mat::fixed::is_vec() const { @@ -9640,17 +10104,14 @@ Mat_aux::set_real(Mat< std::complex >& out, const Base& X) const uword N = out.n_elem; - for(uword i=0; i( A[i], out_mem[i].imag() ); - } + for(uword i=0; i( P.at(row,col), (*out_mem).imag() ); + (*out_mem).real(P.at(row,col)); out_mem++; } } @@ -9684,17 +10145,14 @@ Mat_aux::set_imag(Mat< std::complex >& out, const Base& X) const uword N = out.n_elem; - for(uword i=0; i( out_mem[i].real(), A[i] ); - } + for(uword i=0; i( (*out_mem).real(), P.at(row,col) ); + (*out_mem).imag(P.at(row,col)); out_mem++; } } @@ -9702,7 +10160,7 @@ Mat_aux::set_imag(Mat< std::complex >& out, const Base& X) -#ifdef ARMA_EXTRA_MAT_MEAT +#if defined(ARMA_EXTRA_MAT_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_MAT_MEAT) #endif diff --git a/src/armadillo_bits/OpCube_bones.hpp b/src/armadillo_bits/OpCube_bones.hpp index 7ebc8d02..be4f618f 100644 --- a/src/armadillo_bits/OpCube_bones.hpp +++ b/src/armadillo_bits/OpCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,23 +20,19 @@ //! @{ -//! Analog of the Op class, intended for cubes - template -class OpCube : public BaseCube > +class OpCube : public BaseCube< typename T1::elem_type, OpCube > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - inline explicit OpCube(const BaseCube& in_m); inline OpCube(const BaseCube& in_m, const elem_type in_aux); inline OpCube(const BaseCube& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); inline OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); inline OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); - inline OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c, const uword in_aux_uword_d, const char junk); inline ~OpCube(); arma_aligned const T1& m; //!< the operand; must be derived from BaseCube @@ -42,8 +40,6 @@ class OpCube : public BaseCube > arma_aligned uword aux_uword_a; //!< auxiliary data, uword format arma_aligned uword aux_uword_b; //!< auxiliary data, uword format arma_aligned uword aux_uword_c; //!< auxiliary data, uword format - arma_aligned uword aux_uword_d; //!< auxiliary data, uword format - }; diff --git a/src/armadillo_bits/OpCube_meat.hpp b/src/armadillo_bits/OpCube_meat.hpp index ab50d5c7..a20d6b9f 100644 --- a/src/armadillo_bits/OpCube_meat.hpp +++ b/src/armadillo_bits/OpCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -74,19 +76,6 @@ OpCube::OpCube(const BaseCube& in_m, co -template -OpCube::OpCube(const BaseCube& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c, const uword in_aux_uword_d, const char) - : m(in_m.get_ref()) - , aux_uword_a(in_aux_uword_a) - , aux_uword_b(in_aux_uword_b) - , aux_uword_c(in_aux_uword_c) - , aux_uword_d(in_aux_uword_d) - { - arma_extra_debug_sigprint(); - } - - - template OpCube::~OpCube() { diff --git a/src/armadillo_bits/Op_bones.hpp b/src/armadillo_bits/Op_bones.hpp index 79f32173..fa8c3efd 100644 --- a/src/armadillo_bits/Op_bones.hpp +++ b/src/armadillo_bits/Op_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,24 +28,24 @@ struct Op_traits {}; template struct Op_traits { - static const bool is_row = op_type::template traits::is_row; - static const bool is_col = op_type::template traits::is_col; - static const bool is_xvec = op_type::template traits::is_xvec; + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; }; template struct Op_traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; template class Op - : public Base > - , public Op_traits::value > + : public Base< typename T1::elem_type, Op > + , public Op_traits::value> { public: @@ -54,14 +56,12 @@ class Op inline Op(const T1& in_m, const elem_type in_aux); inline Op(const T1& in_m, const elem_type in_aux, const uword in_aux_uword_a, const uword in_aux_uword_b); inline Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b); - inline Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c, const char junk); inline ~Op(); arma_aligned const T1& m; //!< the operand; must be derived from Base arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 arma_aligned uword aux_uword_a; //!< auxiliary data, uword format arma_aligned uword aux_uword_b; //!< auxiliary data, uword format - arma_aligned uword aux_uword_c; //!< auxiliary data, uword format }; diff --git a/src/armadillo_bits/Op_meat.hpp b/src/armadillo_bits/Op_meat.hpp index 3479a9e6..cd08ff93 100644 --- a/src/armadillo_bits/Op_meat.hpp +++ b/src/armadillo_bits/Op_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -65,19 +67,6 @@ Op::Op(const T1& in_m, const uword in_aux_uword_a, const uword in_a -template -inline -Op::Op(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c, const char) - : m(in_m) - , aux_uword_a(in_aux_uword_a) - , aux_uword_b(in_aux_uword_b) - , aux_uword_c(in_aux_uword_c) - { - arma_extra_debug_sigprint(); - } - - - template inline Op::~Op() diff --git a/src/armadillo_bits/Proxy.hpp b/src/armadillo_bits/Proxy.hpp index 8c51ef0b..ca3f713e 100644 --- a/src/armadillo_bits/Proxy.hpp +++ b/src/armadillo_bits/Proxy.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,6 @@ // use_at = boolean indicating whether at(row,col) must be used to get elements // use_mp = boolean indicating whether OpenMP can be used while processing elements // has_subview = boolean indicating whether the Q object has a subview -// fake_mat = boolean indicating whether the Q object is a matrix using memory from another object // // is_row = boolean indicating whether the Q object can be treated a row vector // is_col = boolean indicating whether the Q object can be treated a column vector @@ -74,14 +75,13 @@ struct Proxy_fixed typedef const elem_type* ea_type; typedef const T1& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; arma_aligned const T1& Q; @@ -91,13 +91,22 @@ struct Proxy_fixed arma_extra_debug_sigprint(); } - arma_inline static uword get_n_rows() { return T1::n_rows; } - arma_inline static uword get_n_cols() { return T1::n_cols; } - arma_inline static uword get_n_elem() { return T1::n_elem; } + //// this may require T1::n_elem etc to be declared as static constexpr inline variables (C++17) + //// see also the notes in Mat::fixed + //// https://en.cppreference.com/w/cpp/language/static + //// https://en.cppreference.com/w/cpp/language/inline + // + // static constexpr uword get_n_rows() { return T1::n_rows; } + // static constexpr uword get_n_cols() { return T1::n_cols; } + // static constexpr uword get_n_elem() { return T1::n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline uword get_n_rows() const { return is_row ? 1 : T1::n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : T1::n_cols; } + arma_inline uword get_n_elem() const { return T1::n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -132,11 +141,10 @@ struct Proxy_redirect { typedef Proxy_fixed result; }; template -class Proxy : public Proxy_redirect::value >::result +struct Proxy : public Proxy_redirect::value>::result { - public: inline Proxy(const T1& A) - : Proxy_redirect< T1, is_Mat_fixed::value >::result(A) + : Proxy_redirect::value>::result(A) { } }; @@ -144,24 +152,21 @@ class Proxy : public Proxy_redirect::value >::result template -class Proxy< Mat > +struct Proxy< Mat > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const eT* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const Mat& Q; @@ -175,9 +180,9 @@ class Proxy< Mat > arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -194,24 +199,21 @@ class Proxy< Mat > template -class Proxy< Col > +struct Proxy< Col > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Col stored_type; typedef const eT* ea_type; typedef const Col& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const Col& Q; @@ -222,12 +224,12 @@ class Proxy< Col > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -244,24 +246,21 @@ class Proxy< Col > template -class Proxy< Row > +struct Proxy< Row > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Row stored_type; typedef const eT* ea_type; typedef const Row& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const Row& Q; @@ -271,13 +270,13 @@ class Proxy< Row > arma_extra_debug_sigprint(); } - arma_inline uword get_n_rows() const { return 1; } + constexpr uword get_n_rows() const { return 1; } arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword, const uword col) const { return Q[col]; } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -294,24 +293,21 @@ class Proxy< Row > template -class Proxy< Gen > +struct Proxy< Gen > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Gen stored_type; typedef const Gen& ea_type; typedef const Gen& aligned_ea_type; - static const bool use_at = Gen::use_at; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = Gen::use_at; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = Gen::is_row; - static const bool is_col = Gen::is_col; - static const bool is_xvec = Gen::is_xvec; + static constexpr bool is_row = Gen::is_row; + static constexpr bool is_col = Gen::is_col; + static constexpr bool is_xvec = Gen::is_xvec; arma_aligned const Gen& Q; @@ -325,143 +321,40 @@ class Proxy< Gen > arma_inline uword get_n_cols() const { return (is_col ? 1 : Q.n_cols); } arma_inline uword get_n_elem() const { return (is_row ? 1 : Q.n_rows) * (is_col ? 1 : Q.n_cols); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } - - template - arma_inline bool has_overlap(const subview&) const { return false; } - - arma_inline bool is_aligned() const { return Gen::is_simple; } - }; - - - -template -class Proxy< Gen > - { - public: - - typedef typename T1::elem_type elem_type; - typedef typename get_pod_type::result pod_type; - typedef Mat stored_type; - typedef const elem_type* ea_type; - typedef const Mat& aligned_ea_type; - - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; - - static const bool is_row = Gen::is_row; - static const bool is_col = Gen::is_col; - static const bool is_xvec = Gen::is_xvec; - - arma_aligned const Mat Q; - - inline explicit Proxy(const Gen& A) - : Q(A) - { - arma_extra_debug_sigprint(); - } - - arma_inline uword get_n_rows() const { return (is_row ? 1 : Q.n_rows); } - arma_inline uword get_n_cols() const { return (is_col ? 1 : Q.n_cols); } - arma_inline uword get_n_elem() const { return (is_row ? 1 : Q.n_rows) * (is_col ? 1 : Q.n_cols); } - - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } - - arma_inline ea_type get_ea() const { return Q.memptr(); } - arma_inline aligned_ea_type get_aligned_ea() const { return Q; } - - template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } - arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } - }; - - - -template -class Proxy< Gen > - { - public: - - typedef typename T1::elem_type elem_type; - typedef typename get_pod_type::result pod_type; - typedef Mat stored_type; - typedef const elem_type* ea_type; - typedef const Mat& aligned_ea_type; - - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; - - static const bool is_row = Gen::is_row; - static const bool is_col = Gen::is_col; - static const bool is_xvec = Gen::is_xvec; - - arma_aligned const Mat Q; - - inline explicit Proxy(const Gen& A) - : Q(A) - { - arma_extra_debug_sigprint(); - } - - arma_inline uword get_n_rows() const { return (is_row ? 1 : Q.n_rows); } - arma_inline uword get_n_cols() const { return (is_col ? 1 : Q.n_cols); } - arma_inline uword get_n_elem() const { return (is_row ? 1 : Q.n_rows) * (is_col ? 1 : Q.n_cols); } - - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } - - arma_inline ea_type get_ea() const { return Q.memptr(); } - arma_inline aligned_ea_type get_aligned_ea() const { return Q; } - - template - arma_inline bool is_alias(const Mat&) const { return false; } - - template - arma_inline bool has_overlap(const subview&) const { return false; } - - arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + constexpr bool is_aligned() const { return Gen::is_simple; } }; template -class Proxy< eOp > +struct Proxy< eOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef eOp stored_type; typedef const eOp& ea_type; typedef const eOp& aligned_ea_type; - static const bool use_at = eOp::use_at; - static const bool use_mp = eOp::use_mp; - static const bool has_subview = eOp::has_subview; - static const bool fake_mat = eOp::fake_mat; + static constexpr bool use_at = eOp::use_at; + static constexpr bool use_mp = eOp::use_mp; + static constexpr bool has_subview = eOp::has_subview; - static const bool is_row = eOp::is_row; - static const bool is_col = eOp::is_col; - static const bool is_xvec = eOp::is_xvec; + static constexpr bool is_row = eOp::is_row; + static constexpr bool is_col = eOp::is_col; + static constexpr bool is_xvec = eOp::is_xvec; arma_aligned const eOp& Q; @@ -475,9 +368,9 @@ class Proxy< eOp > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } arma_inline uword get_n_elem() const { return Q.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -494,24 +387,21 @@ class Proxy< eOp > template -class Proxy< eGlue > +struct Proxy< eGlue > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef eGlue stored_type; typedef const eGlue& ea_type; typedef const eGlue& aligned_ea_type; - static const bool use_at = eGlue::use_at; - static const bool use_mp = eGlue::use_mp; - static const bool has_subview = eGlue::has_subview; - static const bool fake_mat = eGlue::fake_mat; + static constexpr bool use_at = eGlue::use_at; + static constexpr bool use_mp = eGlue::use_mp; + static constexpr bool has_subview = eGlue::has_subview; - static const bool is_row = eGlue::is_row; - static const bool is_col = eGlue::is_col; - static const bool is_xvec = eGlue::is_xvec; + static constexpr bool is_row = eGlue::is_row; + static constexpr bool is_col = eGlue::is_col; + static constexpr bool is_xvec = eGlue::is_xvec; arma_aligned const eGlue& Q; @@ -525,9 +415,9 @@ class Proxy< eGlue > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } arma_inline uword get_n_elem() const { return Q.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -544,24 +434,21 @@ class Proxy< eGlue > template -class Proxy< Op > +struct Proxy< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = Op::is_row; - static const bool is_col = Op::is_col; - static const bool is_xvec = Op::is_xvec; + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; arma_aligned const Mat Q; @@ -575,18 +462,18 @@ class Proxy< Op > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -594,24 +481,21 @@ class Proxy< Op > template -class Proxy< Glue > +struct Proxy< Glue > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool has_subview = false; - static const bool use_mp = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = Glue::is_row; - static const bool is_col = Glue::is_col; - static const bool is_xvec = Glue::is_xvec; + static constexpr bool is_row = Glue::is_row; + static constexpr bool is_col = Glue::is_col; + static constexpr bool is_xvec = Glue::is_xvec; arma_aligned const Mat Q; @@ -625,43 +509,156 @@ class Proxy< Glue > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; -template -class Proxy< mtOp > +template +struct Proxy< Glue > { - public: + typedef Glue this_Glue_type; + typedef Proxy< Glue > this_Proxy_type; + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef this_Glue_type stored_type; + typedef const this_Proxy_type& ea_type; + typedef const this_Proxy_type& aligned_ea_type; + + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); + + static constexpr bool is_row = this_Glue_type::is_row; + static constexpr bool is_col = this_Glue_type::is_col; + static constexpr bool is_xvec = this_Glue_type::is_xvec; + + arma_aligned const this_Glue_type& Q; + arma_aligned const Proxy P1; + arma_aligned const Proxy P2; + + arma_lt_comparator comparator; + + inline explicit Proxy(const this_Glue_type& X) + : Q (X ) + , P1(X.A) + , P2(X.B) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(P1, P2, "element-wise min()"); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : P1.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : P1.get_n_cols(); } + arma_inline uword get_n_elem() const { return P1.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { const elem_type A = P1[i]; const elem_type B = P2[i]; return comparator(A,B) ? A : B; } + arma_inline elem_type at (const uword r, const uword c) const { const elem_type A = P1.at(r,c); const elem_type B = P2.at(r,c); return comparator(A,B) ? A : B; } + arma_inline elem_type at_alt (const uword i) const { const elem_type A = P1.at_alt(i); const elem_type B = P2.at_alt(i); return comparator(A,B) ? A : B; } + + arma_inline ea_type get_ea() const { return *this; } + arma_inline aligned_ea_type get_aligned_ea() const { return *this; } + + template + arma_inline bool is_alias(const Mat& X) const { return (P1.is_alias(X) || P2.is_alias(X)); } + template + arma_inline bool has_overlap(const subview& X) const { return (P1.has_overlap(X) || P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return (P1.is_aligned() && P2.is_aligned()); } + }; + + + +template +struct Proxy< Glue > + { + typedef Glue this_Glue_type; + typedef Proxy< Glue > this_Proxy_type; + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef this_Glue_type stored_type; + typedef const this_Proxy_type& ea_type; + typedef const this_Proxy_type& aligned_ea_type; + + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); + + static constexpr bool is_row = this_Glue_type::is_row; + static constexpr bool is_col = this_Glue_type::is_col; + static constexpr bool is_xvec = this_Glue_type::is_xvec; + + arma_aligned const this_Glue_type& Q; + arma_aligned const Proxy P1; + arma_aligned const Proxy P2; + + arma_gt_comparator comparator; + + inline explicit Proxy(const this_Glue_type& X) + : Q (X ) + , P1(X.A) + , P2(X.B) + { + arma_extra_debug_sigprint(); + + arma_debug_assert_same_size(P1, P2, "element-wise max()"); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : P1.get_n_rows(); } + arma_inline uword get_n_cols() const { return is_col ? 1 : P1.get_n_cols(); } + arma_inline uword get_n_elem() const { return P1.get_n_elem(); } + + arma_inline elem_type operator[] (const uword i) const { const elem_type A = P1[i]; const elem_type B = P2[i]; return comparator(A,B) ? A : B; } + arma_inline elem_type at (const uword r, const uword c) const { const elem_type A = P1.at(r,c); const elem_type B = P2.at(r,c); return comparator(A,B) ? A : B; } + arma_inline elem_type at_alt (const uword i) const { const elem_type A = P1.at_alt(i); const elem_type B = P2.at_alt(i); return comparator(A,B) ? A : B; } + + arma_inline ea_type get_ea() const { return *this; } + arma_inline aligned_ea_type get_aligned_ea() const { return *this; } + + template + arma_inline bool is_alias(const Mat& X) const { return (P1.is_alias(X) || P2.is_alias(X)); } + + template + arma_inline bool has_overlap(const subview& X) const { return (P1.has_overlap(X) || P2.has_overlap(X)); } + + arma_inline bool is_aligned() const { return (P1.is_aligned() && P2.is_aligned()); } + }; + + + +template +struct Proxy< mtOp > + { typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = mtOp::is_row; - static const bool is_col = mtOp::is_col; - static const bool is_xvec = mtOp::is_xvec; + static constexpr bool is_row = mtOp::is_row; + static constexpr bool is_col = mtOp::is_col; + static constexpr bool is_xvec = mtOp::is_xvec; arma_aligned const Mat Q; @@ -675,18 +672,18 @@ class Proxy< mtOp > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row,col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -694,24 +691,21 @@ class Proxy< mtOp > template -class Proxy< mtGlue > +struct Proxy< mtGlue > { - public: - typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = mtGlue::is_row; - static const bool is_col = mtGlue::is_col; - static const bool is_xvec = mtGlue::is_xvec; + static constexpr bool is_row = mtGlue::is_row; + static constexpr bool is_col = mtGlue::is_col; + static constexpr bool is_xvec = mtGlue::is_xvec; arma_aligned const Mat Q; @@ -725,18 +719,18 @@ class Proxy< mtGlue > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row,col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -744,24 +738,21 @@ class Proxy< mtGlue > template -class Proxy< CubeToMatOp > +struct Proxy< CubeToMatOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = CubeToMatOp::is_row; - static const bool is_col = CubeToMatOp::is_col; - static const bool is_xvec = CubeToMatOp::is_xvec; + static constexpr bool is_row = CubeToMatOp::is_row; + static constexpr bool is_col = CubeToMatOp::is_col; + static constexpr bool is_xvec = CubeToMatOp::is_xvec; arma_aligned const Mat Q; @@ -775,18 +766,18 @@ class Proxy< CubeToMatOp > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -794,24 +785,21 @@ class Proxy< CubeToMatOp > template -class Proxy< CubeToMatOp > +struct Proxy< CubeToMatOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = true; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const unwrap_cube U; arma_aligned const Mat Q; @@ -824,21 +812,21 @@ class Proxy< CubeToMatOp > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -846,24 +834,21 @@ class Proxy< CubeToMatOp > template -class Proxy< SpToDOp > +struct Proxy< SpToDOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = SpToDOp::is_row; - static const bool is_col = SpToDOp::is_col; - static const bool is_xvec = SpToDOp::is_xvec; + static constexpr bool is_row = SpToDOp::is_row; + static constexpr bool is_col = SpToDOp::is_col; + static constexpr bool is_xvec = SpToDOp::is_xvec; arma_aligned const Mat Q; @@ -877,18 +862,18 @@ class Proxy< SpToDOp > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -896,24 +881,21 @@ class Proxy< SpToDOp > template -class Proxy< SpToDOp > +struct Proxy< SpToDOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = true; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const unwrap_spmat U; arma_aligned const Mat Q; @@ -926,21 +908,68 @@ class Proxy< SpToDOp > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< SpToDGlue > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = SpToDGlue::is_row; + static constexpr bool is_col = SpToDGlue::is_col; + static constexpr bool is_xvec = SpToDGlue::is_xvec; + + arma_aligned const Mat Q; + + inline explicit Proxy(const SpToDGlue& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return is_row ? 1 : Q.n_rows; } + arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -948,24 +977,21 @@ class Proxy< SpToDOp > template -class Proxy< subview > +struct Proxy< subview > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview stored_type; typedef const subview& ea_type; typedef const subview& aligned_ea_type; - static const bool use_at = true; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const subview& Q; @@ -979,9 +1005,9 @@ class Proxy< subview > arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -992,30 +1018,27 @@ class Proxy< subview > template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< subview_col > +struct Proxy< subview_col > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_col stored_type; typedef const eT* ea_type; typedef const subview_col& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_col& Q; @@ -1026,12 +1049,12 @@ class Proxy< subview_col > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.colmem; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1048,24 +1071,70 @@ class Proxy< subview_col > template -class Proxy< subview_row > +struct Proxy< subview_cols > { - public: + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const eT* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const subview_cols& sv; + arma_aligned const Mat Q; + + inline explicit Proxy(const subview_cols& A) + : sv(A) + , Q ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + arma_inline bool is_alias(const Mat& X) const { return (is_same_type::value) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false; } + + template + arma_inline bool has_overlap(const subview& X) const { return sv.check_overlap(X); } + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + +template +struct Proxy< subview_row > + { typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_row stored_type; typedef const subview_row& ea_type; typedef const subview_row& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const subview_row& Q; @@ -1075,13 +1144,13 @@ class Proxy< subview_row > arma_extra_debug_sigprint(); } - arma_inline uword get_n_rows() const { return 1; } + constexpr uword get_n_rows() const { return 1; } arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword, const uword col) const { return Q[col]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1092,30 +1161,27 @@ class Proxy< subview_row > template arma_inline bool has_overlap(const subview& X) const { return Q.check_overlap(X); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< subview_elem1 > +struct Proxy< subview_elem1 > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_elem1 stored_type; typedef const Proxy< subview_elem1 >& ea_type; typedef const Proxy< subview_elem1 >& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_elem1& Q; arma_aligned const Proxy R; @@ -1129,16 +1195,16 @@ class Proxy< subview_elem1 > const bool R_is_vec = ((R.get_n_rows() == 1) || (R.get_n_cols() == 1)); const bool R_is_empty = (R.get_n_elem() == 0); - arma_debug_check( ((R_is_vec == false) && (R_is_empty == false)), "Mat::elem(): given object is not a vector" ); + arma_debug_check( ((R_is_vec == false) && (R_is_empty == false)), "Mat::elem(): given object must be a vector" ); } arma_inline uword get_n_rows() const { return R.get_n_elem(); } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return R.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i, 0) : R[i ]; arma_debug_check( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } - arma_inline elem_type at (const uword row, const uword) const { const uword ii = (Proxy::use_at) ? R.at(row,0) : R[row]; arma_debug_check( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } - arma_inline elem_type at_alt (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i, 0) : R[i ]; arma_debug_check( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + arma_inline elem_type operator[] (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i,0) : R[i]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + arma_inline elem_type at (const uword r, const uword) const { const uword ii = (Proxy::use_at) ? R.at(r,0) : R[r]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } + arma_inline elem_type at_alt (const uword i) const { const uword ii = (Proxy::use_at) ? R.at(i,0) : R[i]; arma_debug_check_bounds( (ii >= Q.m.n_elem), "Mat::elem(): index out of bounds" ); return Q.m[ii]; } arma_inline ea_type get_ea() const { return (*this); } arma_inline aligned_ea_type get_aligned_ea() const { return (*this); } @@ -1149,30 +1215,27 @@ class Proxy< subview_elem1 > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< subview_elem2 > +struct Proxy< subview_elem2 > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const eT* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const Mat Q; @@ -1186,18 +1249,18 @@ class Proxy< subview_elem2 > arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -1205,24 +1268,21 @@ class Proxy< subview_elem2 > template -class Proxy< diagview > +struct Proxy< diagview > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef diagview stored_type; typedef const diagview& ea_type; typedef const diagview& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const diagview& Q; @@ -1233,12 +1293,12 @@ class Proxy< diagview > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q.at(row, 0); } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1249,13 +1309,13 @@ class Proxy< diagview > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy_diagvec_mat +struct Proxy_diagvec_mat { inline Proxy_diagvec_mat(const T1&) {} }; @@ -1263,41 +1323,38 @@ class Proxy_diagvec_mat template -class Proxy_diagvec_mat< Op > +struct Proxy_diagvec_mat< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef diagview stored_type; typedef const diagview& ea_type; typedef const diagview& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const Mat& R; arma_aligned const diagview Q; inline explicit Proxy_diagvec_mat(const Op& A) - : R(A.m), Q( R.diag( (A.aux_uword_b > 0) ? -sword(A.aux_uword_a) : sword(A.aux_uword_a) ) ) + : R(A.m), Q( R.diag() ) { arma_extra_debug_sigprint(); } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q.at(row, 0); } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1308,13 +1365,13 @@ class Proxy_diagvec_mat< Op > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy_diagvec_expr +struct Proxy_diagvec_expr { inline Proxy_diagvec_expr(const T1&) {} }; @@ -1322,24 +1379,21 @@ class Proxy_diagvec_expr template -class Proxy_diagvec_expr< Op > +struct Proxy_diagvec_expr< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const Mat Q; @@ -1350,21 +1404,21 @@ class Proxy_diagvec_expr< Op > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q.at(row, 0); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -1383,11 +1437,9 @@ struct Proxy_diagvec_redirect< Op, false> { typedef Proxy_diagve template -class Proxy< Op > +struct Proxy< Op > : public Proxy_diagvec_redirect< Op, is_Mat::value >::result { - public: - typedef typename Proxy_diagvec_redirect< Op, is_Mat::value >::result Proxy_diagvec; inline explicit Proxy(const Op& A) @@ -1399,6 +1451,53 @@ class Proxy< Op > +template +struct Proxy< Op > + { + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + typedef Mat stored_type; + typedef const elem_type* ea_type; + typedef const Mat& aligned_ea_type; + + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const Mat Q; + + inline explicit Proxy(const Op& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + + arma_inline ea_type get_ea() const { return Q.memptr(); } + arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + + template + constexpr bool is_alias(const Mat&) const { return false; } + + template + constexpr bool has_overlap(const subview&) const { return false; } + + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + }; + + + template struct Proxy_xtrans_default { @@ -1410,22 +1509,19 @@ struct Proxy_xtrans_default template struct Proxy_xtrans_default< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef xtrans_mat stored_type; typedef const xtrans_mat& ea_type; typedef const xtrans_mat& aligned_ea_type; - static const bool use_at = true; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const unwrap U; const xtrans_mat Q; @@ -1446,7 +1542,7 @@ struct Proxy_xtrans_default< Op > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; @@ -1454,22 +1550,19 @@ struct Proxy_xtrans_default< Op > template struct Proxy_xtrans_default< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef xtrans_mat stored_type; typedef const xtrans_mat& ea_type; typedef const xtrans_mat& aligned_ea_type; - static const bool use_at = true; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const unwrap U; const xtrans_mat Q; @@ -1490,7 +1583,7 @@ struct Proxy_xtrans_default< Op > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; @@ -1512,15 +1605,14 @@ struct Proxy_xtrans_vector< Op > typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = quasi_unwrap::has_subview; - static const bool fake_mat = true; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = quasi_unwrap::has_subview; // NOTE: the Op class takes care of swapping row and col for op_htrans - static const bool is_row = Op::is_row; - static const bool is_col = Op::is_col; - static const bool is_xvec = Op::is_xvec; + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; arma_aligned const quasi_unwrap U; // avoid copy if T1 is a Row, Col or subview_col arma_aligned const Mat Q; @@ -1555,15 +1647,14 @@ struct Proxy_xtrans_vector< Op > typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = quasi_unwrap::has_subview; - static const bool fake_mat = true; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = quasi_unwrap::has_subview; // NOTE: the Op class takes care of swapping row and col for op_strans - static const bool is_row = Op::is_row; - static const bool is_col = Op::is_col; - static const bool is_xvec = Op::is_xvec; + static constexpr bool is_row = Op::is_row; + static constexpr bool is_col = Op::is_col; + static constexpr bool is_xvec = Op::is_xvec; arma_aligned const quasi_unwrap U; // avoid copy if T1 is a Row, Col or subview_col arma_aligned const Mat Q; @@ -1601,7 +1692,7 @@ struct Proxy_xtrans_redirect { typedef Proxy_xtrans_vector resul template -class Proxy< Op > +struct Proxy< Op > : public Proxy_xtrans_redirect < @@ -1609,8 +1700,6 @@ class Proxy< Op > ((is_cx::no) && ((Op::is_row) || (Op::is_col)) ) >::result { - public: - typedef typename Proxy_xtrans_redirect @@ -1626,14 +1715,13 @@ class Proxy< Op > typedef typename Proxy_xtrans::ea_type ea_type; typedef typename Proxy_xtrans::aligned_ea_type aligned_ea_type; - static const bool use_at = Proxy_xtrans::use_at; - static const bool use_mp = Proxy_xtrans::use_mp; - static const bool has_subview = Proxy_xtrans::has_subview; - static const bool fake_mat = Proxy_xtrans::fake_mat; + static constexpr bool use_at = Proxy_xtrans::use_at; + static constexpr bool use_mp = Proxy_xtrans::use_mp; + static constexpr bool has_subview = Proxy_xtrans::has_subview; - static const bool is_row = Proxy_xtrans::is_row; - static const bool is_col = Proxy_xtrans::is_col; - static const bool is_xvec = Proxy_xtrans::is_xvec; + static constexpr bool is_row = Proxy_xtrans::is_row; + static constexpr bool is_col = Proxy_xtrans::is_col; + static constexpr bool is_xvec = Proxy_xtrans::is_xvec; using Proxy_xtrans::Q; @@ -1647,9 +1735,9 @@ class Proxy< Op > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Proxy_xtrans::get_ea(); } arma_inline aligned_ea_type get_aligned_ea() const { return Proxy_xtrans::get_aligned_ea(); } @@ -1666,7 +1754,7 @@ class Proxy< Op > template -class Proxy< Op > +struct Proxy< Op > : public Proxy_xtrans_redirect < @@ -1674,8 +1762,6 @@ class Proxy< Op > ( (Op::is_row) || (Op::is_col) ) >::result { - public: - typedef typename Proxy_xtrans_redirect @@ -1691,14 +1777,13 @@ class Proxy< Op > typedef typename Proxy_xtrans::ea_type ea_type; typedef typename Proxy_xtrans::aligned_ea_type aligned_ea_type; - static const bool use_at = Proxy_xtrans::use_at; - static const bool use_mp = Proxy_xtrans::use_mp; - static const bool has_subview = Proxy_xtrans::has_subview; - static const bool fake_mat = Proxy_xtrans::fake_mat; + static constexpr bool use_at = Proxy_xtrans::use_at; + static constexpr bool use_mp = Proxy_xtrans::use_mp; + static constexpr bool has_subview = Proxy_xtrans::has_subview; - static const bool is_row = Proxy_xtrans::is_row; - static const bool is_col = Proxy_xtrans::is_col; - static const bool is_xvec = Proxy_xtrans::is_xvec; + static constexpr bool is_row = Proxy_xtrans::is_row; + static constexpr bool is_col = Proxy_xtrans::is_col; + static constexpr bool is_xvec = Proxy_xtrans::is_xvec; using Proxy_xtrans::Q; @@ -1712,9 +1797,9 @@ class Proxy< Op > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Proxy_xtrans::get_ea(); } arma_inline aligned_ea_type get_aligned_ea() const { return Proxy_xtrans::get_aligned_ea(); } @@ -1739,14 +1824,13 @@ struct Proxy_subview_row_htrans_cx typedef const subview_row_htrans& ea_type; typedef const subview_row_htrans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row_htrans Q; @@ -1774,14 +1858,13 @@ struct Proxy_subview_row_htrans_non_cx typedef const subview_row_strans& ea_type; typedef const subview_row_strans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row_strans Q; @@ -1812,7 +1895,7 @@ struct Proxy_subview_row_htrans_redirect { typedef Proxy_subview_row_ template -class Proxy< Op, op_htrans> > +struct Proxy< Op, op_htrans> > : public Proxy_subview_row_htrans_redirect < @@ -1820,8 +1903,6 @@ class Proxy< Op, op_htrans> > is_cx::yes >::result { - public: - typedef typename Proxy_subview_row_htrans_redirect @@ -1837,14 +1918,13 @@ class Proxy< Op, op_htrans> > typedef typename Proxy_sv_row_ht::ea_type ea_type; typedef typename Proxy_sv_row_ht::ea_type aligned_ea_type; - static const bool use_at = Proxy_sv_row_ht::use_at; - static const bool use_mp = Proxy_sv_row_ht::use_mp; - static const bool has_subview = Proxy_sv_row_ht::has_subview; - static const bool fake_mat = Proxy_sv_row_ht::fake_mat; + static constexpr bool use_at = Proxy_sv_row_ht::use_at; + static constexpr bool use_mp = Proxy_sv_row_ht::use_mp; + static constexpr bool has_subview = Proxy_sv_row_ht::has_subview; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; using Proxy_sv_row_ht::Q; @@ -1855,12 +1935,12 @@ class Proxy< Op, op_htrans> > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1871,30 +1951,27 @@ class Proxy< Op, op_htrans> > template arma_inline bool has_overlap(const subview& X) const { return Proxy_sv_row_ht::has_overlap(X); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< Op, op_strans> > +struct Proxy< Op, op_strans> > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_row_strans stored_type; typedef const subview_row_strans& ea_type; typedef const subview_row_strans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row_strans Q; @@ -1905,12 +1982,12 @@ class Proxy< Op, op_strans> > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1921,16 +1998,14 @@ class Proxy< Op, op_strans> > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< Op< Row< std::complex >, op_htrans> > +struct Proxy< Op< Row< std::complex >, op_htrans> > { - public: - typedef typename std::complex eT; typedef typename std::complex elem_type; @@ -1939,14 +2014,13 @@ class Proxy< Op< Row< std::complex >, op_htrans> > typedef const xvec_htrans& ea_type; typedef const xvec_htrans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; const xvec_htrans Q; const Row& src; @@ -1959,12 +2033,12 @@ class Proxy< Op< Row< std::complex >, op_htrans> > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -1975,16 +2049,14 @@ class Proxy< Op< Row< std::complex >, op_htrans> > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< Op< Col< std::complex >, op_htrans> > +struct Proxy< Op< Col< std::complex >, op_htrans> > { - public: - typedef typename std::complex eT; typedef typename std::complex elem_type; @@ -1993,14 +2065,13 @@ class Proxy< Op< Col< std::complex >, op_htrans> > typedef const xvec_htrans& ea_type; typedef const xvec_htrans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const xvec_htrans Q; const Col& src; @@ -2012,13 +2083,13 @@ class Proxy< Op< Col< std::complex >, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline uword get_n_rows() const { return 1; } + constexpr uword get_n_rows() const { return 1; } arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword, const uword col) const { return Q[col]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2029,16 +2100,14 @@ class Proxy< Op< Col< std::complex >, op_htrans> > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< Op< subview_col< std::complex >, op_htrans> > +struct Proxy< Op< subview_col< std::complex >, op_htrans> > { - public: - typedef typename std::complex eT; typedef typename std::complex elem_type; @@ -2047,14 +2116,13 @@ class Proxy< Op< subview_col< std::complex >, op_htrans> > typedef const xvec_htrans& ea_type; typedef const xvec_htrans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const xvec_htrans Q; const subview_col& src; @@ -2066,13 +2134,13 @@ class Proxy< Op< subview_col< std::complex >, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline uword get_n_rows() const { return 1; } + constexpr uword get_n_rows() const { return 1; } arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword, const uword col) const { return Q[col]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword, const uword c) const { return Q[c]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2083,31 +2151,28 @@ class Proxy< Op< subview_col< std::complex >, op_htrans> > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< Op > +struct Proxy< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef eOp< Op, eop_scalar_times> stored_type; typedef const eOp< Op, eop_scalar_times>& ea_type; typedef const eOp< Op, eop_scalar_times>& aligned_ea_type; - static const bool use_at = eOp< Op, eop_scalar_times>::use_at; - static const bool use_mp = eOp< Op, eop_scalar_times>::use_mp; - static const bool has_subview = eOp< Op, eop_scalar_times>::has_subview; - static const bool fake_mat = eOp< Op, eop_scalar_times>::fake_mat; + static constexpr bool use_at = eOp< Op, eop_scalar_times>::use_at; + static constexpr bool use_mp = eOp< Op, eop_scalar_times>::use_mp; + static constexpr bool has_subview = eOp< Op, eop_scalar_times>::has_subview; // NOTE: the Op class takes care of swapping row and col for op_htrans - static const bool is_row = eOp< Op, eop_scalar_times>::is_row; - static const bool is_col = eOp< Op, eop_scalar_times>::is_col; - static const bool is_xvec = eOp< Op, eop_scalar_times>::is_xvec; + static constexpr bool is_row = eOp< Op, eop_scalar_times>::is_row; + static constexpr bool is_col = eOp< Op, eop_scalar_times>::is_col; + static constexpr bool is_xvec = eOp< Op, eop_scalar_times>::is_xvec; arma_aligned const Op R; arma_aligned const eOp< Op, eop_scalar_times > Q; @@ -2123,9 +2188,9 @@ class Proxy< Op > arma_inline uword get_n_cols() const { return is_col ? 1 : Q.get_n_cols(); } arma_inline uword get_n_elem() const { return Q.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r, c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2142,24 +2207,21 @@ class Proxy< Op > template -class Proxy< subview_row_strans > +struct Proxy< subview_row_strans > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_row_strans stored_type; typedef const subview_row_strans& ea_type; typedef const subview_row_strans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row_strans& Q; @@ -2170,12 +2232,12 @@ class Proxy< subview_row_strans > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2186,30 +2248,27 @@ class Proxy< subview_row_strans > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< subview_row_htrans > +struct Proxy< subview_row_htrans > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_row_htrans stored_type; typedef const subview_row_htrans& ea_type; typedef const subview_row_htrans& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row_htrans& Q; @@ -2220,12 +2279,12 @@ class Proxy< subview_row_htrans > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2236,30 +2295,27 @@ class Proxy< subview_row_htrans > template arma_inline bool has_overlap(const subview& X) const { return is_alias(X.m); } - arma_inline bool is_aligned() const { return false; } + constexpr bool is_aligned() const { return false; } }; template -class Proxy< xtrans_mat > +struct Proxy< xtrans_mat > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const eT* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const Mat Q; @@ -2273,18 +2329,18 @@ class Proxy< xtrans_mat > arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row,col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -2292,24 +2348,21 @@ class Proxy< xtrans_mat > template -class Proxy< xvec_htrans > +struct Proxy< xvec_htrans > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const eT* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - static const bool fake_mat = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = true; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; arma_aligned const Mat Q; @@ -2323,18 +2376,18 @@ class Proxy< xvec_htrans > arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row,col); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c) const { return Q.at(r,c); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } template - arma_inline bool has_overlap(const subview&) const { return false; } + constexpr bool has_overlap(const subview&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -2342,7 +2395,7 @@ class Proxy< xvec_htrans > template -class Proxy_vectorise_col_mat +struct Proxy_vectorise_col_mat { inline Proxy_vectorise_col_mat(const T1&) {} }; @@ -2350,24 +2403,21 @@ class Proxy_vectorise_col_mat template -class Proxy_vectorise_col_mat< Op > +struct Proxy_vectorise_col_mat< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Mat stored_type; typedef const elem_type* ea_type; typedef const Mat& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = true; - static const bool fake_mat = true; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const unwrap U; arma_aligned const Mat Q; @@ -2380,12 +2430,12 @@ class Proxy_vectorise_col_mat< Op > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword) const { return Q[row]; } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword) const { return Q[r]; } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -2402,7 +2452,7 @@ class Proxy_vectorise_col_mat< Op > template -class Proxy_vectorise_col_expr +struct Proxy_vectorise_col_expr { inline Proxy_vectorise_col_expr(const T1&) {} }; @@ -2410,24 +2460,21 @@ class Proxy_vectorise_col_expr template -class Proxy_vectorise_col_expr< Op > +struct Proxy_vectorise_col_expr< Op > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Op stored_type; typedef typename Proxy::ea_type ea_type; typedef typename Proxy::aligned_ea_type aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = Proxy::use_mp; - static const bool has_subview = Proxy::has_subview; - static const bool fake_mat = Proxy::fake_mat; + static constexpr bool use_at = false; + static constexpr bool use_mp = Proxy::use_mp; + static constexpr bool has_subview = Proxy::has_subview; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const Op& Q; arma_aligned const Proxy R; @@ -2440,12 +2487,12 @@ class Proxy_vectorise_col_expr< Op > } arma_inline uword get_n_rows() const { return R.get_n_elem(); } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return R.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return R[i]; } - arma_inline elem_type at (const uword row, const uword) const { return R.at(row, 0); } - arma_inline elem_type at_alt (const uword i) const { return R.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return R[i]; } + arma_inline elem_type at (const uword r, const uword) const { return R.at(r, 0); } + arma_inline elem_type at_alt (const uword i) const { return R.at_alt(i); } arma_inline ea_type get_ea() const { return R.get_ea(); } arma_inline aligned_ea_type get_aligned_ea() const { return R.get_aligned_ea(); } @@ -2473,11 +2520,9 @@ struct Proxy_vectorise_col_redirect< Op, false> { typedef template -class Proxy< Op > +struct Proxy< Op > : public Proxy_vectorise_col_redirect< Op, (Proxy::use_at) >::result { - public: - typedef typename Proxy_vectorise_col_redirect< Op, (Proxy::use_at) >::result Proxy_vectorise_col; inline explicit Proxy(const Op& A) diff --git a/src/armadillo_bits/ProxyCube.hpp b/src/armadillo_bits/ProxyCube.hpp index 2bcd3c12..ef639284 100644 --- a/src/armadillo_bits/ProxyCube.hpp +++ b/src/armadillo_bits/ProxyCube.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,9 +22,8 @@ template -class ProxyCube +struct ProxyCube { - public: inline ProxyCube(const T1&) { arma_type_check(( is_arma_cube_type::value == false )); @@ -35,19 +36,17 @@ class ProxyCube // which can provide access to elements via operator[] template -class ProxyCube< Cube > +struct ProxyCube< Cube > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const eT* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube& Q; @@ -63,9 +62,9 @@ class ProxyCube< Cube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -73,25 +72,26 @@ class ProxyCube< Cube > template arma_inline bool is_alias(const Cube& X) const { return (void_ptr(&Q) == void_ptr(&X)); } + template + arma_inline bool has_overlap(const subview_cube& X) const { return is_alias(X.m); } + arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; template -class ProxyCube< GenCube > +struct ProxyCube< GenCube > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef GenCube stored_type; typedef const GenCube& ea_type; typedef const GenCube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const GenCube& Q; @@ -107,123 +107,36 @@ class ProxyCube< GenCube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_rows*Q.n_cols*Q.n_slices; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q[i]; } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q[i]; } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } - - arma_inline bool is_aligned() const { return GenCube::is_simple; } - }; - - - -template -class ProxyCube< GenCube > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - typedef Cube stored_type; - typedef const eT* ea_type; - typedef const Cube& aligned_ea_type; - - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - - arma_aligned const Cube Q; - - inline explicit ProxyCube(const GenCube& A) - : Q(A) - { - arma_extra_debug_sigprint(); - } - - arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return Q.n_cols; } - arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } - arma_inline uword get_n_slices() const { return Q.n_slices; } - arma_inline uword get_n_elem() const { return Q.n_elem; } - - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } - - arma_inline ea_type get_ea() const { return Q.memptr(); } - arma_inline aligned_ea_type get_aligned_ea() const { return Q; } + constexpr bool is_alias(const Cube&) const { return false; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool has_overlap(const subview_cube&) const { return false; } - arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } - }; - - - -template -class ProxyCube< GenCube > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - typedef Cube stored_type; - typedef const eT* ea_type; - typedef const Cube& aligned_ea_type; - - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; - - arma_aligned const Cube Q; - - inline explicit ProxyCube(const GenCube& A) - : Q(A) - { - arma_extra_debug_sigprint(); - } - - arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return Q.n_cols; } - arma_inline uword get_n_elem_slice() const { return Q.n_elem_slice; } - arma_inline uword get_n_slices() const { return Q.n_slices; } - arma_inline uword get_n_elem() const { return Q.n_elem; } - - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } - - arma_inline ea_type get_ea() const { return Q.memptr(); } - arma_inline aligned_ea_type get_aligned_ea() const { return Q; } - - template - arma_inline bool is_alias(const Cube&) const { return false; } - - arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } + constexpr bool is_aligned() const { return GenCube::is_simple; } }; template -class ProxyCube< OpCube > +struct ProxyCube< OpCube > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const elem_type* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube Q; @@ -239,15 +152,18 @@ class ProxyCube< OpCube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -255,19 +171,17 @@ class ProxyCube< OpCube > template -class ProxyCube< GlueCube > +struct ProxyCube< GlueCube > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const elem_type* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube Q; @@ -283,15 +197,18 @@ class ProxyCube< GlueCube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -299,19 +216,17 @@ class ProxyCube< GlueCube > template -class ProxyCube< subview_cube > +struct ProxyCube< subview_cube > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef subview_cube stored_type; typedef const subview_cube& ea_type; typedef const subview_cube& aligned_ea_type; - static const bool use_at = true; - static const bool use_mp = false; - static const bool has_subview = true; + static constexpr bool use_at = true; + static constexpr bool use_mp = false; + static constexpr bool has_subview = true; arma_aligned const subview_cube& Q; @@ -327,9 +242,9 @@ class ProxyCube< subview_cube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -337,25 +252,26 @@ class ProxyCube< subview_cube > template arma_inline bool is_alias(const Cube& X) const { return (void_ptr(&(Q.m)) == void_ptr(&X)); } - arma_inline bool is_aligned() const { return false; } + template + arma_inline bool has_overlap(const subview_cube& X) const { return Q.check_overlap(X); } + + constexpr bool is_aligned() const { return false; } }; template -class ProxyCube< subview_cube_slices > +struct ProxyCube< subview_cube_slices > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const eT* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube Q; @@ -371,15 +287,18 @@ class ProxyCube< subview_cube_slices > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -387,19 +306,17 @@ class ProxyCube< subview_cube_slices > template -class ProxyCube< eOpCube > +struct ProxyCube< eOpCube > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef eOpCube stored_type; typedef const eOpCube& ea_type; typedef const eOpCube& aligned_ea_type; - static const bool use_at = eOpCube::use_at; - static const bool use_mp = eOpCube::use_mp; - static const bool has_subview = eOpCube::has_subview; + static constexpr bool use_at = eOpCube::use_at; + static constexpr bool use_mp = eOpCube::use_mp; + static constexpr bool has_subview = eOpCube::has_subview; arma_aligned const eOpCube& Q; @@ -415,9 +332,9 @@ class ProxyCube< eOpCube > arma_inline uword get_n_slices() const { return Q.get_n_slices(); } arma_inline uword get_n_elem() const { return Q.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -425,25 +342,26 @@ class ProxyCube< eOpCube > template arma_inline bool is_alias(const Cube& X) const { return Q.P.is_alias(X); } + template + arma_inline bool has_overlap(const subview_cube& X) const { return Q.P.has_overlap(X); } + arma_inline bool is_aligned() const { return Q.P.is_aligned(); } }; template -class ProxyCube< eGlueCube > +struct ProxyCube< eGlueCube > { - public: - typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; typedef eGlueCube stored_type; typedef const eGlueCube& ea_type; typedef const eGlueCube& aligned_ea_type; - static const bool use_at = eGlueCube::use_at; - static const bool use_mp = eGlueCube::use_mp; - static const bool has_subview = eGlueCube::has_subview; + static constexpr bool use_at = eGlueCube::use_at; + static constexpr bool use_mp = eGlueCube::use_mp; + static constexpr bool has_subview = eGlueCube::has_subview; arma_aligned const eGlueCube& Q; @@ -459,9 +377,9 @@ class ProxyCube< eGlueCube > arma_inline uword get_n_slices() const { return Q.get_n_slices(); } arma_inline uword get_n_elem() const { return Q.get_n_elem(); } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q; } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } @@ -469,25 +387,26 @@ class ProxyCube< eGlueCube > template arma_inline bool is_alias(const Cube& X) const { return (Q.P1.is_alias(X) || Q.P2.is_alias(X)); } + template + arma_inline bool has_overlap(const subview_cube& X) const { return (Q.P1.has_overlap(X) || Q.P2.has_overlap(X)); } + arma_inline bool is_aligned() const { return Q.P1.is_aligned() && Q.P2.is_aligned(); } }; template -class ProxyCube< mtOpCube > +struct ProxyCube< mtOpCube > { - public: - typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const elem_type* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube Q; @@ -503,15 +422,18 @@ class ProxyCube< mtOpCube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; @@ -519,19 +441,17 @@ class ProxyCube< mtOpCube > template -class ProxyCube< mtGlueCube > +struct ProxyCube< mtGlueCube > { - public: - typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef Cube stored_type; typedef const elem_type* ea_type; typedef const Cube& aligned_ea_type; - static const bool use_at = false; - static const bool use_mp = false; - static const bool has_subview = false; + static constexpr bool use_at = false; + static constexpr bool use_mp = false; + static constexpr bool has_subview = false; arma_aligned const Cube Q; @@ -547,15 +467,18 @@ class ProxyCube< mtGlueCube > arma_inline uword get_n_slices() const { return Q.n_slices; } arma_inline uword get_n_elem() const { return Q.n_elem; } - arma_inline elem_type operator[] (const uword i) const { return Q[i]; } - arma_inline elem_type at (const uword row, const uword col, const uword slice) const { return Q.at(row, col, slice); } - arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } + arma_inline elem_type operator[] (const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword r, const uword c, const uword s) const { return Q.at(r, c, s); } + arma_inline elem_type at_alt (const uword i) const { return Q.at_alt(i); } arma_inline ea_type get_ea() const { return Q.memptr(); } arma_inline aligned_ea_type get_aligned_ea() const { return Q; } template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } + + template + constexpr bool has_overlap(const subview_cube&) const { return false; } arma_inline bool is_aligned() const { return memory::is_aligned(Q.memptr()); } }; diff --git a/src/armadillo_bits/Row_bones.hpp b/src/armadillo_bits/Row_bones.hpp index fcba5662..96dff225 100644 --- a/src/armadillo_bits/Row_bones.hpp +++ b/src/armadillo_bits/Row_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,20 +29,29 @@ class Row : public Mat typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_col = false; - static const bool is_row = true; - static const bool is_xvec = false; + static constexpr bool is_col = false; + static constexpr bool is_row = true; + static constexpr bool is_xvec = false; + + inline Row(); + inline Row(const Row& X); - inline Row(); - inline Row(const Row& X); inline explicit Row(const uword N); inline explicit Row(const uword in_rows, const uword in_cols); inline explicit Row(const SizeMat& s); + template inline explicit Row(const uword N, const arma_initmode_indicator&); + template inline explicit Row(const uword in_rows, const uword in_cols, const arma_initmode_indicator&); + template inline explicit Row(const SizeMat& s, const arma_initmode_indicator&); + template inline Row(const uword n_elem, const fill::fill_class& f); template inline Row(const uword in_rows, const uword in_cols, const fill::fill_class& f); template inline Row(const SizeMat& s, const fill::fill_class& f); + inline Row(const uword N, const fill::scalar_holder f); + inline Row(const uword in_rows, const uword in_cols, const fill::scalar_holder f); + inline Row(const SizeMat& s, const fill::scalar_holder f); + inline Row(const char* text); inline Row& operator=(const char* text); @@ -50,13 +61,14 @@ class Row : public Mat inline Row(const std::vector& x); inline Row& operator=(const std::vector& x); - #if defined(ARMA_USE_CXX11) inline Row(const std::initializer_list& list); inline Row& operator=(const std::initializer_list& list); inline Row(Row&& m); inline Row& operator=(Row&& m); - #endif + + // inline Row(Mat&& m); + // inline Row& operator=(Mat&& m); inline Row& operator=(const eT val); inline Row& operator=(const Row& X); @@ -79,13 +91,13 @@ class Row : public Mat inline Row(const subview_cube& X); inline Row& operator=(const subview_cube& X); - inline mat_injector operator<<(const eT val); + arma_frown("use braced initialiser list instead") inline mat_injector operator<<(const eT val); - arma_inline const Op,op_htrans> t() const; - arma_inline const Op,op_htrans> ht() const; - arma_inline const Op,op_strans> st() const; + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; - arma_inline const Op,op_strans> as_col() const; + arma_warn_unused arma_inline const Op,op_strans> as_col() const; arma_inline subview_row col(const uword col_num); arma_inline const subview_row col(const uword col_num) const; @@ -129,15 +141,17 @@ class Row : public Mat template inline void shed_cols(const Base& indices); - inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero = true); + arma_deprecated inline void insert_cols(const uword col_num, const uword N, const bool set_to_zero); + inline void insert_cols(const uword col_num, const uword N); + template inline void insert_cols(const uword col_num, const Base& X); - arma_inline arma_warn_unused eT& at(const uword i); - arma_inline arma_warn_unused const eT& at(const uword i) const; + arma_warn_unused arma_inline eT& at(const uword i); + arma_warn_unused arma_inline const eT& at(const uword i) const; - arma_inline arma_warn_unused eT& at(const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at(const uword in_row, const uword in_col) const; typedef eT* row_iterator; @@ -160,7 +174,7 @@ class Row : public Mat public: - #ifdef ARMA_EXTRA_ROW_PROTO + #if defined(ARMA_EXTRA_ROW_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_ROW_PROTO) #endif }; @@ -173,7 +187,7 @@ class Row::fixed : public Row { private: - static const bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); + static constexpr bool use_extra = (fixed_n_elem > arma_config::mat_prealloc); arma_align_mem eT mem_local_extra[ (use_extra) ? fixed_n_elem : 1 ]; @@ -185,9 +199,9 @@ class Row::fixed : public Row typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_col = false; - static const bool is_row = true; - static const bool is_xvec = false; + static constexpr bool is_col = false; + static constexpr bool is_row = true; + static constexpr bool is_xvec = false; static const uword n_rows; // value provided below the class definition static const uword n_cols; // value provided below the class definition @@ -197,6 +211,7 @@ class Row::fixed : public Row arma_inline fixed(const fixed& X); inline fixed(const subview_cube& X); + inline fixed(const fill::scalar_holder f); template inline fixed(const fill::fill_class& f); template inline fixed(const Base& A); template inline fixed(const Base& A, const Base& B); @@ -215,10 +230,8 @@ class Row::fixed : public Row using Row::operator(); - #if defined(ARMA_USE_CXX11) - inline fixed(const std::initializer_list& list); - inline Row& operator=(const std::initializer_list& list); - #endif + inline fixed(const std::initializer_list& list); + inline Row& operator=(const std::initializer_list& list); arma_inline Row& operator=(const fixed& X); @@ -227,30 +240,30 @@ class Row::fixed : public Row template inline Row& operator=(const eGlue& X); #endif - arma_inline const Op< Row_fixed_type, op_htrans > t() const; - arma_inline const Op< Row_fixed_type, op_htrans > ht() const; - arma_inline const Op< Row_fixed_type, op_strans > st() const; + arma_warn_unused arma_inline const Op< Row_fixed_type, op_htrans > t() const; + arma_warn_unused arma_inline const Op< Row_fixed_type, op_htrans > ht() const; + arma_warn_unused arma_inline const Op< Row_fixed_type, op_strans > st() const; - arma_inline arma_warn_unused const eT& at_alt (const uword i) const; + arma_warn_unused arma_inline const eT& at_alt (const uword i) const; - arma_inline arma_warn_unused eT& operator[] (const uword i); - arma_inline arma_warn_unused const eT& operator[] (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword i); - arma_inline arma_warn_unused const eT& at (const uword i) const; - arma_inline arma_warn_unused eT& operator() (const uword i); - arma_inline arma_warn_unused const eT& operator() (const uword i) const; + arma_warn_unused arma_inline eT& operator[] (const uword i); + arma_warn_unused arma_inline const eT& operator[] (const uword i) const; + arma_warn_unused arma_inline eT& at (const uword i); + arma_warn_unused arma_inline const eT& at (const uword i) const; + arma_warn_unused arma_inline eT& operator() (const uword i); + arma_warn_unused arma_inline const eT& operator() (const uword i) const; - arma_inline arma_warn_unused eT& at (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& at (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT& operator() (const uword in_row, const uword in_col); - arma_inline arma_warn_unused const eT& operator() (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& at (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline eT& operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline const eT& operator() (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused eT* memptr(); - arma_inline arma_warn_unused const eT* memptr() const; + arma_warn_unused arma_inline eT* memptr(); + arma_warn_unused arma_inline const eT* memptr() const; - arma_hot inline const Row& fill(const eT val); - arma_hot inline const Row& zeros(); - arma_hot inline const Row& ones(); + inline const Row& fill(const eT val); + inline const Row& zeros(); + inline const Row& ones(); }; diff --git a/src/armadillo_bits/Row_meat.hpp b/src/armadillo_bits/Row_meat.hpp index a0310a8d..f61a1792 100644 --- a/src/armadillo_bits/Row_meat.hpp +++ b/src/armadillo_bits/Row_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -48,6 +50,12 @@ Row::Row(const uword in_n_elem) : Mat(arma_vec_indicator(), 1, in_n_elem, 2) { arma_extra_debug_sigprint(); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -60,6 +68,12 @@ Row::Row(const uword in_n_rows, const uword in_n_cols) arma_extra_debug_sigprint(); Mat::init_warm(in_n_rows, in_n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -72,6 +86,70 @@ Row::Row(const SizeMat& s) arma_extra_debug_sigprint(); Mat::init_warm(s.n_rows, s.n_cols); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const uword in_n_elem, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } + } + + + +//! internal use only +template +template +inline +Row::Row(const SizeMat& s, const arma_initmode_indicator&) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + if(do_zeros) + { + arma_extra_debug_print("Row::constructor: zeroing memory"); + arrayops::fill_zeros(Mat::memptr(), Mat::n_elem); + } } @@ -119,6 +197,46 @@ Row::Row(const SizeMat& s, const fill::fill_class& f) +template +inline +Row::Row(const uword in_n_elem, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 1, in_n_elem, 2) + { + arma_extra_debug_sigprint(); + + (*this).fill(f.scalar); + } + + + +template +inline +Row::Row(const uword in_n_rows, const uword in_n_cols, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(in_n_rows, in_n_cols); + + (*this).fill(f.scalar); + } + + + +template +inline +Row::Row(const SizeMat& s, const fill::scalar_holder f) + : Mat(arma_vec_indicator(), 0, 0, 2) + { + arma_extra_debug_sigprint(); + + Mat::init_warm(s.n_rows, s.n_cols); + + (*this).fill(f.scalar); + } + + + template inline Row::Row(const char* text) @@ -193,10 +311,9 @@ Row::Row(const std::vector& x) { arma_extra_debug_sigprint_this(this); - if(x.size() > 0) - { - arrayops::copy( Mat::memptr(), &(x[0]), uword(x.size()) ); - } + const uword N = uword(x.size()); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } } @@ -209,114 +326,160 @@ Row::operator=(const std::vector& x) { arma_extra_debug_sigprint(); - Mat::init_warm(1, uword(x.size())); + const uword N = uword(x.size()); - if(x.size() > 0) - { - arrayops::copy( Mat::memptr(), &(x[0]), uword(x.size()) ); - } + Mat::init_warm(1, N); + + if(N > 0) { arrayops::copy( Mat::memptr(), &(x[0]), N ); } return *this; } -#if defined(ARMA_USE_CXX11) +template +inline +Row::Row(const std::initializer_list& list) + : Mat(arma_vec_indicator(), 1, uword(list.size()), 2) + { + arma_extra_debug_sigprint_this(this); - template - inline - Row::Row(const std::initializer_list& list) - : Mat(arma_vec_indicator(), 2) - { - arma_extra_debug_sigprint(); - - (*this).operator=(list); - } + const uword N = uword(list.size()); + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + } + + + +template +inline +Row& +Row::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + const uword N = uword(list.size()); - template - inline - Row& - Row::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - Mat tmp(list); - - arma_debug_check( ((tmp.n_elem > 0) && (tmp.is_vec() == false)), "Mat::init(): requested size is not compatible with row vector layout" ); - - access::rw(tmp.n_rows) = 1; - access::rw(tmp.n_cols) = tmp.n_elem; - - (*this).steal_mem(tmp); - - return *this; - } + Mat::init_warm(1, N); + if(N > 0) { arrayops::copy( Mat::memptr(), list.begin(), N ); } + return *this; + } + + + +template +inline +Row::Row(Row&& X) + : Mat(arma_vec_indicator(), 2) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - template - inline - Row::Row(Row&& X) - : Mat(arma_vec_indicator(), 2) + access::rw(Mat::n_rows) = 1; + access::rw(Mat::n_cols) = X.n_cols; + access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::n_alloc) = X.n_alloc; + + if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - - access::rw(Mat::n_rows) = 1; - access::rw(Mat::n_cols) = X.n_cols; - access::rw(Mat::n_elem) = X.n_elem; + access::rw(Mat::mem_state) = X.mem_state; + access::rw(Mat::mem) = X.mem; - if( ((X.mem_state == 0) && (X.n_elem > arma_config::mat_prealloc)) || (X.mem_state == 1) || (X.mem_state == 2) ) - { - access::rw(Mat::mem_state) = X.mem_state; - access::rw(Mat::mem) = X.mem; - - access::rw(X.n_rows) = 1; - access::rw(X.n_cols) = 0; - access::rw(X.n_elem) = 0; - access::rw(X.mem_state) = 0; - access::rw(X.mem) = 0; - } - else - { - (*this).init_cold(); - - arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); - - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) ) - { - access::rw(X.n_rows) = 1; - access::rw(X.n_cols) = 0; - access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; - } - } + access::rw(X.n_rows) = 1; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.n_alloc) = 0; + access::rw(X.mem_state) = 0; + access::rw(X.mem) = nullptr; } - - - - template - inline - Row& - Row::operator=(Row&& X) + else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + (*this).init_cold(); - (*this).steal_mem(X); + arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); - if( (X.mem_state == 0) && (X.n_elem <= arma_config::mat_prealloc) && (this != &X) ) + if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) { - access::rw(X.n_rows) = 1; - access::rw(X.n_cols) = 0; - access::rw(X.n_elem) = 0; - access::rw(X.mem) = 0; + access::rw(X.n_rows) = 1; + access::rw(X.n_cols) = 0; + access::rw(X.n_elem) = 0; + access::rw(X.mem) = nullptr; } - - return *this; } + } + + + +template +inline +Row& +Row::operator=(Row&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); -#endif + (*this).steal_mem(X, true); + + return *this; + } + + + +// template +// inline +// Row::Row(Mat&& X) +// : Mat(arma_vec_indicator(), 2) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_rows != 1) { const Mat& XX = X; Mat::operator=(XX); return; } +// +// access::rw(Mat::n_rows) = 1; +// access::rw(Mat::n_cols) = X.n_cols; +// access::rw(Mat::n_elem) = X.n_elem; +// access::rw(Mat::n_alloc) = X.n_alloc; +// +// if( (X.n_alloc > arma_config::mat_prealloc) || (X.mem_state == 1) || (X.mem_state == 2) ) +// { +// access::rw(Mat::mem_state) = X.mem_state; +// access::rw(Mat::mem) = X.mem; +// +// access::rw(X.n_cols) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.n_alloc) = 0; +// access::rw(X.mem_state) = 0; +// access::rw(X.mem) = nullptr; +// } +// else // condition: (X.n_alloc <= arma_config::mat_prealloc) || (X.mem_state == 0) || (X.mem_state == 3) +// { +// (*this).init_cold(); +// +// arrayops::copy( (*this).memptr(), X.mem, X.n_elem ); +// +// if( (X.mem_state == 0) && (X.n_alloc <= arma_config::mat_prealloc) ) +// { +// access::rw(X.n_cols) = 0; +// access::rw(X.n_elem) = 0; +// access::rw(X.mem) = nullptr; +// } +// } +// } +// +// +// +// template +// inline +// Row& +// Row::operator=(Mat&& X) +// { +// arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); +// +// if(X.n_rows != 1) { const Mat& XX = X; Mat::operator=(XX); return *this; } +// +// (*this).steal_mem(X, true); +// +// return *this; +// } @@ -561,7 +724,7 @@ Row::col(const uword in_col1) { arma_extra_debug_sigprint(); - arma_debug_check( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used" ); return subview_row(*this, 0, in_col1, 1); } @@ -575,7 +738,7 @@ Row::col(const uword in_col1) const { arma_extra_debug_sigprint(); - arma_debug_check( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( (in_col1 >= Mat::n_cols), "Row::col(): indices out of bounds or incorrectly used" ); return subview_row(*this, 0, in_col1, 1); } @@ -589,7 +752,7 @@ Row::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -605,7 +768,7 @@ Row::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::cols(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -621,7 +784,7 @@ Row::subvec(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -637,7 +800,7 @@ Row::subvec(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols) ), "Row::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -685,7 +848,7 @@ Row::subvec(const span& col_span) const uword in_col2 = col_span.b; const uword subvec_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used" ); return subview_row(*this, 0, in_col1, subvec_n_cols); } @@ -707,7 +870,7 @@ Row::subvec(const span& col_span) const const uword in_col2 = col_span.b; const uword subvec_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= local_n_cols)) ), "Row::subvec(): indices out of bounds or incorrectly used" ); return subview_row(*this, 0, in_col1, subvec_n_cols); } @@ -747,7 +910,7 @@ Row::subvec(const uword start_col, const SizeMat& s) arma_debug_check( (s.n_rows != 1), "Row::subvec(): given size does not specify a row vector" ); - arma_debug_check( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); return subview_row(*this, 0, start_col, s.n_cols); } @@ -763,7 +926,7 @@ Row::subvec(const uword start_col, const SizeMat& s) const arma_debug_check( (s.n_rows != 1), "Row::subvec(): given size does not specify a row vector" ); - arma_debug_check( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_col >= Mat::n_cols) || ((start_col + s.n_cols) > Mat::n_cols) ), "Row::subvec(): size out of bounds" ); return subview_row(*this, 0, start_col, s.n_cols); } @@ -777,7 +940,7 @@ Row::head(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_cols), "Row::head(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_cols), "Row::head(): size out of bounds" ); return subview_row(*this, 0, 0, N); } @@ -791,7 +954,7 @@ Row::head(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_cols), "Row::head(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_cols), "Row::head(): size out of bounds" ); return subview_row(*this, 0, 0, N); } @@ -805,7 +968,7 @@ Row::tail(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_cols), "Row::tail(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_cols), "Row::tail(): size out of bounds" ); const uword start_col = Mat::n_cols - N; @@ -821,7 +984,7 @@ Row::tail(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > Mat::n_cols), "Row::tail(): size out of bounds"); + arma_debug_check_bounds( (N > Mat::n_cols), "Row::tail(): size out of bounds" ); const uword start_col = Mat::n_cols - N; @@ -886,7 +1049,7 @@ Row::shed_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= Mat::n_cols, "Row::shed_col(): index out of bounds"); + arma_debug_check_bounds( col_num >= Mat::n_cols, "Row::shed_col(): index out of bounds" ); shed_cols(col_num, col_num); } @@ -901,7 +1064,7 @@ Row::shed_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= Mat::n_cols), "Row::shed_cols(): indices out of bounds or incorrectly used" @@ -910,7 +1073,7 @@ Row::shed_cols(const uword in_col1, const uword in_col2) const uword n_keep_front = in_col1; const uword n_keep_back = Mat::n_cols - (in_col2 + 1); - Row X(n_keep_front + n_keep_back); + Row X(n_keep_front + n_keep_back, arma_nozeros_indicator()); eT* X_mem = X.memptr(); const eT* t_mem = (*this).memptr(); @@ -944,8 +1107,6 @@ Row::shed_cols(const Base& indices) -//! insert N cols at the specified col position, -//! optionally setting the elements of the inserted cols to zero template inline void @@ -953,38 +1114,48 @@ Row::insert_cols(const uword col_num, const uword N, const bool set_to_zero) { arma_extra_debug_sigprint(); + arma_ignore(set_to_zero); + + (*this).insert_cols(col_num, N); + } + + + +template +inline +void +Row::insert_cols(const uword col_num, const uword N) + { + arma_extra_debug_sigprint(); + const uword t_n_cols = Mat::n_cols; const uword A_n_cols = col_num; const uword B_n_cols = t_n_cols - col_num; // insertion at col_num == n_cols is in effect an append operation - arma_debug_check( (col_num > t_n_cols), "Row::insert_cols(): index out of bounds"); + arma_debug_check_bounds( (col_num > t_n_cols), "Row::insert_cols(): index out of bounds" ); + + if(N == 0) { return; } - if(N > 0) + Row out(t_n_cols + N, arma_nozeros_indicator()); + + eT* out_mem = out.memptr(); + const eT* t_mem = (*this).memptr(); + + if(A_n_cols > 0) { - Row out(t_n_cols + N); - - eT* out_mem = out.memptr(); - const eT* t_mem = (*this).memptr(); - - if(A_n_cols > 0) - { - arrayops::copy( out_mem, t_mem, A_n_cols ); - } - - if(B_n_cols > 0) - { - arrayops::copy( &(out_mem[col_num + N]), &(t_mem[col_num]), B_n_cols ); - } - - if(set_to_zero) - { - arrayops::inplace_set( &(out_mem[col_num]), eT(0), N ); - } - - Mat::steal_mem(out); + arrayops::copy( out_mem, t_mem, A_n_cols ); + } + + if(B_n_cols > 0) + { + arrayops::copy( &(out_mem[col_num + N]), &(t_mem[col_num]), B_n_cols ); } + + arrayops::fill_zeros( &(out_mem[col_num]), N ); + + Mat::steal_mem(out); } @@ -1006,7 +1177,6 @@ Row::insert_cols(const uword col_num, const Base& X) template arma_inline -arma_warn_unused eT& Row::at(const uword i) { @@ -1017,7 +1187,6 @@ Row::at(const uword i) template arma_inline -arma_warn_unused const eT& Row::at(const uword i) const { @@ -1028,7 +1197,6 @@ Row::at(const uword i) const template arma_inline -arma_warn_unused eT& Row::at(const uword, const uword in_col) { @@ -1039,7 +1207,6 @@ Row::at(const uword, const uword in_col) template arma_inline -arma_warn_unused const eT& Row::at(const uword, const uword in_col) const { @@ -1055,7 +1222,7 @@ Row::begin_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds" ); return Mat::memptr(); } @@ -1069,7 +1236,7 @@ Row::begin_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::begin_row(): index out of bounds" ); return Mat::memptr(); } @@ -1083,7 +1250,7 @@ Row::end_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds" ); return Mat::memptr() + Mat::n_cols; } @@ -1097,7 +1264,7 @@ Row::end_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= Mat::n_rows), "Row::end_row(): index out of bounds" ); return Mat::memptr() + Mat::n_cols; } @@ -1111,6 +1278,15 @@ Row::fixed::fixed() : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) { arma_extra_debug_sigprint_this(this); + + if(arma_config::zero_init) + { + arma_extra_debug_print("Row::fixed::constructor: zeroing memory"); + + eT* mem_use = (use_extra) ? &(mem_local_extra[0]) : &(Mat::mem_local[0]); + + arrayops::inplace_set_fixed( mem_use, eT(0) ); + } } @@ -1144,6 +1320,19 @@ Row::fixed::fixed(const subview_cube& X) +template +template +inline +Row::fixed::fixed(const fill::scalar_holder f) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); + + (*this).fill(f.scalar); + } + + + template template template @@ -1153,11 +1342,11 @@ Row::fixed::fixed(const fill::fill_class&) { arma_extra_debug_sigprint_this(this); - if(is_same_type::yes) (*this).zeros(); - if(is_same_type::yes) (*this).ones(); - if(is_same_type::yes) (*this).eye(); - if(is_same_type::yes) (*this).randu(); - if(is_same_type::yes) (*this).randn(); + if(is_same_type::yes) { (*this).zeros(); } + if(is_same_type::yes) { (*this).ones(); } + if(is_same_type::yes) { (*this).eye(); } + if(is_same_type::yes) { (*this).randu(); } + if(is_same_type::yes) { (*this).randn(); } } @@ -1302,43 +1491,39 @@ Row::fixed::operator=(const subview_cube& X) -#if defined(ARMA_USE_CXX11) +template +template +inline +Row::fixed::fixed(const std::initializer_list& list) + : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) + { + arma_extra_debug_sigprint_this(this); - template - template - inline - Row::fixed::fixed(const std::initializer_list& list) - : Row( arma_fixed_indicator(), fixed_n_elem, ((use_extra) ? mem_local_extra : Mat::mem_local) ) - { - arma_extra_debug_sigprint_this(this); - - (*this).operator=(list); - } + (*this).operator=(list); + } + + + +template +template +inline +Row& +Row::fixed::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + const uword N = uword(list.size()); + arma_debug_check( (N > fixed_n_elem), "Row::fixed: initialiser list is too long" ); - template - template - inline - Row& - Row::fixed::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - const uword N = uword(list.size()); - - arma_debug_check( (N > fixed_n_elem), "Row::fixed: initialiser list is too long" ); - - eT* this_mem = (*this).memptr(); - - arrayops::copy( this_mem, list.begin(), N ); - - for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } - - return *this; - } + eT* this_mem = (*this).memptr(); -#endif + arrayops::copy( this_mem, list.begin(), N ); + + for(uword iq=N; iq < fixed_n_elem; ++iq) { this_mem[iq] = eT(0); } + + return *this; + } @@ -1475,7 +1660,6 @@ Row::fixed::st() const template template arma_inline -arma_warn_unused const eT& Row::fixed::at_alt(const uword ii) const { @@ -1497,7 +1681,6 @@ Row::fixed::at_alt(const uword ii) const template template arma_inline -arma_warn_unused eT& Row::fixed::operator[] (const uword ii) { @@ -1509,7 +1692,6 @@ Row::fixed::operator[] (const uword ii) template template arma_inline -arma_warn_unused const eT& Row::fixed::operator[] (const uword ii) const { @@ -1521,7 +1703,6 @@ Row::fixed::operator[] (const uword ii) const template template arma_inline -arma_warn_unused eT& Row::fixed::at(const uword ii) { @@ -1533,7 +1714,6 @@ Row::fixed::at(const uword ii) template template arma_inline -arma_warn_unused const eT& Row::fixed::at(const uword ii) const { @@ -1545,11 +1725,10 @@ Row::fixed::at(const uword ii) const template template arma_inline -arma_warn_unused eT& Row::fixed::operator() (const uword ii) { - arma_debug_check( (ii >= fixed_n_elem), "Row::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Row::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; } @@ -1559,11 +1738,10 @@ Row::fixed::operator() (const uword ii) template template arma_inline -arma_warn_unused const eT& Row::fixed::operator() (const uword ii) const { - arma_debug_check( (ii >= fixed_n_elem), "Row::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= fixed_n_elem), "Row::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[ii] : Mat::mem_local[ii]; } @@ -1573,7 +1751,6 @@ Row::fixed::operator() (const uword ii) const template template arma_inline -arma_warn_unused eT& Row::fixed::at(const uword, const uword in_col) { @@ -1585,7 +1762,6 @@ Row::fixed::at(const uword, const uword in_col) template template arma_inline -arma_warn_unused const eT& Row::fixed::at(const uword, const uword in_col) const { @@ -1597,11 +1773,10 @@ Row::fixed::at(const uword, const uword in_col) const template template arma_inline -arma_warn_unused eT& Row::fixed::operator() (const uword in_row, const uword in_col) { - arma_debug_check( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; } @@ -1611,11 +1786,10 @@ Row::fixed::operator() (const uword in_row, const uword in_col template template arma_inline -arma_warn_unused const eT& Row::fixed::operator() (const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row > 0) || (in_col >= fixed_n_elem)), "Row::operator(): index out of bounds" ); return (use_extra) ? mem_local_extra[in_col] : Mat::mem_local[in_col]; } @@ -1625,7 +1799,6 @@ Row::fixed::operator() (const uword in_row, const uword in_col template template arma_inline -arma_warn_unused eT* Row::fixed::memptr() { @@ -1637,7 +1810,6 @@ Row::fixed::memptr() template template arma_inline -arma_warn_unused const eT* Row::fixed::memptr() const { @@ -1648,7 +1820,6 @@ Row::fixed::memptr() const template template -arma_hot inline const Row& Row::fixed::fill(const eT val) @@ -1666,7 +1837,6 @@ Row::fixed::fill(const eT val) template template -arma_hot inline const Row& Row::fixed::zeros() @@ -1684,7 +1854,6 @@ Row::fixed::zeros() template template -arma_hot inline const Row& Row::fixed::ones() @@ -1710,7 +1879,7 @@ Row::Row(const arma_fixed_indicator&, const uword in_n_elem, const eT* in_me -#ifdef ARMA_EXTRA_ROW_MEAT +#if defined(ARMA_EXTRA_ROW_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_ROW_MEAT) #endif diff --git a/src/armadillo_bits/SizeCube_bones.hpp b/src/armadillo_bits/SizeCube_bones.hpp index 1fb6219a..96b26af6 100644 --- a/src/armadillo_bits/SizeCube_bones.hpp +++ b/src/armadillo_bits/SizeCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/SizeCube_meat.hpp b/src/armadillo_bits/SizeCube_meat.hpp index 7cc88fb8..8354ca12 100644 --- a/src/armadillo_bits/SizeCube_meat.hpp +++ b/src/armadillo_bits/SizeCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -51,7 +53,7 @@ SizeCube::operator()(const uword dim) const if(dim == 1) { return n_cols; } if(dim == 2) { return n_slices; } - arma_debug_check(true, "size(): index out of bounds"); + arma_debug_check_bounds(true, "size(): index out of bounds"); return uword(1); } diff --git a/src/armadillo_bits/SizeMat_bones.hpp b/src/armadillo_bits/SizeMat_bones.hpp index 827c6278..6139d336 100644 --- a/src/armadillo_bits/SizeMat_bones.hpp +++ b/src/armadillo_bits/SizeMat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/SizeMat_meat.hpp b/src/armadillo_bits/SizeMat_meat.hpp index dcd8f59c..e00fd4f0 100644 --- a/src/armadillo_bits/SizeMat_meat.hpp +++ b/src/armadillo_bits/SizeMat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -48,7 +50,7 @@ SizeMat::operator()(const uword dim) const if(dim == 0) { return n_rows; } if(dim == 1) { return n_cols; } - arma_debug_check(true, "size(): index out of bounds"); + arma_debug_check_bounds(true, "size(): index out of bounds"); return uword(1); } diff --git a/src/armadillo_bits/SpBase_bones.hpp b/src/armadillo_bits/SpBase_bones.hpp index 34b7068e..d16bf476 100644 --- a/src/armadillo_bits/SpBase_bones.hpp +++ b/src/armadillo_bits/SpBase_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,14 +24,14 @@ template struct SpBase_eval_SpMat { - inline const derived& eval() const; + arma_warn_unused inline const derived& eval() const; }; template struct SpBase_eval_expr { - inline SpMat eval() const; //!< force the immediate evaluation of a delayed expression + arma_warn_unused inline SpMat eval() const; //!< force the immediate evaluation of a delayed expression }; @@ -52,9 +54,9 @@ struct SpBase arma_inline bool is_alias(const SpMat& X) const; - inline const SpOp t() const; //!< Hermitian transpose - inline const SpOp ht() const; //!< Hermitian transpose - inline const SpOp st() const; //!< simple transpose + arma_warn_unused inline const SpOp t() const; //!< Hermitian transpose + arma_warn_unused inline const SpOp ht() const; //!< Hermitian transpose + arma_warn_unused inline const SpOp st() const; //!< simple transpose arma_cold inline void print( const std::string extra_text = "") const; arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; @@ -68,8 +70,11 @@ struct SpBase arma_cold inline void raw_print_dense( const std::string extra_text = "") const; arma_cold inline void raw_print_dense(std::ostream& user_stream, const std::string extra_text = "") const; - inline arma_warn_unused elem_type min() const; - inline arma_warn_unused elem_type max() const; + arma_cold inline void brief_print( const std::string extra_text = "") const; + arma_cold inline void brief_print(std::ostream& user_stream, const std::string extra_text = "") const; + + arma_warn_unused inline elem_type min() const; + arma_warn_unused inline elem_type max() const; inline elem_type min(uword& index_of_min_val) const; inline elem_type max(uword& index_of_max_val) const; @@ -77,30 +82,33 @@ struct SpBase inline elem_type min(uword& row_of_min_val, uword& col_of_min_val) const; inline elem_type max(uword& row_of_max_val, uword& col_of_max_val) const; - inline arma_warn_unused uword index_min() const; - inline arma_warn_unused uword index_max() const; + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; - inline arma_warn_unused bool is_symmetric() const; - inline arma_warn_unused bool is_symmetric(const typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; - inline arma_warn_unused bool is_hermitian() const; - inline arma_warn_unused bool is_hermitian(const typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_zero(const typename get_pod_type::result tol = 0) const; - inline arma_warn_unused bool is_trimatu() const; - inline arma_warn_unused bool is_trimatl() const; - inline arma_warn_unused bool is_diagmat() const; - inline arma_warn_unused bool is_empty() const; - inline arma_warn_unused bool is_square() const; - inline arma_warn_unused bool is_vec() const; - inline arma_warn_unused bool is_colvec() const; - inline arma_warn_unused bool is_rowvec() const; - inline arma_warn_unused bool is_finite() const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline bool is_trimatu() const; + arma_warn_unused inline bool is_trimatl() const; + arma_warn_unused inline bool is_diagmat() const; + arma_warn_unused inline bool is_empty() const; + arma_warn_unused inline bool is_square() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_colvec() const; + arma_warn_unused inline bool is_rowvec() const; + arma_warn_unused inline bool is_finite() const; - inline const SpOp as_col() const; - inline const SpOp as_row() const; + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; + arma_warn_unused inline const SpOp as_col() const; + arma_warn_unused inline const SpOp as_row() const; }; diff --git a/src/armadillo_bits/SpBase_meat.hpp b/src/armadillo_bits/SpBase_meat.hpp index 4dd1894a..4ebc4243 100644 --- a/src/armadillo_bits/SpBase_meat.hpp +++ b/src/armadillo_bits/SpBase_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -69,109 +71,235 @@ SpBase::st() const template -arma_cold inline void SpBase::print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, true); } template -arma_cold inline void SpBase::print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, true); } template -arma_cold inline void SpBase::raw_print(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print(get_cout_stream(), tmp.M, false); } template -arma_cold inline void SpBase::raw_print(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_raw_print(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print(user_stream, tmp.M, false); } template -arma_cold inline void SpBase::print_dense(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_print_dense(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print_dense(get_cout_stream(), tmp.M, true); } template -arma_cold inline void SpBase::print_dense(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_print_dense(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print_dense(user_stream, tmp.M, true); } template -arma_cold inline void SpBase::raw_print_dense(const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_raw_print_dense(extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::print_dense(get_cout_stream(), tmp.M, false); } template -arma_cold inline void SpBase::raw_print_dense(std::ostream& user_stream, const std::string extra_text) const { + arma_extra_debug_sigprint(); + const unwrap_spmat tmp( (*this).get_ref() ); - tmp.M.impl_raw_print_dense(user_stream, extra_text); + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::print_dense(user_stream, tmp.M, false); + } + + + +template +inline +void +SpBase::brief_print(const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = get_cout_stream().width(); + + get_cout_stream() << extra_text << '\n'; + + get_cout_stream().width(orig_width); + } + + arma_ostream::brief_print(get_cout_stream(), tmp.M); } +template +inline +void +SpBase::brief_print(std::ostream& user_stream, const std::string extra_text) const + { + arma_extra_debug_sigprint(); + + const unwrap_spmat tmp( (*this).get_ref() ); + + if(extra_text.length() != 0) + { + const std::streamsize orig_width = user_stream.width(); + + user_stream << extra_text << '\n'; + + user_stream.width(orig_width); + } + + arma_ostream::brief_print(user_stream, tmp.M); + } + + + // // extra functions defined in SpBase_eval_SpMat @@ -204,7 +332,6 @@ SpBase_eval_expr::eval() const template inline -arma_warn_unused elem_type SpBase::min() const { @@ -215,7 +342,6 @@ SpBase::min() const template inline -arma_warn_unused elem_type SpBase::max() const { @@ -292,7 +418,6 @@ SpBase::max(uword& row_of_max_val, uword& col_of_max_val) co template inline -arma_warn_unused uword SpBase::index_min() const { @@ -316,7 +441,6 @@ SpBase::index_min() const template inline -arma_warn_unused uword SpBase::index_max() const { @@ -340,7 +464,6 @@ SpBase::index_max() const template inline -arma_warn_unused bool SpBase::is_symmetric() const { @@ -355,7 +478,6 @@ SpBase::is_symmetric() const template inline -arma_warn_unused bool SpBase::is_symmetric(const typename get_pod_type::result tol) const { @@ -370,7 +492,6 @@ SpBase::is_symmetric(const typename get_pod_type:: template inline -arma_warn_unused bool SpBase::is_hermitian() const { @@ -385,7 +506,6 @@ SpBase::is_hermitian() const template inline -arma_warn_unused bool SpBase::is_hermitian(const typename get_pod_type::result tol) const { @@ -400,7 +520,63 @@ SpBase::is_hermitian(const typename get_pod_type:: template inline -arma_warn_unused +bool +SpBase::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + arma_debug_check( (tol < T(0)), "is_zero(): parameter 'tol' must be >= 0" ); + + const SpProxy P( (*this).get_ref() ); + + if(P.get_n_elem() == 0) { return false; } + + if(P.get_n_nonzero() == 0) { return true; } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat::stored_type> U(P.Q); + + return arrayops::is_zero(U.M.values, U.M.n_nonzero, tol); + } + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + if(is_cx::yes) + { + while(it != it_end) + { + const elem_type val = (*it); + + const T val_real = access::tmp_real(val); + const T val_imag = access::tmp_imag(val); + + if(eop_aux::arma_abs(val_real) > tol) { return false; } + if(eop_aux::arma_abs(val_imag) > tol) { return false; } + + ++it; + } + } + else // not complex + { + while(it != it_end) + { + if(eop_aux::arma_abs(*it) > tol) { return false; } + + ++it; + } + } + + return true; + } + + + +template +inline bool SpBase::is_trimatu() const { @@ -426,7 +602,6 @@ SpBase::is_trimatu() const template inline -arma_warn_unused bool SpBase::is_trimatl() const { @@ -452,7 +627,6 @@ SpBase::is_trimatl() const template inline -arma_warn_unused bool SpBase::is_diagmat() const { @@ -476,7 +650,6 @@ SpBase::is_diagmat() const template inline -arma_warn_unused bool SpBase::is_empty() const { @@ -491,7 +664,6 @@ SpBase::is_empty() const template inline -arma_warn_unused bool SpBase::is_square() const { @@ -506,7 +678,6 @@ SpBase::is_square() const template inline -arma_warn_unused bool SpBase::is_vec() const { @@ -523,7 +694,6 @@ SpBase::is_vec() const template inline -arma_warn_unused bool SpBase::is_colvec() const { @@ -540,7 +710,6 @@ SpBase::is_colvec() const template inline -arma_warn_unused bool SpBase::is_rowvec() const { @@ -557,22 +726,23 @@ SpBase::is_rowvec() const template inline -arma_warn_unused bool SpBase::is_finite() const { arma_extra_debug_sigprint(); - const SpProxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } if(is_SpMat::stored_type>::value) { - const unwrap_spmat::stored_type> U(P.Q); + const unwrap_spmat U( (*this).get_ref() ); - return U.M.is_finite(); + return U.M.internal_is_finite(); } else { + const SpProxy P( (*this).get_ref() ); + typename SpProxy::const_iterator_type it = P.begin(); typename SpProxy::const_iterator_type it_end = P.end(); @@ -590,22 +760,23 @@ SpBase::is_finite() const template inline -arma_warn_unused bool SpBase::has_inf() const { arma_extra_debug_sigprint(); - const SpProxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } if(is_SpMat::stored_type>::value) { - const unwrap_spmat::stored_type> U(P.Q); + const unwrap_spmat U( (*this).get_ref() ); - return U.M.has_inf(); + return U.M.internal_has_inf(); } else { + const SpProxy P( (*this).get_ref() ); + typename SpProxy::const_iterator_type it = P.begin(); typename SpProxy::const_iterator_type it_end = P.end(); @@ -623,22 +794,23 @@ SpBase::has_inf() const template inline -arma_warn_unused bool SpBase::has_nan() const { arma_extra_debug_sigprint(); - const SpProxy P( (*this).get_ref() ); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } if(is_SpMat::stored_type>::value) { - const unwrap_spmat::stored_type> U(P.Q); + const unwrap_spmat U( (*this).get_ref() ); - return U.M.has_nan(); + return U.M.internal_has_nan(); } else { + const SpProxy P( (*this).get_ref() ); + typename SpProxy::const_iterator_type it = P.begin(); typename SpProxy::const_iterator_type it_end = P.end(); @@ -654,6 +826,40 @@ SpBase::has_nan() const +template +inline +bool +SpBase::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + if(is_SpMat::stored_type>::value) + { + const unwrap_spmat U( (*this).get_ref() ); + + return U.M.internal_has_nonfinite(); + } + else + { + const SpProxy P( (*this).get_ref() ); + + typename SpProxy::const_iterator_type it = P.begin(); + typename SpProxy::const_iterator_type it_end = P.end(); + + while(it != it_end) + { + if(arma_isfinite(*it) == false) { return true; } + ++it; + } + } + + return false; + } + + + template inline const SpOp diff --git a/src/armadillo_bits/SpCol_bones.hpp b/src/armadillo_bits/SpCol_bones.hpp index 5e6b001a..b49f1a58 100644 --- a/src/armadillo_bits/SpCol_bones.hpp +++ b/src/armadillo_bits/SpCol_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class SpCol : public SpMat typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; inline SpCol(); @@ -54,6 +56,10 @@ class SpCol : public SpMat template inline explicit SpCol(const SpBase& A, const SpBase& B); + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + inline void shed_row (const uword row_num); inline void shed_rows(const uword in_row1, const uword in_row2); @@ -70,7 +76,7 @@ class SpCol : public SpMat inline const_row_iterator end_row (const uword row_num = 0) const; - #ifdef ARMA_EXTRA_SPCOL_PROTO + #if defined(ARMA_EXTRA_SPCOL_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPCOL_PROTO) #endif }; diff --git a/src/armadillo_bits/SpCol_meat.hpp b/src/armadillo_bits/SpCol_meat.hpp index a6c4da50..9b3c824e 100644 --- a/src/armadillo_bits/SpCol_meat.hpp +++ b/src/armadillo_bits/SpCol_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -200,6 +202,36 @@ SpCol::SpCol +template +inline +const SpOp,spop_htrans> +SpCol::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpCol::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpCol::st() const + { + return SpOp,spop_strans>(*this); + } + + + //! remove specified row template inline @@ -208,7 +240,7 @@ SpCol::shed_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= SpMat::n_rows, "SpCol::shed_row(): out of bounds"); + arma_debug_check_bounds( row_num >= SpMat::n_rows, "SpCol::shed_row(): out of bounds" ); shed_rows(row_num, row_num); } @@ -223,7 +255,7 @@ SpCol::shed_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= SpMat::n_rows), "SpCol::shed_rows(): indices out of bounds or incorrectly used" @@ -239,14 +271,14 @@ SpCol::shed_rows(const uword in_row1, const uword in_row2) for(uword i = 0; i < SpMat::n_nonzero; ++i) { // Start position found? - if (SpMat::row_indices[i] >= in_row1 && !start_found) + if(SpMat::row_indices[i] >= in_row1 && !start_found) { start = i; start_found = true; } // End position found? - if (SpMat::row_indices[i] > in_row2) + if(SpMat::row_indices[i] > in_row2) { end = i; end_found = true; @@ -254,13 +286,13 @@ SpCol::shed_rows(const uword in_row1, const uword in_row2) } } - if (!end_found) + if(!end_found) { end = SpMat::n_nonzero; } // Now we can make the copy. - if (start != end) + if(start != end) { const uword elem_diff = end - start; @@ -268,14 +300,14 @@ SpCol::shed_rows(const uword in_row1, const uword in_row2) uword* new_row_indices = memory::acquire(SpMat::n_nonzero - elem_diff); // Copy before the section we are dropping (if it exists). - if (start > 0) + if(start > 0) { arrayops::copy(new_values, SpMat::values, start); arrayops::copy(new_row_indices, SpMat::row_indices, start); } // Copy after the section we are dropping (if it exists). - if (end != SpMat::n_nonzero) + if(end != SpMat::n_nonzero) { arrayops::copy(new_values + start, SpMat::values + end, (SpMat::n_nonzero - end)); arrayops::copy(new_row_indices + start, SpMat::row_indices + end, (SpMat::n_nonzero - end)); @@ -311,11 +343,11 @@ SpCol::shed_rows(const uword in_row1, const uword in_row2) // // arma_debug_check(set_to_zero == false, "SpCol::insert_rows(): cannot set nonzero values"); // -// arma_debug_check((row_num > SpMat::n_rows), "SpCol::insert_rows(): out of bounds"); +// arma_debug_check_bounds((row_num > SpMat::n_rows), "SpCol::insert_rows(): out of bounds"); // // for(uword row = 0; row < SpMat::n_rows; ++row) // { -// if (SpMat::row_indices[row] >= row_num) +// if(SpMat::row_indices[row] >= row_num) // { // access::rw(SpMat::row_indices[row]) += N; // } @@ -334,7 +366,7 @@ SpCol::begin_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds" ); SpMat::sync_csc(); @@ -350,7 +382,7 @@ SpCol::begin_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::begin_row(): index out of bounds" ); SpMat::sync_csc(); @@ -366,7 +398,7 @@ SpCol::end_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds" ); SpMat::sync_csc(); @@ -382,7 +414,7 @@ SpCol::end_row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds"); + arma_debug_check_bounds( (row_num >= SpMat::n_rows), "SpCol::end_row(): index out of bounds" ); SpMat::sync_csc(); @@ -391,7 +423,7 @@ SpCol::end_row(const uword row_num) const -#ifdef ARMA_EXTRA_SPCOL_MEAT +#if defined(ARMA_EXTRA_SPCOL_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPCOL_MEAT) #endif diff --git a/src/armadillo_bits/SpGlue_bones.hpp b/src/armadillo_bits/SpGlue_bones.hpp index 1aad5063..3c5432d8 100644 --- a/src/armadillo_bits/SpGlue_bones.hpp +++ b/src/armadillo_bits/SpGlue_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class SpGlue : public SpBase > +class SpGlue : public SpBase< typename T1::elem_type, SpGlue > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = spglue_type::template traits::is_row; - static const bool is_col = spglue_type::template traits::is_col; - static const bool is_xvec = spglue_type::template traits::is_xvec; + static constexpr bool is_row = spglue_type::template traits::is_row; + static constexpr bool is_col = spglue_type::template traits::is_col; + static constexpr bool is_xvec = spglue_type::template traits::is_xvec; inline SpGlue(const T1& in_A, const T2& in_B); inline SpGlue(const T1& in_A, const T2& in_B, const elem_type in_aux); diff --git a/src/armadillo_bits/SpGlue_meat.hpp b/src/armadillo_bits/SpGlue_meat.hpp index 09c659ce..04d40e1a 100644 --- a/src/armadillo_bits/SpGlue_meat.hpp +++ b/src/armadillo_bits/SpGlue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/SpMat_bones.hpp b/src/armadillo_bits/SpMat_bones.hpp index 75b57bfd..e96f1530 100644 --- a/src/armadillo_bits/SpMat_bones.hpp +++ b/src/armadillo_bits/SpMat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class SpMat : public SpBase< eT, SpMat > typedef eT elem_type; //!< the type of elements stored in the matrix typedef typename get_pod_type::result pod_type; //!< if eT is std::complex, pod_type is T; otherwise pod_type is eT - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const uword n_rows; //!< number of rows (read-only) const uword n_cols; //!< number of columns (read-only) @@ -92,16 +94,14 @@ class SpMat : public SpBase< eT, SpMat > inline SpMat& operator=(const std::string& text); inline SpMat(const SpMat& x); - #if defined(ARMA_USE_CXX11) inline SpMat(SpMat&& m); inline SpMat& operator=(SpMat&& m); - #endif inline explicit SpMat(const MapMat& x); - inline SpMat& operator=(const MapMat& x); + inline SpMat& operator= (const MapMat& x); template - inline SpMat(const Base& rowind, const Base& colptr, const Base& values, const uword n_rows, const uword n_cols); + inline SpMat(const Base& rowind, const Base& colptr, const Base& values, const uword n_rows, const uword n_cols, const bool check_for_zeros = true); template inline SpMat(const Base& locations, const Base& values, const bool sort_locations = true); @@ -112,12 +112,12 @@ class SpMat : public SpBase< eT, SpMat > template inline SpMat(const bool add_values, const Base& locations, const Base& values, const uword n_rows, const uword n_cols, const bool sort_locations = true, const bool check_for_zeros = true); - inline SpMat& operator=(const eT val); //! sets size to 1x1 + inline SpMat& operator= (const eT val); //! sets size to 1x1 inline SpMat& operator*=(const eT val); inline SpMat& operator/=(const eT val); // operator+=(val) and operator-=(val) are not defined as they don't make sense for sparse matrices - inline SpMat& operator=(const SpMat& m); + inline SpMat& operator= (const SpMat& m); inline SpMat& operator+=(const SpMat& m); inline SpMat& operator-=(const SpMat& m); inline SpMat& operator*=(const SpMat& m); @@ -125,7 +125,7 @@ class SpMat : public SpBase< eT, SpMat > inline SpMat& operator/=(const SpMat& m); template inline explicit SpMat(const Base& m); - template inline SpMat& operator=(const Base& m); + template inline SpMat& operator= (const Base& m); template inline SpMat& operator+=(const Base& m); template inline SpMat& operator-=(const Base& m); template inline SpMat& operator*=(const Base& m); @@ -133,7 +133,7 @@ class SpMat : public SpBase< eT, SpMat > template inline SpMat& operator%=(const Base& m); template inline explicit SpMat(const Op& expr); - template inline SpMat& operator=(const Op& expr); + template inline SpMat& operator= (const Op& expr); template inline SpMat& operator+=(const Op& expr); template inline SpMat& operator-=(const Op& expr); template inline SpMat& operator*=(const Op& expr); @@ -148,15 +148,23 @@ class SpMat : public SpBase< eT, SpMat > inline explicit SpMat(const SpBase& A, const SpBase& B); inline SpMat(const SpSubview& X); - inline SpMat& operator=(const SpSubview& X); + inline SpMat& operator= (const SpSubview& X); inline SpMat& operator+=(const SpSubview& X); inline SpMat& operator-=(const SpSubview& X); inline SpMat& operator*=(const SpSubview& X); inline SpMat& operator%=(const SpSubview& X); inline SpMat& operator/=(const SpSubview& X); + template inline SpMat(const SpSubview_col_list& X); + template inline SpMat& operator= (const SpSubview_col_list& X); + template inline SpMat& operator+=(const SpSubview_col_list& X); + template inline SpMat& operator-=(const SpSubview_col_list& X); + template inline SpMat& operator*=(const SpSubview_col_list& X); + template inline SpMat& operator%=(const SpSubview_col_list& X); + template inline SpMat& operator/=(const SpSubview_col_list& X); + inline SpMat(const spdiagview& X); - inline SpMat& operator=(const spdiagview& X); + inline SpMat& operator= (const spdiagview& X); inline SpMat& operator+=(const spdiagview& X); inline SpMat& operator-=(const spdiagview& X); inline SpMat& operator*=(const spdiagview& X); @@ -165,7 +173,7 @@ class SpMat : public SpBase< eT, SpMat > // delayed unary ops template inline SpMat(const SpOp& X); - template inline SpMat& operator=(const SpOp& X); + template inline SpMat& operator= (const SpOp& X); template inline SpMat& operator+=(const SpOp& X); template inline SpMat& operator-=(const SpOp& X); template inline SpMat& operator*=(const SpOp& X); @@ -174,7 +182,7 @@ class SpMat : public SpBase< eT, SpMat > // delayed binary ops template inline SpMat(const SpGlue& X); - template inline SpMat& operator=(const SpGlue& X); + template inline SpMat& operator= (const SpGlue& X); template inline SpMat& operator+=(const SpGlue& X); template inline SpMat& operator-=(const SpGlue& X); template inline SpMat& operator*=(const SpGlue& X); @@ -183,7 +191,7 @@ class SpMat : public SpBase< eT, SpMat > // delayed mixed-type unary ops template inline SpMat(const mtSpOp& X); - template inline SpMat& operator=(const mtSpOp& X); + template inline SpMat& operator= (const mtSpOp& X); template inline SpMat& operator+=(const mtSpOp& X); template inline SpMat& operator-=(const mtSpOp& X); template inline SpMat& operator*=(const mtSpOp& X); @@ -192,7 +200,7 @@ class SpMat : public SpBase< eT, SpMat > // delayed mixed-type binary ops template inline SpMat(const mtSpGlue& X); - template inline SpMat& operator=(const mtSpGlue& X); + template inline SpMat& operator= (const mtSpGlue& X); template inline SpMat& operator+=(const mtSpGlue& X); template inline SpMat& operator-=(const mtSpGlue& X); template inline SpMat& operator*=(const mtSpGlue& X); @@ -200,17 +208,17 @@ class SpMat : public SpBase< eT, SpMat > template inline SpMat& operator/=(const mtSpGlue& X); - arma_inline SpSubview row(const uword row_num); - arma_inline const SpSubview row(const uword row_num) const; + arma_inline SpSubview_row row(const uword row_num); + arma_inline const SpSubview_row row(const uword row_num) const; - inline SpSubview operator()(const uword row_num, const span& col_span); - inline const SpSubview operator()(const uword row_num, const span& col_span) const; + inline SpSubview_row operator()(const uword row_num, const span& col_span); + inline const SpSubview_row operator()(const uword row_num, const span& col_span) const; - arma_inline SpSubview col(const uword col_num); - arma_inline const SpSubview col(const uword col_num) const; + arma_inline SpSubview_col col(const uword col_num); + arma_inline const SpSubview_col col(const uword col_num) const; - inline SpSubview operator()(const span& row_span, const uword col_num); - inline const SpSubview operator()(const span& row_span, const uword col_num) const; + inline SpSubview_col operator()(const span& row_span, const uword col_num); + inline const SpSubview_col operator()(const span& row_span, const uword col_num) const; arma_inline SpSubview rows(const uword in_row1, const uword in_row2); arma_inline const SpSubview rows(const uword in_row1, const uword in_row2) const; @@ -247,6 +255,10 @@ class SpMat : public SpBase< eT, SpMat > inline const SpSubview tail_cols(const uword N) const; + template arma_inline SpSubview_col_list cols(const Base& ci); + template arma_inline const SpSubview_col_list cols(const Base& ci) const; + + inline spdiagview diag(const sword in_id = 0); inline const spdiagview diag(const sword in_id = 0) const; @@ -262,107 +274,104 @@ class SpMat : public SpBase< eT, SpMat > // access the i-th element; if there is nothing at element i, 0 is returned - arma_inline arma_warn_unused SpMat_MapMat_val operator[] (const uword i); - arma_inline arma_warn_unused eT operator[] (const uword i) const; - arma_inline arma_warn_unused SpMat_MapMat_val at (const uword i); - arma_inline arma_warn_unused eT at (const uword i) const; - arma_inline arma_warn_unused SpMat_MapMat_val operator() (const uword i); - arma_inline arma_warn_unused eT operator() (const uword i) const; - - // access the element at the given row and column; if there is nothing at that position, 0 is returned - arma_inline arma_warn_unused SpMat_MapMat_val at (const uword in_row, const uword in_col); - arma_inline arma_warn_unused eT at (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused SpMat_MapMat_val operator() (const uword in_row, const uword in_col); - arma_inline arma_warn_unused eT operator() (const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline SpMat_MapMat_val operator[] (const uword i); + arma_warn_unused arma_inline eT operator[] (const uword i) const; + arma_warn_unused arma_inline SpMat_MapMat_val at (const uword i); + arma_warn_unused arma_inline eT at (const uword i) const; - arma_inline arma_warn_unused bool is_empty() const; - arma_inline arma_warn_unused bool is_vec() const; - arma_inline arma_warn_unused bool is_rowvec() const; - arma_inline arma_warn_unused bool is_colvec() const; - arma_inline arma_warn_unused bool is_square() const; - inline arma_warn_unused bool is_finite() const; + arma_warn_unused arma_inline SpMat_MapMat_val operator() (const uword i); + arma_warn_unused arma_inline eT operator() (const uword i) const; - inline arma_warn_unused bool is_symmetric() const; - inline arma_warn_unused bool is_symmetric(const typename get_pod_type::result tol) const; + // access the element at the given row and column; if there is nothing at that position, 0 is returned + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline SpMat_MapMat_val operator[] (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT operator[] (const uword in_row, const uword in_col) const; + #endif - inline arma_warn_unused bool is_hermitian() const; - inline arma_warn_unused bool is_hermitian(const typename get_pod_type::result tol) const; + arma_warn_unused arma_inline SpMat_MapMat_val at (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT at (const uword in_row, const uword in_col) const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused arma_inline SpMat_MapMat_val operator() (const uword in_row, const uword in_col); + arma_warn_unused arma_inline eT operator() (const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const uword i) const; - arma_inline arma_warn_unused bool in_range(const span& x) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const span& col_span) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const span& col_span) const; + arma_warn_unused arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_vec() const; + arma_warn_unused arma_inline bool is_rowvec() const; + arma_warn_unused arma_inline bool is_colvec() const; + arma_warn_unused arma_inline bool is_square() const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + arma_warn_unused inline bool is_symmetric() const; + arma_warn_unused inline bool is_symmetric(const typename get_pod_type::result tol) const; + arma_warn_unused inline bool is_hermitian() const; + arma_warn_unused inline bool is_hermitian(const typename get_pod_type::result tol) const; - arma_cold inline void impl_print( const std::string& extra_text) const; - arma_cold inline void impl_print(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused inline bool internal_is_finite() const; + arma_warn_unused inline bool internal_has_inf() const; + arma_warn_unused inline bool internal_has_nan() const; + arma_warn_unused inline bool internal_has_nonfinite() const; - arma_cold inline void impl_raw_print( const std::string& extra_text) const; - arma_cold inline void impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; - arma_cold inline void impl_print_dense( const std::string& extra_text) const; - arma_cold inline void impl_print_dense(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; - arma_cold inline void impl_raw_print_dense( const std::string& extra_text) const; - arma_cold inline void impl_raw_print_dense(std::ostream& user_stream, const std::string& extra_text) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; - template inline void copy_size(const SpMat& m); - template inline void copy_size(const Mat& m); + template inline SpMat& copy_size(const SpMat& m); + template inline SpMat& copy_size(const Mat& m); - inline void set_size(const uword in_elem); - inline void set_size(const uword in_rows, const uword in_cols); - inline void set_size(const SizeMat& s); + inline SpMat& set_size(const uword in_elem); + inline SpMat& set_size(const uword in_rows, const uword in_cols); + inline SpMat& set_size(const SizeMat& s); - inline void resize(const uword in_rows, const uword in_cols); - inline void resize(const SizeMat& s); + inline SpMat& resize(const uword in_rows, const uword in_cols); + inline SpMat& resize(const SizeMat& s); - inline void reshape(const uword in_rows, const uword in_cols); - inline void reshape(const SizeMat& s); + inline SpMat& reshape(const uword in_rows, const uword in_cols); + inline SpMat& reshape(const SizeMat& s); inline void reshape_helper_generic(const uword in_rows, const uword in_cols); //! internal use only inline void reshape_helper_intovec(); //! internal use only - arma_deprecated inline void reshape(const uword in_rows, const uword in_cols, const uword dim); //!< NOTE: don't use this form: it will be removed - - template inline const SpMat& for_each(functor F); + template inline SpMat& for_each(functor F); template inline const SpMat& for_each(functor F) const; - template inline const SpMat& transform(functor F); + template inline SpMat& transform(functor F); - inline const SpMat& replace(const eT old_val, const eT new_val); + inline SpMat& replace(const eT old_val, const eT new_val); - inline const SpMat& clean(const pod_type threshold); + inline SpMat& clean(const pod_type threshold); - inline const SpMat& zeros(); - inline const SpMat& zeros(const uword in_elem); - inline const SpMat& zeros(const uword in_rows, const uword in_cols); - inline const SpMat& zeros(const SizeMat& s); + inline SpMat& clamp(const eT min_val, const eT max_val); - inline const SpMat& eye(); - inline const SpMat& eye(const uword in_rows, const uword in_cols); - inline const SpMat& eye(const SizeMat& s); + inline SpMat& zeros(); + inline SpMat& zeros(const uword in_elem); + inline SpMat& zeros(const uword in_rows, const uword in_cols); + inline SpMat& zeros(const SizeMat& s); - inline const SpMat& speye(); - inline const SpMat& speye(const uword in_rows, const uword in_cols); - inline const SpMat& speye(const SizeMat& s); + inline SpMat& eye(); + inline SpMat& eye(const uword in_rows, const uword in_cols); + inline SpMat& eye(const SizeMat& s); - inline const SpMat& sprandu(const uword in_rows, const uword in_cols, const double density); - inline const SpMat& sprandu(const SizeMat& s, const double density); + inline SpMat& speye(); + inline SpMat& speye(const uword in_rows, const uword in_cols); + inline SpMat& speye(const SizeMat& s); - inline const SpMat& sprandn(const uword in_rows, const uword in_cols, const double density); - inline const SpMat& sprandn(const SizeMat& s, const double density); + inline SpMat& sprandu(const uword in_rows, const uword in_cols, const double density); + inline SpMat& sprandu(const SizeMat& s, const double density); + + inline SpMat& sprandn(const uword in_rows, const uword in_cols, const double density); + inline SpMat& sprandn(const SizeMat& s, const double density); inline void reset(); + inline void reset_cache(); //! don't use this unless you're writing internal Armadillo code inline void reserve(const uword in_rows, const uword in_cols, const uword new_n_nonzero); @@ -381,17 +390,19 @@ class SpMat : public SpBase< eT, SpMat > // saving and loading // TODO: implement auto_detect for sparse matrices - inline arma_cold bool save(const std::string name, const file_type type = arma_binary, const bool print_status = true) const; - inline arma_cold bool save( std::ostream& os, const file_type type = arma_binary, const bool print_status = true) const; + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save(const csv_name& spec, const file_type type = csv_ascii) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool load(const std::string name, const file_type type = arma_binary, const bool print_status = true); - inline arma_cold bool load( std::istream& is, const file_type type = arma_binary, const bool print_status = true); + arma_cold inline bool load(const std::string name, const file_type type = arma_binary); + arma_cold inline bool load(const csv_name& spec, const file_type type = csv_ascii); + arma_cold inline bool load( std::istream& is, const file_type type = arma_binary); - inline arma_cold bool quiet_save(const std::string name, const file_type type = arma_binary) const; - inline arma_cold bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool quiet_load(const std::string name, const file_type type = arma_binary); - inline arma_cold bool quiet_load( std::istream& is, const file_type type = arma_binary); + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = arma_binary); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = arma_binary); @@ -434,30 +445,31 @@ class SpMat : public SpBase< eT, SpMat > public: inline const_iterator(); - inline const_iterator(const SpMat& in_M, uword initial_pos = 0); // assumes initial_pos is valid - //! once initialised, will be at the first nonzero value after the given position (using forward columnwise traversal) - inline const_iterator(const SpMat& in_M, uword in_row, uword in_col); - //! if you know the exact position of the iterator; in_row is a dummy argument - inline const_iterator(const SpMat& in_M, uword in_row, uword in_col, uword in_pos); - inline const_iterator(const const_iterator& other); - inline arma_hot const_iterator& operator++(); - inline arma_warn_unused const_iterator operator++(int); + inline const_iterator(const SpMat& in_M, uword initial_pos = 0); // assumes initial_pos is valid + inline const_iterator(const SpMat& in_M, uword in_row, uword in_col); // iterator will be at the first nonzero value after the given position (using forward columnwise traversal) + inline const_iterator(const SpMat& in_M, uword in_row, uword in_col, uword in_pos); // if the exact position of the iterator is known; in_row is a dummy argument + + inline const_iterator(const const_iterator& other); + inline const_iterator& operator= (const const_iterator& other) = default; + + arma_hot inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); - inline arma_hot const_iterator& operator--(); - inline arma_warn_unused const_iterator operator--(int); + arma_hot inline const_iterator& operator--(); + arma_warn_unused inline const_iterator operator--(int); - inline arma_hot bool operator==(const const_iterator& rhs) const; - inline arma_hot bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + arma_hot inline bool operator!=(const const_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpSubview::const_iterator& rhs) const; - inline arma_hot bool operator!=(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_iterator& rhs) const; - inline arma_hot bool operator==(const const_row_iterator& rhs) const; - inline arma_hot bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpSubview::const_row_iterator& rhs) const; - inline arma_hot bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; }; /** @@ -470,24 +482,27 @@ class SpMat : public SpBase< eT, SpMat > public: inline iterator() : const_iterator() { } - inline iterator(SpMat& in_M, uword initial_pos = 0) : const_iterator(in_M, initial_pos) { } - inline iterator(SpMat& in_M, uword in_row, uword in_col) : const_iterator(in_M, in_row, in_col) { } + + inline iterator(SpMat& in_M, uword initial_pos = 0) : const_iterator(in_M, initial_pos) { } + inline iterator(SpMat& in_M, uword in_row, uword in_col) : const_iterator(in_M, in_row, in_col) { } inline iterator(SpMat& in_M, uword in_row, uword in_col, uword in_pos) : const_iterator(in_M, in_row, in_col, in_pos) { } - inline iterator(const iterator& other) : const_iterator(other) { } - inline arma_hot SpValProxy > operator*(); + inline iterator (const iterator& other) : const_iterator(other) { } + inline iterator& operator=(const iterator& other) = default; + + arma_hot inline SpValProxy< SpMat > operator*(); // overloads needed for return type correctness - inline arma_hot iterator& operator++(); - inline arma_warn_unused iterator operator++(int); + arma_hot inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); - inline arma_hot iterator& operator--(); - inline arma_warn_unused iterator operator--(int); + arma_hot inline iterator& operator--(); + arma_warn_unused inline iterator operator--(int); // this has a different value_type than iterator_base - typedef SpValProxy > value_type; - typedef const SpValProxy >* pointer; - typedef const SpValProxy >& reference; + typedef SpValProxy< SpMat > value_type; + typedef const SpValProxy< SpMat >* pointer; + typedef const SpValProxy< SpMat >& reference; }; class const_row_iterator : public iterator_base @@ -496,34 +511,35 @@ class SpMat : public SpBase< eT, SpMat > inline const_row_iterator(); inline const_row_iterator(const SpMat& in_M, uword initial_pos = 0); - //! once initialised, will be at the first nonzero value after the given position (using forward row-wise traversal) inline const_row_iterator(const SpMat& in_M, uword in_row, uword in_col); - inline const_row_iterator(const const_row_iterator& other); - inline arma_hot const_row_iterator& operator++(); - inline arma_warn_unused const_row_iterator operator++(int); + inline const_row_iterator(const const_row_iterator& other); + inline const_row_iterator& operator= (const const_row_iterator& other) = default; - inline arma_hot const_row_iterator& operator--(); - inline arma_warn_unused const_row_iterator operator--(int); + arma_hot inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + arma_hot inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); uword internal_row; // hold row internally - uword actual_pos; // this holds the true position we are at in the matrix, as column-major indexing + uword actual_pos; // hold the true position we are at in the matrix, as column-major indexing arma_inline eT operator*() const { return iterator_base::M->values[actual_pos]; } arma_inline uword row() const { return internal_row; } - inline arma_hot bool operator==(const const_iterator& rhs) const; - inline arma_hot bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + arma_hot inline bool operator!=(const const_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpSubview::const_iterator& rhs) const; - inline arma_hot bool operator!=(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpSubview::const_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_iterator& rhs) const; - inline arma_hot bool operator==(const const_row_iterator& rhs) const; - inline arma_hot bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpSubview::const_row_iterator& rhs) const; - inline arma_hot bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpSubview::const_row_iterator& rhs) const; + arma_hot inline bool operator!=(const typename SpSubview::const_row_iterator& rhs) const; }; class row_iterator : public const_row_iterator @@ -531,24 +547,26 @@ class SpMat : public SpBase< eT, SpMat > public: inline row_iterator() : const_row_iterator() {} - inline row_iterator(SpMat& in_M, uword initial_pos = 0) : const_row_iterator(in_M, initial_pos) { } - //! once initialised, will be at the first nonzero value after the given position (using forward row-wise traversal) + + inline row_iterator(SpMat& in_M, uword initial_pos = 0) : const_row_iterator(in_M, initial_pos) { } inline row_iterator(SpMat& in_M, uword in_row, uword in_col) : const_row_iterator(in_M, in_row, in_col) { } - inline row_iterator(const row_iterator& other) : const_row_iterator(other) { } - inline arma_hot SpValProxy > operator*(); + inline row_iterator(const row_iterator& other) : const_row_iterator(other) { } + inline row_iterator& operator= (const row_iterator& other) = default; + + arma_hot inline SpValProxy< SpMat > operator*(); // overloads required for return type correctness - inline arma_hot row_iterator& operator++(); - inline arma_warn_unused row_iterator operator++(int); + arma_hot inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); - inline arma_hot row_iterator& operator--(); - inline arma_warn_unused row_iterator operator--(int); + arma_hot inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); // this has a different value_type than iterator_base - typedef SpValProxy > value_type; - typedef const SpValProxy >* pointer; - typedef const SpValProxy >& reference; + typedef SpValProxy< SpMat > value_type; + typedef const SpValProxy< SpMat >* pointer; + typedef const SpValProxy< SpMat >& reference; }; @@ -599,6 +617,12 @@ class SpMat : public SpBase< eT, SpMat > inline bool empty() const; inline uword size() const; + arma_warn_unused arma_inline SpMat_MapMat_val front(); + arma_warn_unused arma_inline eT front() const; + + arma_warn_unused arma_inline SpMat_MapMat_val back(); + arma_warn_unused arma_inline eT back() const; + // Resize memory. // If the new size is larger, the column pointers and new memory still need to be correctly set. // If the new size is smaller, the first new_n_nonzero elements will be copied. @@ -618,8 +642,8 @@ class SpMat : public SpBase< eT, SpMat > inline void steal_mem_simple(SpMat& X); //! don't use this unless you're writing internal Armadillo code - template< typename T1, typename Functor> arma_hot inline void init_xform (const SpBase& x, const Functor& func); - template arma_hot inline void init_xform_mt(const SpBase& x, const Functor& func); + template< typename T1, typename Functor> inline void init_xform (const SpBase& x, const Functor& func); + template inline void init_xform_mt(const SpBase& x, const Functor& func); //! don't use this unless you're writing internal Armadillo code arma_inline bool is_alias(const SpMat& X) const; @@ -627,8 +651,8 @@ class SpMat : public SpBase< eT, SpMat > protected: - inline void init(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); - inline void arma_cold init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); + inline void init(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); + arma_cold inline void init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero = 0); inline void init(const std::string& text); inline void init(const SpMat& x); @@ -645,22 +669,22 @@ class SpMat : public SpBase< eT, SpMat > private: - inline arma_hot arma_warn_unused const eT* find_value_csc(const uword in_row, const uword in_col) const; + arma_warn_unused arma_hot inline const eT* find_value_csc(const uword in_row, const uword in_col) const; - inline arma_hot arma_warn_unused eT get_value(const uword i ) const; - inline arma_hot arma_warn_unused eT get_value(const uword in_row, const uword in_col) const; + arma_warn_unused arma_hot inline eT get_value(const uword i ) const; + arma_warn_unused arma_hot inline eT get_value(const uword in_row, const uword in_col) const; - inline arma_hot arma_warn_unused eT get_value_csc(const uword i ) const; - inline arma_hot arma_warn_unused eT get_value_csc(const uword in_row, const uword in_col) const; + arma_warn_unused arma_hot inline eT get_value_csc(const uword i ) const; + arma_warn_unused arma_hot inline eT get_value_csc(const uword in_row, const uword in_col) const; - inline arma_hot arma_warn_unused bool try_set_value_csc(const uword in_row, const uword in_col, const eT in_val); - inline arma_hot arma_warn_unused bool try_add_value_csc(const uword in_row, const uword in_col, const eT in_val); - inline arma_hot arma_warn_unused bool try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val); - inline arma_hot arma_warn_unused bool try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val); - inline arma_hot arma_warn_unused bool try_div_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_set_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_add_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val); + arma_warn_unused arma_hot inline bool try_div_value_csc(const uword in_row, const uword in_col, const eT in_val); - inline arma_warn_unused eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); - inline void delete_element(const uword in_row, const uword in_col); + arma_warn_unused inline eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); + inline void delete_element(const uword in_row, const uword in_col); // cache related @@ -671,7 +695,7 @@ class SpMat : public SpBase< eT, SpMat > // 1: CSC needs to be updated from cache (ie. cache has more recent data) // 2: no update required (ie. CSC and cache contain the same data) - #if defined(ARMA_USE_CXX11) + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) arma_aligned mutable std::mutex cache_mutex; #endif @@ -692,10 +716,11 @@ class SpMat : public SpBase< eT, SpMat > friend class SpSubview_MapMat_val; friend class spdiagview; + template friend class SpSubview_col_list; public: - #ifdef ARMA_EXTRA_SPMAT_PROTO + #if defined(ARMA_EXTRA_SPMAT_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_PROTO) #endif }; diff --git a/src/armadillo_bits/SpMat_iterators_meat.hpp b/src/armadillo_bits/SpMat_iterators_meat.hpp index 3150d062..ed29640d 100644 --- a/src/armadillo_bits/SpMat_iterators_meat.hpp +++ b/src/armadillo_bits/SpMat_iterators_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,7 +28,7 @@ template inline SpMat::iterator_base::iterator_base() - : M(NULL) + : M(nullptr) , internal_col(0) , internal_pos(0) { @@ -80,6 +82,8 @@ SpMat::const_iterator::const_iterator() { } + + template inline SpMat::const_iterator::const_iterator(const SpMat& in_M, uword initial_pos) @@ -91,7 +95,7 @@ SpMat::const_iterator::const_iterator(const SpMat& in_M, uword initial_p iterator_base::internal_col = in_M.n_cols; return; } - + // Determine which column we should be in. while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) { @@ -108,13 +112,13 @@ SpMat::const_iterator::const_iterator(const SpMat& in_M, uword in_row, u { // So we have a position we want to be right after. Skip to the column. iterator_base::internal_pos = iterator_base::M->col_ptrs[iterator_base::internal_col]; - + // Now we have to make sure that is the right column. while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) { iterator_base::internal_col++; } - + // Now we have to get to the right row. while((iterator_base::M->row_indices[iterator_base::internal_pos] < in_row) && (iterator_base::internal_col == in_col)) { @@ -146,24 +150,23 @@ SpMat::const_iterator::const_iterator(const typename SpMat::const_iterat template inline -arma_hot typename SpMat::const_iterator& SpMat::const_iterator::operator++() { ++iterator_base::internal_pos; - - if (iterator_base::internal_pos == iterator_base::M->n_nonzero) + + if(iterator_base::internal_pos == iterator_base::M->n_nonzero) { iterator_base::internal_col = iterator_base::M->n_cols; return *this; } - + // Check to see if we moved a column. - while (iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) + while(iterator_base::M->col_ptrs[iterator_base::internal_col + 1] <= iterator_base::internal_pos) { ++iterator_base::internal_col; } - + return *this; } @@ -171,14 +174,13 @@ SpMat::const_iterator::operator++() template inline -arma_warn_unused typename SpMat::const_iterator SpMat::const_iterator::operator++(int) { typename SpMat::const_iterator tmp(*this); - + ++(*this); - + return tmp; } @@ -186,19 +188,17 @@ SpMat::const_iterator::operator++(int) template inline -arma_hot typename SpMat::const_iterator& SpMat::const_iterator::operator--() { --iterator_base::internal_pos; // First, see if we moved back a column. - while (iterator_base::internal_pos < iterator_base::M->col_ptrs[iterator_base::internal_col]) + while(iterator_base::internal_pos < iterator_base::M->col_ptrs[iterator_base::internal_col]) { --iterator_base::internal_col; } - - + return *this; } @@ -206,14 +206,13 @@ SpMat::const_iterator::operator--() template inline -arma_warn_unused typename SpMat::const_iterator SpMat::const_iterator::operator--(int) { typename SpMat::const_iterator tmp(*this); - + --(*this); - + return tmp; } @@ -221,7 +220,6 @@ SpMat::const_iterator::operator--(int) template inline -arma_hot bool SpMat::const_iterator::operator==(const const_iterator& rhs) const { @@ -232,7 +230,6 @@ SpMat::const_iterator::operator==(const const_iterator& rhs) const template inline -arma_hot bool SpMat::const_iterator::operator!=(const const_iterator& rhs) const { @@ -243,7 +240,6 @@ SpMat::const_iterator::operator!=(const const_iterator& rhs) const template inline -arma_hot bool SpMat::const_iterator::operator==(const typename SpSubview::const_iterator& rhs) const { @@ -254,7 +250,6 @@ SpMat::const_iterator::operator==(const typename SpSubview::const_iterat template inline -arma_hot bool SpMat::const_iterator::operator!=(const typename SpSubview::const_iterator& rhs) const { @@ -265,7 +260,6 @@ SpMat::const_iterator::operator!=(const typename SpSubview::const_iterat template inline -arma_hot bool SpMat::const_iterator::operator==(const const_row_iterator& rhs) const { @@ -276,7 +270,6 @@ SpMat::const_iterator::operator==(const const_row_iterator& rhs) const template inline -arma_hot bool SpMat::const_iterator::operator!=(const const_row_iterator& rhs) const { @@ -287,7 +280,6 @@ SpMat::const_iterator::operator!=(const const_row_iterator& rhs) const template inline -arma_hot bool SpMat::const_iterator::operator==(const typename SpSubview::const_row_iterator& rhs) const { @@ -298,7 +290,6 @@ SpMat::const_iterator::operator==(const typename SpSubview::const_row_it template inline -arma_hot bool SpMat::const_iterator::operator!=(const typename SpSubview::const_row_iterator& rhs) const { @@ -313,11 +304,10 @@ SpMat::const_iterator::operator!=(const typename SpSubview::const_row_it template inline -arma_hot -SpValProxy > +SpValProxy< SpMat > SpMat::iterator::operator*() { - return SpValProxy >( + return SpValProxy< SpMat >( iterator_base::M->row_indices[iterator_base::internal_pos], iterator_base::internal_col, access::rw(*iterator_base::M), @@ -328,11 +318,11 @@ SpMat::iterator::operator*() template inline -arma_hot typename SpMat::iterator& SpMat::iterator::operator++() { const_iterator::operator++(); + return *this; } @@ -340,14 +330,13 @@ SpMat::iterator::operator++() template inline -arma_warn_unused typename SpMat::iterator SpMat::iterator::operator++(int) { typename SpMat::iterator tmp(*this); - + const_iterator::operator++(); - + return tmp; } @@ -355,11 +344,11 @@ SpMat::iterator::operator++(int) template inline -arma_hot typename SpMat::iterator& SpMat::iterator::operator--() { const_iterator::operator--(); + return *this; } @@ -367,14 +356,13 @@ SpMat::iterator::operator--() template inline -arma_warn_unused typename SpMat::iterator SpMat::iterator::operator--(int) { typename SpMat::iterator tmp(*this); - + const_iterator::operator--(); - + return tmp; } @@ -407,73 +395,71 @@ SpMat::const_row_iterator::const_row_iterator(const SpMat& in_M, uword i , actual_pos(0) { // Corner case for the end of a matrix. - if (initial_pos == in_M.n_nonzero) + if(initial_pos == in_M.n_nonzero) { iterator_base::internal_col = 0; internal_row = in_M.n_rows; actual_pos = in_M.n_nonzero; iterator_base::internal_pos = in_M.n_nonzero; - + return; } - + // We don't count zeros in our position count, so we have to find the nonzero // value corresponding to the given initial position. We assume initial_pos // is valid. - - // This is irritating because we don't know where the elements are in each - // row. What we will do is loop across all columns looking for elements in - // row 0 (and add to our sum), then in row 1, and so forth, until we get to - // the desired position. + + // This is irritating because we don't know where the elements are in each row. + // What we will do is loop across all columns looking for elements in row 0 + // (and add to our sum), then in row 1, and so forth, until we get to the desired position. uword cur_pos = std::numeric_limits::max(); // Invalid value. uword cur_actual_pos = 0; - - for (uword row = 0; row < iterator_base::M->n_rows; ++row) + + for(uword row = 0; row < iterator_base::M->n_rows; ++row) { - for (uword col = 0; col < iterator_base::M->n_cols; ++col) + for(uword col = 0; col < iterator_base::M->n_cols; ++col) { // Find the first element with row greater than or equal to in_row. const uword col_offset = iterator_base::M->col_ptrs[col ]; const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - - if (start_ptr != end_ptr) + + if(start_ptr != end_ptr) { const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, row); - - // This is the number of elements in the column with row index less than - // in_row. + + // This is the number of elements in the column with row index less than in_row. const uword offset = uword(pos_ptr - start_ptr); - - if (iterator_base::M->row_indices[col_offset + offset] == row) + + if(iterator_base::M->row_indices[col_offset + offset] == row) { cur_actual_pos = col_offset + offset; - + // Increment position portably. - if (cur_pos == std::numeric_limits::max()) - cur_pos = 0; + if(cur_pos == std::numeric_limits::max()) + { cur_pos = 0; } else - ++cur_pos; - + { ++cur_pos; } + // Do we terminate? - if (cur_pos == initial_pos) + if(cur_pos == initial_pos) { internal_row = row; iterator_base::internal_col = col; iterator_base::internal_pos = cur_pos; actual_pos = cur_actual_pos; - + return; } } } } } - - // If we got to here, then we have gone past the end of the matrix. This - // shouldn't happen... + + // If we got to here, then we have gone past the end of the matrix. + // This shouldn't happen... iterator_base::internal_pos = iterator_base::M->n_nonzero; iterator_base::internal_col = 0; internal_row = iterator_base::M->n_rows; @@ -497,41 +483,40 @@ SpMat::const_row_iterator::const_row_iterator(const SpMat& in_M, uword i // // We'll find these simultaneously, though we will have to loop over all // columns. - + // This will hold the total number of points with rows less than in_row. uword cur_pos = 0; uword cur_min_row = iterator_base::M->n_rows; uword cur_min_col = 0; uword cur_actual_pos = 0; - - for (uword col = 0; col < iterator_base::M->n_cols; ++col) + + for(uword col = 0; col < iterator_base::M->n_cols; ++col) { // Find the first element with row greater than or equal to in_row. const uword col_offset = iterator_base::M->col_ptrs[col ]; const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, in_row); - - // This is the number of elements in the column with row index less than - // in_row. + + // This is the number of elements in the column with row index less than in_row. const uword offset = uword(pos_ptr - start_ptr); - + cur_pos += offset; - - if (pos_ptr != end_ptr) + + if(pos_ptr != end_ptr) { // This is the row index of the first element in the column with row index // greater than or equal to in_row. - if ((*pos_ptr) < cur_min_row) + if((*pos_ptr) < cur_min_row) { - // If we are in the desired row but before the desired column, we - // can't take this. - if (col >= in_col) + // If we are in the desired row but before the desired column, + // we can't take this. + if(col >= in_col) { cur_min_row = (*pos_ptr); cur_min_col = col; @@ -541,7 +526,7 @@ SpMat::const_row_iterator::const_row_iterator(const SpMat& in_M, uword i } } } - + // Now we know what the minimum row is. internal_row = cur_min_row; iterator_base::internal_col = cur_min_col; @@ -571,54 +556,53 @@ SpMat::const_row_iterator::const_row_iterator(const typename SpMat::cons */ template inline -arma_hot typename SpMat::const_row_iterator& SpMat::const_row_iterator::operator++() { // We just need to find the next nonzero element. iterator_base::internal_pos++; - + if(iterator_base::internal_pos == iterator_base::M->n_nonzero) { internal_row = iterator_base::M->n_rows; iterator_base::internal_col = 0; - + return *this; } - + // Otherwise, we need to search. We can start in the next column and use // lower_bound() to find the next element. uword next_min_row = iterator_base::M->n_rows; uword next_min_col = iterator_base::M->n_cols; uword next_actual_pos = 0; - + // Search from the current column to the end of the matrix. - for (uword col = iterator_base::internal_col + 1; col < iterator_base::M->n_cols; ++col) + for(uword col = iterator_base::internal_col + 1; col < iterator_base::M->n_cols; ++col) { // Find the first element with row greater than or equal to in_row. const uword col_offset = iterator_base::M->col_ptrs[col ]; const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - - if (start_ptr != end_ptr) + + if(start_ptr != end_ptr) { // Find the first element in the column with row greater than or equal to // the current row. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row); - - if (pos_ptr != end_ptr) + + if(pos_ptr != end_ptr) { // We found something in the column, but is the row index correct? - if ((*pos_ptr) == internal_row) + if((*pos_ptr) == internal_row) { // Exact match---so we are done. iterator_base::internal_col = col; actual_pos = col_offset + (pos_ptr - start_ptr); return *this; } - else if ((*pos_ptr) < next_min_row) + else if((*pos_ptr) < next_min_row) { // The first element in this column is in a subsequent row, but it's // the minimum row we've seen so far. @@ -626,7 +610,7 @@ SpMat::const_row_iterator::operator++() next_min_col = col; next_actual_pos = col_offset + (pos_ptr - start_ptr); } - else if ((*pos_ptr) == next_min_row && col < next_min_col) + else if((*pos_ptr) == next_min_row && col < next_min_col) { // The first element in this column is in a subsequent row that we // already have another element for, but the column index is less so @@ -637,25 +621,25 @@ SpMat::const_row_iterator::operator++() } } } - + // Restart the search in the next row. - for (uword col = 0; col <= iterator_base::internal_col; ++col) + for(uword col = 0; col <= iterator_base::internal_col; ++col) { // Find the first element with row greater than or equal to in_row + 1. const uword col_offset = iterator_base::M->col_ptrs[col ]; const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - - if (start_ptr != end_ptr) + + if(start_ptr != end_ptr) { const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + 1); - - if (pos_ptr != end_ptr) + + if(pos_ptr != end_ptr) { // We found something in the column, but is the row index correct? - if ((*pos_ptr) == internal_row + 1) + if((*pos_ptr) == internal_row + 1) { // Exact match---so we are done. iterator_base::internal_col = col; @@ -663,15 +647,15 @@ SpMat::const_row_iterator::operator++() actual_pos = col_offset + (pos_ptr - start_ptr); return *this; } - else if ((*pos_ptr) < next_min_row) + else if((*pos_ptr) < next_min_row) { - // The first element in this column is in a subsequent row, but it's - // the minimum row we've seen so far. + // The first element in this column is in a subsequent row, + // but it's the minimum row we've seen so far. next_min_row = (*pos_ptr); next_min_col = col; next_actual_pos = col_offset + (pos_ptr - start_ptr); } - else if ((*pos_ptr) == next_min_row && col < next_min_col) + else if((*pos_ptr) == next_min_row && col < next_min_col) { // The first element in this column is in a subsequent row that we // already have another element for, but the column index is less so @@ -682,11 +666,11 @@ SpMat::const_row_iterator::operator++() } } } - + iterator_base::internal_col = next_min_col; internal_row = next_min_row; actual_pos = next_actual_pos; - + return *this; // Now we are done. } @@ -697,14 +681,13 @@ SpMat::const_row_iterator::operator++() */ template inline -arma_warn_unused typename SpMat::const_row_iterator SpMat::const_row_iterator::operator++(int) { typename SpMat::const_row_iterator tmp(*this); - + ++(*this); - + return tmp; } @@ -715,49 +698,48 @@ SpMat::const_row_iterator::operator++(int) */ template inline -arma_hot typename SpMat::const_row_iterator& SpMat::const_row_iterator::operator--() { - if (iterator_base::internal_pos == 0) + if(iterator_base::internal_pos == 0) { // Do nothing; we are already at the beginning. return *this; } - + iterator_base::internal_pos--; - + // We have to search backwards. We'll do this by going backwards over columns // and seeing if we find an element in the same row. uword max_row = 0; uword max_col = 0; uword next_actual_pos = 0; - - //for (uword col = iterator_base::internal_col; col > 1; --col) - for (uword col = iterator_base::internal_col; col >= 1; --col) + + //for(uword col = iterator_base::internal_col; col > 1; --col) + for(uword col = iterator_base::internal_col; col >= 1; --col) { // Find the first element with row greater than or equal to in_row + 1. const uword col_offset = iterator_base::M->col_ptrs[col - 1]; const uword next_col_offset = iterator_base::M->col_ptrs[col ]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - - if (start_ptr != end_ptr) + + if(start_ptr != end_ptr) { // There are elements in this column. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + 1); - - if (pos_ptr != start_ptr) + + if(pos_ptr != start_ptr) { // The element before pos_ptr is the one we are interested in. - if (*(pos_ptr - 1) > max_row) + if(*(pos_ptr - 1) > max_row) { max_row = *(pos_ptr - 1); max_col = col - 1; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); } - else if (*(pos_ptr - 1) == max_row && (col - 1) > max_col) + else if(*(pos_ptr - 1) == max_row && (col - 1) > max_col) { max_col = col - 1; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); @@ -765,49 +747,49 @@ SpMat::const_row_iterator::operator--() } } } - + // Now loop around to the columns at the end of the matrix. - for (uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) + for(uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) { // Find the first element with row greater than or equal to in_row + 1. const uword col_offset = iterator_base::M->col_ptrs[col ]; const uword next_col_offset = iterator_base::M->col_ptrs[col + 1]; - + const uword* start_ptr = &iterator_base::M->row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->row_indices[next_col_offset]; - - if (start_ptr != end_ptr) + + if(start_ptr != end_ptr) { // There are elements in this column. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row); - - if (pos_ptr != start_ptr) + + if(pos_ptr != start_ptr) { // There are elements in this column with row index < internal_row. - if (*(pos_ptr - 1) > max_row) + if(*(pos_ptr - 1) > max_row) { max_row = *(pos_ptr - 1); max_col = col; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); } - else if (*(pos_ptr - 1) == max_row && col > max_col) + else if(*(pos_ptr - 1) == max_row && col > max_col) { max_col = col; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); } } } - - if (col == 0) // Catch edge case that the loop termination condition won't. + + if(col == 0) // Catch edge case that the loop termination condition won't. { break; } } - + iterator_base::internal_col = max_col; internal_row = max_row; actual_pos = next_actual_pos; - + return *this; } @@ -818,14 +800,13 @@ SpMat::const_row_iterator::operator--() */ template inline -arma_warn_unused typename SpMat::const_row_iterator SpMat::const_row_iterator::operator--(int) { typename SpMat::const_row_iterator tmp(*this); - + --(*this); - + return tmp; } @@ -833,7 +814,6 @@ SpMat::const_row_iterator::operator--(int) template inline -arma_hot bool SpMat::const_row_iterator::operator==(const const_iterator& rhs) const { @@ -844,7 +824,6 @@ SpMat::const_row_iterator::operator==(const const_iterator& rhs) const template inline -arma_hot bool SpMat::const_row_iterator::operator!=(const const_iterator& rhs) const { @@ -855,7 +834,6 @@ SpMat::const_row_iterator::operator!=(const const_iterator& rhs) const template inline -arma_hot bool SpMat::const_row_iterator::operator==(const typename SpSubview::const_iterator& rhs) const { @@ -866,7 +844,6 @@ SpMat::const_row_iterator::operator==(const typename SpSubview::const_it template inline -arma_hot bool SpMat::const_row_iterator::operator!=(const typename SpSubview::const_iterator& rhs) const { @@ -877,7 +854,6 @@ SpMat::const_row_iterator::operator!=(const typename SpSubview::const_it template inline -arma_hot bool SpMat::const_row_iterator::operator==(const const_row_iterator& rhs) const { @@ -888,7 +864,6 @@ SpMat::const_row_iterator::operator==(const const_row_iterator& rhs) const template inline -arma_hot bool SpMat::const_row_iterator::operator!=(const const_row_iterator& rhs) const { @@ -899,7 +874,6 @@ SpMat::const_row_iterator::operator!=(const const_row_iterator& rhs) const template inline -arma_hot bool SpMat::const_row_iterator::operator==(const typename SpSubview::const_row_iterator& rhs) const { @@ -910,7 +884,6 @@ SpMat::const_row_iterator::operator==(const typename SpSubview::const_ro template inline -arma_hot bool SpMat::const_row_iterator::operator!=(const typename SpSubview::const_row_iterator& rhs) const { @@ -925,11 +898,10 @@ SpMat::const_row_iterator::operator!=(const typename SpSubview::const_ro template inline -arma_hot -SpValProxy > +SpValProxy< SpMat > SpMat::row_iterator::operator*() { - return SpValProxy >( + return SpValProxy< SpMat >( const_row_iterator::internal_row, iterator_base::internal_col, access::rw(*iterator_base::M), @@ -940,11 +912,11 @@ SpMat::row_iterator::operator*() template inline -arma_hot typename SpMat::row_iterator& SpMat::row_iterator::operator++() { const_row_iterator::operator++(); + return *this; } @@ -952,14 +924,13 @@ SpMat::row_iterator::operator++() template inline -arma_warn_unused typename SpMat::row_iterator SpMat::row_iterator::operator++(int) { typename SpMat::row_iterator tmp(*this); - + const_row_iterator::operator++(); - + return tmp; } @@ -967,11 +938,11 @@ SpMat::row_iterator::operator++(int) template inline -arma_hot typename SpMat::row_iterator& SpMat::row_iterator::operator--() { const_row_iterator::operator--(); + return *this; } @@ -979,15 +950,15 @@ SpMat::row_iterator::operator--() template inline -arma_warn_unused typename SpMat::row_iterator SpMat::row_iterator::operator--(int) { typename SpMat::row_iterator tmp(*this); - + const_row_iterator::operator--(); - + return tmp; } + //! @} diff --git a/src/armadillo_bits/SpMat_meat.hpp b/src/armadillo_bits/SpMat_meat.hpp index 5b1fdb1e..b8d51cf6 100644 --- a/src/armadillo_bits/SpMat_meat.hpp +++ b/src/armadillo_bits/SpMat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,9 +31,9 @@ SpMat::SpMat() , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -67,12 +69,12 @@ SpMat::SpMat(const uword in_rows, const uword in_cols) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + init_cold(in_rows, in_cols); } @@ -86,9 +88,9 @@ SpMat::SpMat(const SizeMat& s) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -105,9 +107,9 @@ SpMat::SpMat(const arma_reserve_indicator&, const uword in_rows, const uword , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -125,9 +127,9 @@ SpMat::SpMat(const arma_layout_indicator&, const SpMat& x) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -154,12 +156,12 @@ SpMat::SpMat(const char* text) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + init(std::string(text)); } @@ -171,7 +173,7 @@ SpMat& SpMat::operator=(const char* text) { arma_extra_debug_sigprint(); - + init(std::string(text)); return *this; @@ -187,12 +189,12 @@ SpMat::SpMat(const std::string& text) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint(); - + init(text); } @@ -220,52 +222,48 @@ SpMat::SpMat(const SpMat& x) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + init(x); } -#if defined(ARMA_USE_CXX11) - - template - inline - SpMat::SpMat(SpMat&& in_mat) - : n_rows(0) - , n_cols(0) - , n_elem(0) - , n_nonzero(0) - , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) - { - arma_extra_debug_sigprint_this(this); - arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); - - (*this).steal_mem(in_mat); - } - +template +inline +SpMat::SpMat(SpMat&& in_mat) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); + (*this).steal_mem(in_mat); + } + + + +template +inline +SpMat& +SpMat::operator=(SpMat&& in_mat) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); - template - inline - SpMat& - SpMat::operator=(SpMat&& in_mat) - { - arma_extra_debug_sigprint(arma_str::format("this = %x in_mat = %x") % this % &in_mat); - - (*this).steal_mem(in_mat); - - return *this; - } + (*this).steal_mem(in_mat); -#endif + return *this; + } @@ -277,9 +275,9 @@ SpMat::SpMat(const MapMat& x) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -316,19 +314,19 @@ SpMat::SpMat(const Base& locations_expr, const Base& vals_e , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - const unwrap locs_tmp( locations_expr.get_ref() ); - const unwrap vals_tmp( vals_expr.get_ref() ); + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); const Mat& locs = locs_tmp.M; const Mat& vals = vals_tmp.M; - arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); @@ -343,15 +341,12 @@ SpMat::SpMat(const Base& locations_expr, const Base& vals_e const uword N_old = vals.n_elem; uword N_new = 0; - for(uword i = 0; i < N_old; ++i) - { - if(vals[i] != eT(0)) { ++N_new; } - } + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } if(N_new != N_old) { - Col filtered_vals(N_new); - Mat filtered_locs(2, N_new); + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); uword index = 0; for(uword i = 0; i < N_old; ++i) @@ -392,39 +387,36 @@ SpMat::SpMat(const Base& locations_expr, const Base& vals_e , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - const unwrap locs_tmp( locations_expr.get_ref() ); - const unwrap vals_tmp( vals_expr.get_ref() ); + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); const Mat& locs = locs_tmp.M; const Mat& vals = vals_tmp.M; - arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); init_cold(in_n_rows, in_n_cols); - + // Ensure that there are no zeros, unless the user asked not to. if(check_for_zeros) { const uword N_old = vals.n_elem; uword N_new = 0; - for(uword i = 0; i < N_old; ++i) - { - if(vals[i] != eT(0)) { ++N_new; } - } + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } if(N_new != N_old) { - Col filtered_vals(N_new); - Mat filtered_locs(2, N_new); + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); uword index = 0; for(uword i = 0; i < N_old; ++i) @@ -464,39 +456,36 @@ SpMat::SpMat(const bool add_values, const Base& locations_expr, co , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - const unwrap locs_tmp( locations_expr.get_ref() ); - const unwrap vals_tmp( vals_expr.get_ref() ); + const quasi_unwrap locs_tmp( locations_expr.get_ref() ); + const quasi_unwrap vals_tmp( vals_expr.get_ref() ); const Mat& locs = locs_tmp.M; const Mat& vals = vals_tmp.M; - arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); arma_debug_check( (locs.n_rows != 2), "SpMat::SpMat(): locations matrix must have two rows" ); arma_debug_check( (locs.n_cols != vals.n_elem), "SpMat::SpMat(): number of locations is different than number of values" ); init_cold(in_n_rows, in_n_cols); - + // Ensure that there are no zeros, unless the user asked not to. if(check_for_zeros) { const uword N_old = vals.n_elem; uword N_new = 0; - for(uword i = 0; i < N_old; ++i) - { - if(vals[i] != eT(0)) { ++N_new; } - } + for(uword i=0; i < N_old; ++i) { N_new += (vals[i] != eT(0)) ? uword(1) : uword(0); } if(N_new != N_old) { - Col filtered_vals(N_new); - Mat filtered_locs(2, N_new); + Col filtered_vals( N_new, arma_nozeros_indicator()); + Mat filtered_locs(2, N_new, arma_nozeros_indicator()); uword index = 0; for(uword i = 0; i < N_old; ++i) @@ -543,30 +532,31 @@ SpMat::SpMat const Base& colptr_expr, const Base& values_expr, const uword in_n_rows, - const uword in_n_cols + const uword in_n_cols, + const bool check_for_zeros ) : n_rows(0) , n_cols(0) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - const unwrap rowind_tmp( rowind_expr.get_ref() ); - const unwrap colptr_tmp( colptr_expr.get_ref() ); - const unwrap vals_tmp( values_expr.get_ref() ); + const quasi_unwrap rowind_tmp( rowind_expr.get_ref() ); + const quasi_unwrap colptr_tmp( colptr_expr.get_ref() ); + const quasi_unwrap vals_tmp( values_expr.get_ref() ); const Mat& rowind = rowind_tmp.M; const Mat& colptr = colptr_tmp.M; const Mat& vals = vals_tmp.M; - arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object is not a vector" ); - arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object is not a vector" ); - arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object is not a vector" ); + arma_debug_check( (rowind.is_vec() == false), "SpMat::SpMat(): given 'rowind' object must be a vector" ); + arma_debug_check( (colptr.is_vec() == false), "SpMat::SpMat(): given 'colptr' object must be a vector" ); + arma_debug_check( (vals.is_vec() == false), "SpMat::SpMat(): given 'values' object must be a vector" ); // Resize to correct number of elements (this also sets n_nonzero) init_cold(in_n_rows, in_n_cols, vals.n_elem); @@ -583,7 +573,7 @@ SpMat::SpMat access::rw(col_ptrs[n_cols + 1]) = std::numeric_limits::max(); // make sure no zeros are stored - remove_zeros(); + if(check_for_zeros) { remove_zeros(); } } @@ -601,9 +591,9 @@ SpMat::operator=(const eT val) init(1, 1, 1); // Sets col_ptrs to 0. // Manually set element. - access::rw(values[0]) = val; + access::rw(values[0]) = val; access::rw(row_indices[0]) = 0; - access::rw(col_ptrs[1]) = 1; + access::rw(col_ptrs[1]) = 1; } else { @@ -782,16 +772,14 @@ SpMat::operator/=(const SpMat& x) { arma_extra_debug_sigprint(); + // NOTE: use of this function is not advised; it is implemented only for completeness + arma_debug_assert_same_size(n_rows, n_cols, x.n_rows, x.n_cols, "element-wise division"); - // If you use this method, you are probably stupid or misguided, - // but for compatibility with Mat, we have implemented it anyway. for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) { - for(uword r = 0; r < n_rows; ++r) - { - at(r, c) /= x.at(r, c); - } + at(r, c) /= x.at(r, c); } return *this; @@ -808,17 +796,17 @@ SpMat::SpMat(const SpToDOp& expr) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + typedef typename T1::elem_type T; - + // Make sure the type is compatible. arma_type_check(( is_same_type< eT, T >::no )); - + op_type::apply(*this, expr); } @@ -838,9 +826,9 @@ SpMat::SpMat , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint(); @@ -876,7 +864,7 @@ SpMat::SpMat uword cur_pos = 0; - while ((x_it != x_end) || (y_it != y_end)) + while((x_it != x_end) || (y_it != y_end)) { if(x_it == y_it) // if we are at the same place { @@ -911,7 +899,7 @@ SpMat::SpMat } // Now fix the column pointers; they are supposed to be a sum. - for (uword c = 1; c <= n_cols; ++c) + for(uword c = 1; c <= n_cols; ++c) { access::rw(col_ptrs[c]) += col_ptrs[c - 1]; } @@ -929,12 +917,12 @@ SpMat::SpMat(const Base& x) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + (*this).operator=(x); } @@ -978,10 +966,7 @@ SpMat::operator=(const Base& expr) const eT* x_mem = x.memptr(); - for(uword i = 0; i < x_n_elem; ++i) - { - n += (x_mem[i] != eT(0)) ? uword(1) : uword(0); - } + for(uword i=0; i < x_n_elem; ++i) { n += (x_mem[i] != eT(0)) ? uword(1) : uword(0); } init(x_n_rows, x_n_cols, n); @@ -1048,101 +1033,18 @@ template template inline SpMat& -SpMat::operator*=(const Base& y) +SpMat::operator*=(const Base& x) { arma_extra_debug_sigprint(); - - sync_csc(); - - const Proxy p(y.get_ref()); - - arma_debug_assert_mul_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "matrix multiplication"); - - // We assume the matrix structure is such that we will end up with a sparse - // matrix. Assuming that every entry in the dense matrix is nonzero (which is - // a fairly valid assumption), each row with any nonzero elements in it (in this - // matrix) implies an entire nonzero column. Therefore, we iterate over all - // the row_indices and count the number of rows with any elements in them - // (using the quasi-linked-list idea from SYMBMM -- see spglue_times_meat.hpp). - podarray index(n_rows); - index.fill(n_rows); // Fill with invalid links. - - uword last_index = n_rows + 1; - for(uword i = 0; i < n_nonzero; ++i) - { - if(index[row_indices[i]] == n_rows) - { - index[row_indices[i]] = last_index; - last_index = row_indices[i]; - } - } - - // Now count the number of rows which have nonzero elements. - uword nonzero_rows = 0; - while(last_index != n_rows + 1) - { - ++nonzero_rows; - last_index = index[last_index]; - } - - SpMat z(arma_reserve_indicator(), n_rows, p.get_n_cols(), (nonzero_rows * p.get_n_cols())); // upper bound on size - - // Now we have to fill all the elements using a modification of the NUMBMM algorithm. - uword cur_pos = 0; - - podarray partial_sums(n_rows); - partial_sums.zeros(); - - for(uword lcol = 0; lcol < n_cols; ++lcol) - { - const_iterator it = begin(); - const_iterator it_end = end(); - - while(it != it_end) - { - const eT value = (*it); - - partial_sums[it.row()] += (value * p.at(it.col(), lcol)); - - ++it; - } - - // Now add all partial sums to the matrix. - for(uword i = 0; i < n_rows; ++i) - { - if(partial_sums[i] != eT(0)) - { - access::rw(z.values[cur_pos]) = partial_sums[i]; - access::rw(z.row_indices[cur_pos]) = i; - ++access::rw(z.col_ptrs[lcol + 1]); - //printf("colptr %d now %d\n", lcol + 1, z.col_ptrs[lcol + 1]); - ++cur_pos; - partial_sums[i] = 0; // Would it be faster to do this in batch later? - } - } - } - - // Now fix the column pointers. - for(uword c = 1; c <= z.n_cols; ++c) - { - access::rw(z.col_ptrs[c]) += z.col_ptrs[c - 1]; - } - - // Resize to final correct size. - z.mem_resize(z.col_ptrs[z.n_cols]); - // Now take the memory of the temporary matrix. - steal_mem(z); + sync_csc(); - return *this; + return (*this).operator=( (*this) * x.get_ref() ); } -/** - * Don't use this function. It's not mathematically well-defined and wastes - * cycles to trash all your data. This is dumb. - */ +// NOTE: use of this function is not advised; it is implemented only for completeness template template inline @@ -1150,7 +1052,7 @@ SpMat& SpMat::operator/=(const Base& x) { arma_extra_debug_sigprint(); - + sync_csc(); SpMat tmp = (*this) / x.get_ref(); @@ -1169,59 +1071,40 @@ SpMat& SpMat::operator%=(const Base& x) { arma_extra_debug_sigprint(); - - sync_csc(); - const Proxy p(x.get_ref()); + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; - arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise multiplication"); + arma_debug_assert_same_size(n_rows, n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); - // Count the number of elements we will need. - const_iterator it = begin(); - const_iterator it_end = end(); + sync_csc(); + invalidate_cache(); - uword new_n_nonzero = 0; + constexpr eT zero = eT(0); - while(it != it_end) + bool has_zero = false; + + for(uword c=0; c < n_cols; ++c) { - // use_at == false can't save us any work here - if(((*it) * p.at(it.row(), it.col())) != eT(0)) + const uword index_start = col_ptrs[c ]; + const uword index_end = col_ptrs[c + 1]; + + for(uword i=index_start; i < index_end; ++i) { - ++new_n_nonzero; + const uword r = row_indices[i]; + + eT& val = access::rw(values[i]); + + const eT result = val * B.at(r,c); + + val = result; + + if(result == zero) { has_zero = true; } } - ++it; } - SpMat tmp(arma_reserve_indicator(), n_rows, n_cols, new_n_nonzero); - - const_iterator c_it = begin(); - const_iterator c_it_end = end(); - - uword cur_pos = 0; + if(has_zero) { remove_zeros(); } - while(c_it != c_it_end) - { - // use_at == false can't save us any work here - const eT val = (*c_it) * p.at(c_it.row(), c_it.col()); - if(val != eT(0)) - { - access::rw(tmp.values[cur_pos]) = val; - access::rw(tmp.row_indices[cur_pos]) = c_it.row(); - ++access::rw(tmp.col_ptrs[c_it.col() + 1]); - ++cur_pos; - } - - ++c_it; - } - - // Fix column pointers. - for(uword c = 1; c <= n_cols; ++c) - { - access::rw(tmp.col_ptrs[c]) += tmp.col_ptrs[c - 1]; - } - - steal_mem(tmp); - return *this; } @@ -1236,9 +1119,9 @@ SpMat::SpMat(const Op& expr) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -1255,51 +1138,38 @@ SpMat::operator=(const Op& expr) { arma_extra_debug_sigprint(); - const Proxy P(expr.m); + const diagmat_proxy P(expr.m); + + const uword max_n_nonzero = (std::min)(P.n_rows, P.n_cols); - const uword P_n_rows = P.get_n_rows(); - const uword P_n_cols = P.get_n_cols(); + // resize memory to upper bound + init(P.n_rows, P.n_cols, max_n_nonzero); - const bool P_is_vec = (P_n_rows == 1) || (P_n_cols == 1); + uword count = 0; - if(P_is_vec) // generate diagonal sparse matrix from dense vector + for(uword i=0; i < max_n_nonzero; ++i) { - const uword N = (P_n_rows == 1) ? P_n_cols : P_n_rows; - - (*this).eye(N,N); - - eT* this_values = access::rwp(values); + const eT val = P[i]; - if(Proxy::use_at == false) - { - typename Proxy::ea_type P_ea = P.get_ea(); - - for(uword i=0; i < N; ++i) { this_values[i] = P_ea[i]; } - } - else + if(val != eT(0)) { - if(P_n_rows == 1) - { - for(uword i=0; i < N; ++i) { this_values[i] = P.at(0,i); } - } - else - { - for(uword i=0; i < N; ++i) { this_values[i] = P.at(i,0); } - } + access::rw(values[count]) = val; + access::rw(row_indices[count]) = i; + access::rw(col_ptrs[i + 1])++; + ++count; } } - else // generate diagonal sparse matrix from dense matrix + + // fix column pointers to be cumulative + for(uword i = 1; i < n_cols + 1; ++i) { - (*this).eye(P_n_rows, P_n_cols); - - eT* this_values = access::rwp(values); - - const uword N = (std::min)(P_n_rows, P_n_cols); - - for(uword i=0; i < N; ++i) { this_values[i] = P.at(i,i); } + access::rw(col_ptrs[i]) += col_ptrs[i - 1]; } - remove_zeros(); + // quick resize without reallocating memory and copying data + access::rw( n_nonzero) = count; + access::rw( values[count]) = eT(0); + access::rw(row_indices[count]) = uword(0); return *this; } @@ -1392,12 +1262,12 @@ SpMat::SpMat(const SpSubview& X) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - + (*this).operator=(X); } @@ -1415,35 +1285,62 @@ SpMat::operator=(const SpSubview& X) X.m.sync_csc(); const bool alias = (this == &(X.m)); - - if(alias == false) + + if(alias) + { + SpMat tmp(X); + + steal_mem(tmp); + } + else { init(X.n_rows, X.n_cols, X.n_nonzero); - - typename SpSubview::const_iterator it = X.begin(); - typename SpSubview::const_iterator it_end = X.end(); - - while(it != it_end) + + if(X.n_rows == X.m.n_rows) { - access::rw(row_indices[it.pos()]) = it.row(); - access::rw(values[it.pos()]) = (*it); - ++access::rw(col_ptrs[it.col() + 1]); - ++it; + const uword sv_col_start = X.aux_col1; + const uword sv_col_end = X.aux_col1 + X.n_cols - 1; + + typename SpMat::const_col_iterator m_it = X.m.begin_col_no_sync(sv_col_start); + typename SpMat::const_col_iterator m_it_end = X.m.end_col_no_sync(sv_col_end); + + uword count = 0; + + while(m_it != m_it_end) + { + const uword m_it_col_adjusted = m_it.col() - sv_col_start; + + access::rw(row_indices[count]) = m_it.row(); + access::rw(values[count]) = (*m_it); + ++access::rw(col_ptrs[m_it_col_adjusted + 1]); + + count++; + + ++m_it; + } } - + else + { + typename SpSubview::const_iterator it = X.begin(); + typename SpSubview::const_iterator it_end = X.end(); + + while(it != it_end) + { + const uword it_pos = it.pos(); + + access::rw(row_indices[it_pos]) = it.row(); + access::rw(values[it_pos]) = (*it); + ++access::rw(col_ptrs[it.col() + 1]); + ++it; + } + } + // Now sum column pointers. for(uword c = 1; c <= n_cols; ++c) { access::rw(col_ptrs[c]) += col_ptrs[c - 1]; } } - else - { - // Create it in a temporary. - SpMat tmp(X); - - steal_mem(tmp); - } return *this; } @@ -1543,32 +1440,45 @@ SpMat::operator/=(const SpSubview& x) template +template inline -SpMat::SpMat(const spdiagview& X) +SpMat::SpMat(const SpSubview_col_list& X) : n_rows(0) , n_cols(0) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); - - spdiagview::extract(*this, X); + + SpSubview_col_list::extract(*this, X); } template +template inline SpMat& -SpMat::operator=(const spdiagview& X) +SpMat::operator=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - spdiagview::extract(*this, X); + const bool alias = (this == &(X.m)); + + if(alias == false) + { + SpSubview_col_list::extract(*this, X); + } + else + { + SpMat tmp(X); + + steal_mem(tmp); + } return *this; } @@ -1576,87 +1486,199 @@ SpMat::operator=(const spdiagview& X) template +template inline SpMat& -SpMat::operator+=(const spdiagview& X) +SpMat::operator+=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - const SpMat tmp(X); + SpSubview_col_list::plus_inplace(*this, X); - return (*this).operator+=(tmp); + return *this; } template +template inline SpMat& -SpMat::operator-=(const spdiagview& X) +SpMat::operator-=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - const SpMat tmp(X); + SpSubview_col_list::minus_inplace(*this, X); - return (*this).operator-=(tmp); + return *this; } template +template inline SpMat& -SpMat::operator*=(const spdiagview& X) +SpMat::operator*=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - const SpMat tmp(X); + sync_csc(); - return (*this).operator*=(tmp); + SpMat z = (*this) * X; + + steal_mem(z); + + return *this; } template +template inline SpMat& -SpMat::operator%=(const spdiagview& X) +SpMat::operator%=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - const SpMat tmp(X); + SpSubview_col_list::schur_inplace(*this, X); - return (*this).operator%=(tmp); + return *this; } template +template inline SpMat& -SpMat::operator/=(const spdiagview& X) +SpMat::operator/=(const SpSubview_col_list& X) { arma_extra_debug_sigprint(); - const SpMat tmp(X); + SpSubview_col_list::div_inplace(*this, X); - return (*this).operator/=(tmp); + return *this; } template -template inline -SpMat::SpMat(const SpOp& X) +SpMat::SpMat(const spdiagview& X) : n_rows(0) , n_cols(0) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) // set in application of sparse operation - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) + { + arma_extra_debug_sigprint_this(this); + + spdiagview::extract(*this, X); + } + + + +template +inline +SpMat& +SpMat::operator=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + spdiagview::extract(*this, X); + + return *this; + } + + + +template +inline +SpMat& +SpMat::operator+=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator+=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator-=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator-=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator*=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator*=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator%=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator%=(tmp); + } + + + +template +inline +SpMat& +SpMat::operator/=(const spdiagview& X) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(X); + + return (*this).operator/=(tmp); + } + + + +template +template +inline +SpMat::SpMat(const SpOp& X) + : n_rows(0) + , n_cols(0) + , n_elem(0) + , n_nonzero(0) + , vec_state(0) + , values(nullptr) // set in application of sparse operation + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -1794,9 +1816,9 @@ SpMat::SpMat(const SpGlue& X) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -1934,9 +1956,9 @@ SpMat::SpMat(const mtSpOp& X) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -1960,7 +1982,7 @@ SpMat::operator=(const mtSpOp& X) sync_csc(); // in case apply() used element accessors invalidate_cache(); // in case apply() modified the CSC representation - + return *this; } @@ -1977,7 +1999,7 @@ SpMat::operator+=(const mtSpOp& X) sync_csc(); const SpMat m(X); - + return (*this).operator+=(m); } @@ -1994,7 +2016,7 @@ SpMat::operator-=(const mtSpOp& X) sync_csc(); const SpMat m(X); - + return (*this).operator-=(m); } @@ -2011,7 +2033,7 @@ SpMat::operator*=(const mtSpOp& X) sync_csc(); const SpMat m(X); - + return (*this).operator*=(m); } @@ -2028,7 +2050,7 @@ SpMat::operator%=(const mtSpOp& X) sync_csc(); const SpMat m(X); - + return (*this).operator%=(m); } @@ -2045,7 +2067,7 @@ SpMat::operator/=(const mtSpOp& X) sync_csc(); const SpMat m(X); - + return (*this).operator/=(m); } @@ -2060,9 +2082,9 @@ SpMat::SpMat(const mtSpGlue& X) , n_elem(0) , n_nonzero(0) , vec_state(0) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -2179,35 +2201,35 @@ SpMat::operator/=(const mtSpGlue& X) template arma_inline -SpSubview +SpSubview_row SpMat::row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds"); - - return SpSubview(*this, row_num, 0, 1, n_cols); + arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds"); + + return SpSubview_row(*this, row_num); } template arma_inline -const SpSubview +const SpSubview_row SpMat::row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check(row_num >= n_rows, "SpMat::row(): out of bounds"); - - return SpSubview(*this, row_num, 0, 1, n_cols); + arma_debug_check_bounds(row_num >= n_rows, "SpMat::row(): out of bounds"); + + return SpSubview_row(*this, row_num); } template inline -SpSubview +SpSubview_row SpMat::operator()(const uword row_num, const span& col_span) { arma_extra_debug_sigprint(); @@ -2220,7 +2242,7 @@ SpMat::operator()(const uword row_num, const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -2229,14 +2251,14 @@ SpMat::operator()(const uword row_num, const span& col_span) "SpMat::operator(): indices out of bounds or incorrectly used" ); - return SpSubview(*this, row_num, in_col1, 1, submat_n_cols); + return SpSubview_row(*this, row_num, in_col1, submat_n_cols); } template inline -const SpSubview +const SpSubview_row SpMat::operator()(const uword row_num, const span& col_span) const { arma_extra_debug_sigprint(); @@ -2249,7 +2271,7 @@ SpMat::operator()(const uword row_num, const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -2258,42 +2280,42 @@ SpMat::operator()(const uword row_num, const span& col_span) const "SpMat::operator(): indices out of bounds or incorrectly used" ); - return SpSubview(*this, row_num, in_col1, 1, submat_n_cols); + return SpSubview_row(*this, row_num, in_col1, submat_n_cols); } template arma_inline -SpSubview +SpSubview_col SpMat::col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds"); - - return SpSubview(*this, 0, col_num, n_rows, 1); + arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds"); + + return SpSubview_col(*this, col_num); } template arma_inline -const SpSubview +const SpSubview_col SpMat::col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check(col_num >= n_cols, "SpMat::col(): out of bounds"); - - return SpSubview(*this, 0, col_num, n_rows, 1); + arma_debug_check_bounds(col_num >= n_cols, "SpMat::col(): out of bounds"); + + return SpSubview_col(*this, col_num); } template inline -SpSubview +SpSubview_col SpMat::operator()(const span& row_span, const uword col_num) { arma_extra_debug_sigprint(); @@ -2306,7 +2328,7 @@ SpMat::operator()(const span& row_span, const uword col_num) const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -2315,14 +2337,14 @@ SpMat::operator()(const span& row_span, const uword col_num) "SpMat::operator(): indices out of bounds or incorrectly used" ); - return SpSubview(*this, in_row1, col_num, submat_n_rows, 1); + return SpSubview_col(*this, col_num, in_row1, submat_n_rows); } template inline -const SpSubview +const SpSubview_col SpMat::operator()(const span& row_span, const uword col_num) const { arma_extra_debug_sigprint(); @@ -2335,7 +2357,7 @@ SpMat::operator()(const span& row_span, const uword col_num) const const uword in_row2 = row_span.b; const uword submat_n_rows = row_all ? local_n_rows : in_row2 - in_row1 + 1; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -2344,7 +2366,7 @@ SpMat::operator()(const span& row_span, const uword col_num) const "SpMat::operator(): indices out of bounds or incorrectly used" ); - return SpSubview(*this, in_row1, col_num, submat_n_rows, 1); + return SpSubview_col(*this, col_num, in_row1, submat_n_rows); } @@ -2356,14 +2378,14 @@ SpMat::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "SpMat::rows(): indices out of bounds or incorrectly used" ); - + const uword subview_n_rows = in_row2 - in_row1 + 1; - + return SpSubview(*this, in_row1, 0, subview_n_rows, n_cols); } @@ -2376,14 +2398,14 @@ SpMat::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "SpMat::rows(): indices out of bounds or incorrectly used" ); - + const uword subview_n_rows = in_row2 - in_row1 + 1; - + return SpSubview(*this, in_row1, 0, subview_n_rows, n_cols); } @@ -2396,14 +2418,14 @@ SpMat::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "SpMat::cols(): indices out of bounds or incorrectly used" ); - + const uword subview_n_cols = in_col2 - in_col1 + 1; - + return SpSubview(*this, 0, in_col1, n_rows, subview_n_cols); } @@ -2416,14 +2438,14 @@ SpMat::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "SpMat::cols(): indices out of bounds or incorrectly used" ); - + const uword subview_n_cols = in_col2 - in_col1 + 1; - + return SpSubview(*this, 0, in_col1, n_rows, subview_n_cols); } @@ -2436,15 +2458,15 @@ SpMat::submat(const uword in_row1, const uword in_col1, const uword in_row2, { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "SpMat::submat(): indices out of bounds or incorrectly used" ); - + const uword subview_n_rows = in_row2 - in_row1 + 1; const uword subview_n_cols = in_col2 - in_col1 + 1; - + return SpSubview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); } @@ -2457,15 +2479,15 @@ SpMat::submat(const uword in_row1, const uword in_col1, const uword in_row2, { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "SpMat::submat(): indices out of bounds or incorrectly used" ); - + const uword subview_n_rows = in_row2 - in_row1 + 1; const uword subview_n_cols = in_col2 - in_col1 + 1; - + return SpSubview(*this, in_row1, in_col1, subview_n_rows, subview_n_cols); } @@ -2484,7 +2506,7 @@ SpMat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "SpMat::submat(): indices or size out of bounds" @@ -2508,7 +2530,7 @@ SpMat::submat(const uword in_row1, const uword in_col1, const SizeMat& s) co const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "SpMat::submat(): indices or size out of bounds" @@ -2540,7 +2562,7 @@ SpMat::submat(const span& row_span, const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -2575,7 +2597,7 @@ SpMat::submat(const span& row_span, const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -2644,7 +2666,7 @@ SpMat::head_rows(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "SpMat::head_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" ); return SpSubview(*this, 0, 0, N, n_cols); } @@ -2658,7 +2680,7 @@ SpMat::head_rows(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "SpMat::head_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "SpMat::head_rows(): size out of bounds" ); return SpSubview(*this, 0, 0, N, n_cols); } @@ -2672,7 +2694,7 @@ SpMat::tail_rows(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "SpMat::tail_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" ); const uword start_row = n_rows - N; @@ -2688,7 +2710,7 @@ SpMat::tail_rows(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_rows), "SpMat::tail_rows(): size out of bounds"); + arma_debug_check_bounds( (N > n_rows), "SpMat::tail_rows(): size out of bounds" ); const uword start_row = n_rows - N; @@ -2704,7 +2726,7 @@ SpMat::head_cols(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "SpMat::head_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" ); return SpSubview(*this, 0, 0, n_rows, N); } @@ -2718,7 +2740,7 @@ SpMat::head_cols(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "SpMat::head_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "SpMat::head_cols(): size out of bounds" ); return SpSubview(*this, 0, 0, n_rows, N); } @@ -2732,7 +2754,7 @@ SpMat::tail_cols(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "SpMat::tail_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" ); const uword start_col = n_cols - N; @@ -2748,7 +2770,7 @@ SpMat::tail_cols(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > n_cols), "SpMat::tail_cols(): size out of bounds"); + arma_debug_check_bounds( (N > n_cols), "SpMat::tail_cols(): size out of bounds" ); const uword start_col = n_cols - N; @@ -2757,6 +2779,32 @@ SpMat::tail_cols(const uword N) const +template +template +arma_inline +SpSubview_col_list +SpMat::cols(const Base& indices) + { + arma_extra_debug_sigprint(); + + return SpSubview_col_list(*this, indices); + } + + + +template +template +arma_inline +const SpSubview_col_list +SpMat::cols(const Base& indices) const + { + arma_extra_debug_sigprint(); + + return SpSubview_col_list(*this, indices); + } + + + //! creation of spdiagview (diagonal) template inline @@ -2768,7 +2816,7 @@ SpMat::diag(const sword in_id) const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; const uword col_offset = (in_id > 0) ? uword( in_id) : 0; - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "SpMat::diag(): requested diagonal out of bounds" @@ -2792,7 +2840,7 @@ SpMat::diag(const sword in_id) const const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "SpMat::diag(): requested diagonal out of bounds" @@ -2812,7 +2860,7 @@ SpMat::swap_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check( ((in_row1 >= n_rows) || (in_row2 >= n_rows)), "SpMat::swap_rows(): out of bounds" ); + arma_debug_check_bounds( ((in_row1 >= n_rows) || (in_row2 >= n_rows)), "SpMat::swap_rows(): out of bounds" ); if(in_row1 == in_row2) { return; } @@ -2824,85 +2872,85 @@ SpMat::swap_rows(const uword in_row1, const uword in_row2) // We will try to avoid using the at() call since it is expensive, instead preferring to use an iterator to track our position. uword col1 = (in_row1 < in_row2) ? in_row1 : in_row2; uword col2 = (in_row1 < in_row2) ? in_row2 : in_row1; - - for (uword lcol = 0; lcol < n_cols; lcol++) + + for(uword lcol = 0; lcol < n_cols; lcol++) { // If there is nothing in this column we can ignore it. - if (col_ptrs[lcol] == col_ptrs[lcol + 1]) + if(col_ptrs[lcol] == col_ptrs[lcol + 1]) { continue; } - + // These will represent the positions of the items themselves. uword loc1 = n_nonzero + 1; uword loc2 = n_nonzero + 1; - - for (uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++) + + for(uword search_pos = col_ptrs[lcol]; search_pos < col_ptrs[lcol + 1]; search_pos++) { - if (row_indices[search_pos] == col1) + if(row_indices[search_pos] == col1) { loc1 = search_pos; } - - if (row_indices[search_pos] == col2) + + if(row_indices[search_pos] == col2) { loc2 = search_pos; break; // No need to look any further. } } - + // There are four cases: we found both elements; we found one element (loc1); we found one element (loc2); we found zero elements. // If we found zero elements no work needs to be done and we can continue to the next column. - if ((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1))) + if((loc1 != (n_nonzero + 1)) && (loc2 != (n_nonzero + 1))) { // This is an easy case: just swap the values. No index modifying necessary. eT tmp = values[loc1]; access::rw(values[loc1]) = values[loc2]; access::rw(values[loc2]) = tmp; } - else if (loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2. + else if(loc1 != (n_nonzero + 1)) // We only found loc1 and not loc2. { // We need to find the correct place to move our value to. It will be forward (not backwards) because in_row2 > in_row1. // Each iteration of the loop swaps the current value (loc1) with (loc1 + 1); in this manner we move our value down to where it should be. - while (((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2)) + while(((loc1 + 1) < col_ptrs[lcol + 1]) && (row_indices[loc1 + 1] < in_row2)) { // Swap both the values and the indices. The column should not change. eT tmp = values[loc1]; access::rw(values[loc1]) = values[loc1 + 1]; access::rw(values[loc1 + 1]) = tmp; - + uword tmp_index = row_indices[loc1]; access::rw(row_indices[loc1]) = row_indices[loc1 + 1]; access::rw(row_indices[loc1 + 1]) = tmp_index; - + loc1++; // And increment the counter. } - + // Now set the row index correctly. access::rw(row_indices[loc1]) = in_row2; - + } - else if (loc2 != (n_nonzero + 1)) + else if(loc2 != (n_nonzero + 1)) { // We need to find the correct place to move our value to. It will be backwards (not forwards) because in_row1 < in_row2. // Each iteration of the loop swaps the current value (loc2) with (loc2 - 1); in this manner we move our value up to where it should be. - while (((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1)) + while(((loc2 - 1) >= col_ptrs[lcol]) && (row_indices[loc2 - 1] > in_row1)) { // Swap both the values and the indices. The column should not change. eT tmp = values[loc2]; access::rw(values[loc2]) = values[loc2 - 1]; access::rw(values[loc2 - 1]) = tmp; - + uword tmp_index = row_indices[loc2]; access::rw(row_indices[loc2]) = row_indices[loc2 - 1]; access::rw(row_indices[loc2 - 1]) = tmp_index; - + loc2--; // And decrement the counter. } - + // Now set the row index correctly. access::rw(row_indices[loc2]) = in_row1; - + } /* else: no need to swap anything; both values are zero */ } @@ -2917,18 +2965,17 @@ SpMat::swap_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check( ((in_col1 >= n_cols) || (in_col2 >= n_cols)), "SpMat::swap_cols(): out of bounds" ); + arma_debug_check_bounds( ((in_col1 >= n_cols) || (in_col2 >= n_cols)), "SpMat::swap_cols(): out of bounds" ); if(in_col1 == in_col2) { return; } // TODO: this is a rudimentary implementation - SpMat tmp = (*this); - - tmp.col(in_col1) = (*this).col(in_col2); - tmp.col(in_col2) = (*this).col(in_col1); + const SpMat tmp1 = (*this).col(in_col1); + const SpMat tmp2 = (*this).col(in_col2); - steal_mem(tmp); + (*this).col(in_col2) = tmp1; + (*this).col(in_col1) = tmp2; // for(uword lrow = 0; lrow < n_rows; ++lrow) // { @@ -2947,8 +2994,8 @@ SpMat::shed_row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check (row_num >= n_rows, "SpMat::shed_row(): out of bounds"); - + arma_debug_check_bounds(row_num >= n_rows, "SpMat::shed_row(): out of bounds"); + shed_rows (row_num, row_num); } @@ -2961,8 +3008,8 @@ SpMat::shed_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check (col_num >= n_cols, "SpMat::shed_col(): out of bounds"); - + arma_debug_check_bounds(col_num >= n_cols, "SpMat::shed_col(): out of bounds"); + shed_cols(col_num, col_num); } @@ -2975,7 +3022,7 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "SpMat::shed_rows(): indices out of bounds or incorectly used" @@ -2987,10 +3034,10 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) // First, count the number of elements we will be removing. uword removing = 0; - for (uword i = 0; i < n_nonzero; ++i) + for(uword i = 0; i < n_nonzero; ++i) { const uword lrow = row_indices[i]; - if (lrow >= in_row1 && lrow <= in_row2) + if(lrow >= in_row1 && lrow <= in_row2) { ++removing; } @@ -2998,7 +3045,7 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) // Obtain counts of the number of points in each column and store them as the // (invalid) column pointers of the new matrix. - for (uword i = 1; i < n_cols + 1; ++i) + for(uword i = 1; i < n_cols + 1; ++i) { access::rw(newmat.col_ptrs[i]) = col_ptrs[i] - col_ptrs[i - 1]; } @@ -3008,16 +3055,16 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) // Now, copy over the elements. // i is the index in the old matrix; j is the index in the new matrix. - const_iterator it = begin(); - const_iterator it_end = end(); + const_iterator it = cbegin(); + const_iterator it_end = cend(); uword j = 0; // The index in the new matrix. - while (it != it_end) + while(it != it_end) { const uword lrow = it.row(); const uword lcol = it.col(); - if (lrow >= in_row1 && lrow <= in_row2) + if(lrow >= in_row1 && lrow <= in_row2) { // This element is being removed. Subtract it from the column counts. --access::rw(newmat.col_ptrs[lcol + 1]); @@ -3026,7 +3073,7 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) { // This element is being kept. We may need to map the row index, // if it is past the section of rows we are removing. - if (lrow > in_row2) + if(lrow > in_row2) { access::rw(newmat.row_indices[j]) = lrow - (in_row2 - in_row1 + 1); } @@ -3043,7 +3090,7 @@ SpMat::shed_rows(const uword in_row1, const uword in_row2) } // Finally, sum the column counts so they are correct column pointers. - for (uword i = 1; i < n_cols + 1; ++i) + for(uword i = 1; i < n_cols + 1; ++i) { access::rw(newmat.col_ptrs[i]) += newmat.col_ptrs[i - 1]; } @@ -3061,7 +3108,7 @@ SpMat::shed_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "SpMat::shed_cols(): indices out of bounds or incorrectly used" @@ -3073,35 +3120,39 @@ SpMat::shed_cols(const uword in_col1, const uword in_col2) // First we find the locations in values and row_indices for the column entries. uword col_beg = col_ptrs[in_col1]; uword col_end = col_ptrs[in_col2 + 1]; - + // Then we find the number of entries in the column. uword diff = col_end - col_beg; - - if (diff > 0) + + if(diff > 0) { - eT* new_values = memory::acquire (n_nonzero - diff); - uword* new_row_indices = memory::acquire(n_nonzero - diff); - + eT* new_values = memory::acquire (n_nonzero + 1 - diff); + uword* new_row_indices = memory::acquire(n_nonzero + 1 - diff); + // Copy first part. - if (col_beg != 0) + if(col_beg != 0) { arrayops::copy(new_values, values, col_beg); arrayops::copy(new_row_indices, row_indices, col_beg); } - + // Copy second part. - if (col_end != n_nonzero) + if(col_end != n_nonzero) { arrayops::copy(new_values + col_beg, values + col_end, n_nonzero - col_end); arrayops::copy(new_row_indices + col_beg, row_indices + col_end, n_nonzero - col_end); } + // Copy sentry element. + new_values[n_nonzero - diff] = values[n_nonzero]; + new_row_indices[n_nonzero - diff] = row_indices[n_nonzero]; + if(values) { memory::release(access::rw(values)); } if(row_indices) { memory::release(access::rw(row_indices)); } - + access::rw(values) = new_values; access::rw(row_indices) = new_row_indices; - + // Update counts and such. access::rw(n_nonzero) -= diff; } @@ -3113,14 +3164,14 @@ SpMat::shed_cols(const uword in_col1, const uword in_col2) new_col_ptrs[new_n_cols + 1] = std::numeric_limits::max(); // Copy first set of columns (no manipulation required). - if (in_col1 != 0) + if(in_col1 != 0) { arrayops::copy(new_col_ptrs, col_ptrs, in_col1); } // Copy second set of columns (manipulation required). uword cur_col = in_col1; - for (uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col) + for(uword i = in_col2 + 1; i <= n_cols; ++i, ++cur_col) { new_col_ptrs[cur_col] = col_ptrs[i] - diff; } @@ -3142,7 +3193,6 @@ SpMat::shed_cols(const uword in_col1, const uword in_col2) template arma_inline -arma_warn_unused SpMat_MapMat_val SpMat::operator[](const uword i) { @@ -3156,7 +3206,6 @@ SpMat::operator[](const uword i) template arma_inline -arma_warn_unused eT SpMat::operator[](const uword i) const { @@ -3167,7 +3216,6 @@ SpMat::operator[](const uword i) const template arma_inline -arma_warn_unused SpMat_MapMat_val SpMat::at(const uword i) { @@ -3181,7 +3229,6 @@ SpMat::at(const uword i) template arma_inline -arma_warn_unused eT SpMat::at(const uword i) const { @@ -3192,11 +3239,10 @@ SpMat::at(const uword i) const template arma_inline -arma_warn_unused SpMat_MapMat_val SpMat::operator()(const uword i) { - arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" ); const uword in_col = i / n_rows; const uword in_row = i % n_rows; @@ -3208,11 +3254,10 @@ SpMat::operator()(const uword i) template arma_inline -arma_warn_unused eT SpMat::operator()(const uword i) const { - arma_debug_check( (i >= n_elem), "SpMat::operator(): out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "SpMat::operator(): out of bounds" ); return get_value(i); } @@ -3224,9 +3269,32 @@ SpMat::operator()(const uword i) const * If there is nothing at that position, 0 is returned. */ +#if defined(__cpp_multidimensional_subscript) + + template + arma_inline + SpMat_MapMat_val + SpMat::operator[] (const uword in_row, const uword in_col) + { + return SpMat_MapMat_val((*this), cache, in_row, in_col); + } + + + + template + arma_inline + eT + SpMat::operator[] (const uword in_row, const uword in_col) const + { + return get_value(in_row, in_col); + } + +#endif + + + template arma_inline -arma_warn_unused SpMat_MapMat_val SpMat::at(const uword in_row, const uword in_col) { @@ -3237,7 +3305,6 @@ SpMat::at(const uword in_row, const uword in_col) template arma_inline -arma_warn_unused eT SpMat::at(const uword in_row, const uword in_col) const { @@ -3248,11 +3315,10 @@ SpMat::at(const uword in_row, const uword in_col) const template arma_inline -arma_warn_unused SpMat_MapMat_val SpMat::operator()(const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" ); return SpMat_MapMat_val((*this), cache, in_row, in_col); } @@ -3261,11 +3327,10 @@ SpMat::operator()(const uword in_row, const uword in_col) template arma_inline -arma_warn_unused eT SpMat::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "SpMat::operator(): out of bounds" ); return get_value(in_row, in_col); } @@ -3277,7 +3342,6 @@ SpMat::operator()(const uword in_row, const uword in_col) const */ template arma_inline -arma_warn_unused bool SpMat::is_empty() const { @@ -3289,7 +3353,6 @@ SpMat::is_empty() const //! returns true if the object can be interpreted as a column or row vector template arma_inline -arma_warn_unused bool SpMat::is_vec() const { @@ -3301,7 +3364,6 @@ SpMat::is_vec() const //! returns true if the object can be interpreted as a row vector template arma_inline -arma_warn_unused bool SpMat::is_rowvec() const { @@ -3313,7 +3375,6 @@ SpMat::is_rowvec() const //! returns true if the object can be interpreted as a column vector template arma_inline -arma_warn_unused bool SpMat::is_colvec() const { @@ -3325,7 +3386,6 @@ SpMat::is_colvec() const //! returns true if the object has the same number of non-zero rows and columnns template arma_inline -arma_warn_unused bool SpMat::is_square() const { @@ -3334,31 +3394,14 @@ SpMat::is_square() const -//! returns true if all of the elements are finite template inline -arma_warn_unused bool -SpMat::is_finite() const +SpMat::is_symmetric() const { arma_extra_debug_sigprint(); - sync_csc(); - - return arrayops::is_finite(values, n_nonzero); - } - - - -template -inline -arma_warn_unused -bool -SpMat::is_symmetric() const - { - arma_extra_debug_sigprint(); - - const SpMat& A = (*this); + const SpMat& A = (*this); if(A.n_rows != A.n_cols) { return false; } @@ -3371,7 +3414,6 @@ SpMat::is_symmetric() const template inline -arma_warn_unused bool SpMat::is_symmetric(const typename get_pod_type::result tol) const { @@ -3400,7 +3442,6 @@ SpMat::is_symmetric(const typename get_pod_type::result tol) cons template inline -arma_warn_unused bool SpMat::is_hermitian() const { @@ -3419,7 +3460,6 @@ SpMat::is_hermitian() const template inline -arma_warn_unused bool SpMat::is_hermitian(const typename get_pod_type::result tol) const { @@ -3448,9 +3488,22 @@ SpMat::is_hermitian(const typename get_pod_type::result tol) cons template inline -arma_warn_unused bool -SpMat::has_inf() const +SpMat::internal_is_finite() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return arrayops::is_finite(values, n_nonzero); + } + + + +template +inline +bool +SpMat::internal_has_inf() const { arma_extra_debug_sigprint(); @@ -3463,9 +3516,8 @@ SpMat::has_inf() const template inline -arma_warn_unused bool -SpMat::has_nan() const +SpMat::internal_has_nan() const { arma_extra_debug_sigprint(); @@ -3476,10 +3528,23 @@ SpMat::has_nan() const +template +inline +bool +SpMat::internal_has_nonfinite() const + { + arma_extra_debug_sigprint(); + + sync_csc(); + + return (arrayops::is_finite(values, n_nonzero) == false); + } + + + //! returns true if the given index is currently in range template arma_inline -arma_warn_unused bool SpMat::in_range(const uword i) const { @@ -3490,13 +3555,12 @@ SpMat::in_range(const uword i) const //! returns true if the given start and end indices are currently in range template arma_inline -arma_warn_unused bool SpMat::in_range(const span& x) const { arma_extra_debug_sigprint(); - - if(x.whole == true) + + if(x.whole) { return true; } @@ -3504,7 +3568,7 @@ SpMat::in_range(const span& x) const { const uword a = x.a; const uword b = x.b; - + return ( (a <= b) && (b < n_elem) ); } } @@ -3514,7 +3578,6 @@ SpMat::in_range(const span& x) const //! returns true if the given location is currently in range template arma_inline -arma_warn_unused bool SpMat::in_range(const uword in_row, const uword in_col) const { @@ -3525,13 +3588,12 @@ SpMat::in_range(const uword in_row, const uword in_col) const template arma_inline -arma_warn_unused bool SpMat::in_range(const span& row_span, const uword in_col) const { arma_extra_debug_sigprint(); - - if(row_span.whole == true) + + if(row_span.whole) { return (in_col < n_cols); } @@ -3539,7 +3601,7 @@ SpMat::in_range(const span& row_span, const uword in_col) const { const uword in_row1 = row_span.a; const uword in_row2 = row_span.b; - + return ( (in_row1 <= in_row2) && (in_row2 < n_rows) && (in_col < n_cols) ); } } @@ -3548,13 +3610,12 @@ SpMat::in_range(const span& row_span, const uword in_col) const template arma_inline -arma_warn_unused bool SpMat::in_range(const uword in_row, const span& col_span) const { arma_extra_debug_sigprint(); - - if(col_span.whole == true) + + if(col_span.whole) { return (in_row < n_rows); } @@ -3562,7 +3623,7 @@ SpMat::in_range(const uword in_row, const span& col_span) const { const uword in_col1 = col_span.a; const uword in_col2 = col_span.b; - + return ( (in_row < n_rows) && (in_col1 <= in_col2) && (in_col2 < n_cols) ); } } @@ -3571,29 +3632,27 @@ SpMat::in_range(const uword in_row, const span& col_span) const template arma_inline -arma_warn_unused bool SpMat::in_range(const span& row_span, const span& col_span) const { arma_extra_debug_sigprint(); - + const uword in_row1 = row_span.a; const uword in_row2 = row_span.b; - + const uword in_col1 = col_span.a; const uword in_col2 = col_span.b; - + const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); - - return ( (rows_ok == true) && (cols_ok == true) ); + + return ( rows_ok && cols_ok ); } template arma_inline -arma_warn_unused bool SpMat::in_range(const uword in_row, const uword in_col, const SizeMat& s) const { @@ -3612,211 +3671,16 @@ SpMat::in_range(const uword in_row, const uword in_col, const SizeMat& s) co -template -arma_cold -inline -void -SpMat::impl_print(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print(get_cout_stream(), *this, true); - } - - - -template -arma_cold -inline -void -SpMat::impl_print(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print(user_stream, *this, true); - } - - - -template -arma_cold -inline -void -SpMat::impl_raw_print(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print(get_cout_stream(), *this, false); - } - - -template -arma_cold -inline -void -SpMat::impl_raw_print(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print(user_stream, *this, false); - } - - - -/** - * Matrix printing, prepends supplied text. - * Prints 0 wherever no element exists. - */ -template -arma_cold -inline -void -SpMat::impl_print_dense(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print_dense(get_cout_stream(), *this, true); - } - - - -template -arma_cold -inline -void -SpMat::impl_print_dense(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print_dense(user_stream, *this, true); - } - - - -template -arma_cold -inline -void -SpMat::impl_raw_print_dense(const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = get_cout_stream().width(); - - get_cout_stream() << extra_text << '\n'; - - get_cout_stream().width(orig_width); - } - - arma_ostream::print_dense(get_cout_stream(), *this, false); - } - - - -template -arma_cold -inline -void -SpMat::impl_raw_print_dense(std::ostream& user_stream, const std::string& extra_text) const - { - arma_extra_debug_sigprint(); - - sync_csc(); - - if(extra_text.length() != 0) - { - const std::streamsize orig_width = user_stream.width(); - - user_stream << extra_text << '\n'; - - user_stream.width(orig_width); - } - - arma_ostream::print_dense(user_stream, *this, false); - } - - - //! Set the size to the size of another matrix. template template inline -void +SpMat& SpMat::copy_size(const SpMat& m) { arma_extra_debug_sigprint(); - set_size(m.n_rows, m.n_cols); + return set_size(m.n_rows, m.n_cols); } @@ -3824,19 +3688,19 @@ SpMat::copy_size(const SpMat& m) template template inline -void +SpMat& SpMat::copy_size(const Mat& m) { arma_extra_debug_sigprint(); - - set_size(m.n_rows, m.n_cols); + + return set_size(m.n_rows, m.n_cols); } template inline -void +SpMat& SpMat::set_size(const uword in_elem) { arma_extra_debug_sigprint(); @@ -3850,60 +3714,52 @@ SpMat::set_size(const uword in_elem) { set_size(in_elem, 1); } + + return *this; } template inline -void +SpMat& SpMat::set_size(const uword in_rows, const uword in_cols) { arma_extra_debug_sigprint(); invalidate_cache(); // placed here, as set_size() is used during matrix modification - if( (n_rows == in_rows) && (n_cols == in_cols) ) - { - return; - } - else - { - init(in_rows, in_cols); - } + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } + + init(in_rows, in_cols); + + return *this; } template inline -void +SpMat& SpMat::set_size(const SizeMat& s) { arma_extra_debug_sigprint(); - (*this).set_size(s.n_rows, s.n_cols); + return (*this).set_size(s.n_rows, s.n_cols); } template inline -void +SpMat& SpMat::resize(const uword in_rows, const uword in_cols) { arma_extra_debug_sigprint(); - if( (n_rows == in_rows) && (n_cols == in_cols) ) - { - return; - } + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } - if( (n_elem == 0) || (n_nonzero == 0) ) - { - set_size(in_rows, in_cols); - return; - } + if( (n_elem == 0) || (n_nonzero == 0) ) { return set_size(in_rows, in_cols); } SpMat tmp(in_rows, in_cols); @@ -3918,41 +3774,39 @@ SpMat::resize(const uword in_rows, const uword in_cols) } steal_mem(tmp); + + return *this; } template inline -void +SpMat& SpMat::resize(const SizeMat& s) { arma_extra_debug_sigprint(); - (*this).resize(s.n_rows, s.n_cols); + return (*this).resize(s.n_rows, s.n_cols); } template inline -void +SpMat& SpMat::reshape(const uword in_rows, const uword in_cols) { arma_extra_debug_sigprint(); arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" ); - if( (n_rows == in_rows) && (n_cols == in_cols) ) { return; } + if( (n_rows == in_rows) && (n_cols == in_cols) ) { return *this; } if(vec_state == 1) { arma_debug_check( (in_cols != 1), "SpMat::reshape(): object is a column vector; requested size is not compatible" ); } if(vec_state == 2) { arma_debug_check( (in_rows != 1), "SpMat::reshape(): object is a row vector; requested size is not compatible" ); } - if(n_nonzero == 0) - { - (*this).zeros(in_rows, in_cols); - return; - } + if(n_nonzero == 0) { return (*this).zeros(in_rows, in_cols); } if(in_cols == 1) { @@ -3962,18 +3816,20 @@ SpMat::reshape(const uword in_rows, const uword in_cols) { (*this).reshape_helper_generic(in_rows, in_cols); } + + return *this; } template inline -void +SpMat& SpMat::reshape(const SizeMat& s) { arma_extra_debug_sigprint(); - (*this).reshape(s.n_rows, s.n_cols); + return (*this).reshape(s.n_rows, s.n_cols); } @@ -3999,8 +3855,8 @@ SpMat::reshape_helper_generic(const uword in_rows, const uword in_cols) arrayops::fill_zeros(new_col_ptrs, in_cols + 1); - const_iterator it = begin(); - const_iterator it_end = end(); + const_iterator it = cbegin(); + const_iterator it_end = cend(); for(; it != it_end; ++it) { @@ -4039,7 +3895,7 @@ SpMat::reshape_helper_intovec() sync_csc(); invalidate_cache(); - const_iterator it = begin(); + const_iterator it = cbegin(); const uword t_n_rows = n_rows; const uword t_n_nonzero = n_nonzero; @@ -4067,49 +3923,11 @@ SpMat::reshape_helper_intovec() -//! NOTE: don't use this form; it's deprecated and will be removed -template -arma_deprecated -inline -void -SpMat::reshape(const uword in_rows, const uword in_cols, const uword dim) - { - arma_extra_debug_sigprint(); - - arma_debug_check( (dim > 1), "SpMat::reshape(): parameter 'dim' must be 0 or 1" ); - - if(dim == 0) - { - (*this).reshape(in_rows, in_cols); - } - else - if(dim == 1) - { - arma_check( ((in_rows*in_cols) != n_elem), "SpMat::reshape(): changing the number of elements in a sparse matrix is currently not supported" ); - - sync_csc(); - - // Row-wise reshaping. This is more tedious and we will use a separate sparse matrix to do it. - SpMat tmp(in_rows, in_cols); - - for(const_row_iterator it = begin_row(); it.pos() < n_nonzero; ++it) - { - uword vector_position = (it.row() * n_cols) + it.col(); - - tmp((vector_position / in_cols), (vector_position % in_cols)) = (*it); - } - - steal_mem(tmp); - } - } - - - //! apply a functor to each non-zero element template template inline -const SpMat& +SpMat& SpMat::for_each(functor F) { arma_extra_debug_sigprint(); @@ -4154,10 +3972,7 @@ SpMat::for_each(functor F) const const uword N = (*this).n_nonzero; - for(uword i=0; i < N; ++i) - { - F(values[i]); - } + for(uword i=0; i < N; ++i) { F(values[i]); } return *this; } @@ -4168,7 +3983,7 @@ SpMat::for_each(functor F) const template template inline -const SpMat& +SpMat& SpMat::transform(functor F) { arma_extra_debug_sigprint(); @@ -4192,7 +4007,7 @@ SpMat::transform(functor F) } if(has_zero) { remove_zeros(); } - + return *this; } @@ -4200,14 +4015,14 @@ SpMat::transform(functor F) template inline -const SpMat& +SpMat& SpMat::replace(const eT old_val, const eT new_val) { arma_extra_debug_sigprint(); if(old_val == eT(0)) { - arma_debug_warn("SpMat::replace(): replacement not done, as old_val = 0"); + arma_debug_warn_level(1, "SpMat::replace(): replacement not done, as old_val = 0"); } else { @@ -4226,7 +4041,7 @@ SpMat::replace(const eT old_val, const eT new_val) template inline -const SpMat& +SpMat& SpMat::clean(const typename get_pod_type::result threshold) { arma_extra_debug_sigprint(); @@ -4236,9 +4051,40 @@ SpMat::clean(const typename get_pod_type::result threshold) sync_csc(); invalidate_cache(); - arrayops::clean(access::rwp(values), n_nonzero, threshold); + arrayops::clean(access::rwp(values), n_nonzero, threshold); + + remove_zeros(); + + return *this; + } + + + +template +inline +SpMat& +SpMat::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpMat::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "SpMat::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + if(n_nonzero == 0) { return *this; } + + sync_csc(); + invalidate_cache(); + + arrayops::clamp(access::rwp(values), n_nonzero, min_val, max_val); - remove_zeros(); + if( (min_val == eT(0)) || (max_val == eT(0)) ) { remove_zeros(); } return *this; } @@ -4247,14 +4093,16 @@ SpMat::clean(const typename get_pod_type::result threshold) template inline -const SpMat& +SpMat& SpMat::zeros() { arma_extra_debug_sigprint(); - const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) ); - - if(already_done == false) + if((n_nonzero == 0) && (values != nullptr)) + { + invalidate_cache(); + } + else { init(n_rows, n_cols); } @@ -4266,11 +4114,11 @@ SpMat::zeros() template inline -const SpMat& +SpMat& SpMat::zeros(const uword in_elem) { arma_extra_debug_sigprint(); - + if(vec_state == 2) { zeros(1, in_elem); // Row vector @@ -4279,7 +4127,7 @@ SpMat::zeros(const uword in_elem) { zeros(in_elem, 1); } - + return *this; } @@ -4287,14 +4135,16 @@ SpMat::zeros(const uword in_elem) template inline -const SpMat& +SpMat& SpMat::zeros(const uword in_rows, const uword in_cols) { arma_extra_debug_sigprint(); - const bool already_done = ( (sync_state != 1) && (n_nonzero == 0) && (n_rows == in_rows) && (n_cols == in_cols) ); - - if(already_done == false) + if((n_nonzero == 0) && (n_rows == in_rows) && (n_cols == in_cols) && (values != nullptr)) + { + invalidate_cache(); + } + else { init(in_rows, in_cols); } @@ -4306,7 +4156,7 @@ SpMat::zeros(const uword in_rows, const uword in_cols) template inline -const SpMat& +SpMat& SpMat::zeros(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -4318,11 +4168,11 @@ SpMat::zeros(const SizeMat& s) template inline -const SpMat& +SpMat& SpMat::eye() { arma_extra_debug_sigprint(); - + return (*this).eye(n_rows, n_cols); } @@ -4330,7 +4180,7 @@ SpMat::eye() template inline -const SpMat& +SpMat& SpMat::eye(const uword in_rows, const uword in_cols) { arma_extra_debug_sigprint(); @@ -4357,7 +4207,7 @@ SpMat::eye(const uword in_rows, const uword in_cols) template inline -const SpMat& +SpMat& SpMat::eye(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -4369,11 +4219,11 @@ SpMat::eye(const SizeMat& s) template inline -const SpMat& +SpMat& SpMat::speye() { arma_extra_debug_sigprint(); - + return (*this).eye(n_rows, n_cols); } @@ -4381,7 +4231,7 @@ SpMat::speye() template inline -const SpMat& +SpMat& SpMat::speye(const uword in_n_rows, const uword in_n_cols) { arma_extra_debug_sigprint(); @@ -4393,7 +4243,7 @@ SpMat::speye(const uword in_n_rows, const uword in_n_cols) template inline -const SpMat& +SpMat& SpMat::speye(const SizeMat& s) { arma_extra_debug_sigprint(); @@ -4405,7 +4255,7 @@ SpMat::speye(const SizeMat& s) template inline -const SpMat& +SpMat& SpMat::sprandu(const uword in_rows, const uword in_cols, const double density) { arma_extra_debug_sigprint(); @@ -4479,7 +4329,7 @@ SpMat::sprandu(const uword in_rows, const uword in_cols, const double densit template inline -const SpMat& +SpMat& SpMat::sprandu(const SizeMat& s, const double density) { arma_extra_debug_sigprint(); @@ -4491,7 +4341,7 @@ SpMat::sprandu(const SizeMat& s, const double density) template inline -const SpMat& +SpMat& SpMat::sprandn(const uword in_rows, const uword in_cols, const double density) { arma_extra_debug_sigprint(); @@ -4565,7 +4415,7 @@ SpMat::sprandn(const uword in_rows, const uword in_cols, const double densit template inline -const SpMat& +SpMat& SpMat::sprandn(const SizeMat& s, const double density) { arma_extra_debug_sigprint(); @@ -4581,21 +4431,50 @@ void SpMat::reset() { arma_extra_debug_sigprint(); - + switch(vec_state) { - default: - init(0, 0); - break; + default: init(0, 0); break; + case 1: init(0, 1); break; + case 2: init(1, 0); break; + } + } + + + +template +inline +void +SpMat::reset_cache() + { + arma_extra_debug_sigprint(); + + sync_csc(); + + #if defined(ARMA_USE_OPENMP) + { + #pragma omp critical (arma_SpMat_cache) + { + cache.reset(); - case 1: - init(0, 1); - break; + sync_state = 0; + } + } + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) + { + const std::lock_guard lock(cache_mutex); - case 2: - init(1, 0); - break; + cache.reset(); + + sync_state = 0; + } + #else + { + cache.reset(); + + sync_state = 0; } + #endif } @@ -4641,9 +4520,8 @@ SpMat::set_imag(const SpBase::pod_type,T1>& X) //! save the matrix to a file template inline -arma_cold bool -SpMat::save(const std::string name, const file_type type, const bool print_status) const +SpMat::save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); @@ -4654,7 +4532,11 @@ SpMat::save(const std::string name, const file_type type, const bool print_s switch(type) { case csv_ascii: - save_okay = diskio::save_csv_ascii(*this, name); + return (*this).save(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).save(csv_name(name), type); break; case arma_binary: @@ -4666,11 +4548,86 @@ SpMat::save(const std::string name, const file_type type, const bool print_s break; default: - if(print_status) { arma_debug_warn("SpMat::save(): unsupported file type"); } + arma_debug_warn_level(1, "SpMat::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("SpMat::save(): couldn't write to ", name); } + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): write failed; file: ", name); } + + return save_okay; + } + + + +template +inline +bool +SpMat::save(const csv_name& spec, const file_type type) const + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("SpMat::save(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + + arma_extra_debug_print("SpMat::save(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + if(with_header) + { + if( (spec.header_ro.n_cols != 1) && (spec.header_ro.n_rows != 1) ) + { + arma_debug_warn_level(1, "SpMat::save(): given header must have a vector layout"); + return false; + } + + for(uword i=0; i < spec.header_ro.n_elem; ++i) + { + const std::string& token = spec.header_ro.at(i); + + if(token.find(separator) != std::string::npos) + { + arma_debug_warn_level(1, "SpMat::save(): token within the header contains the separator character: '", token, "'"); + return false; + } + } + + const uword save_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(spec.header_ro.n_elem != save_n_cols) + { + arma_debug_warn_level(1, "SpMat::save(): size mismatch between header and matrix"); + return false; + } + } + + bool save_okay = false; + + if(do_trans) + { + const SpMat tmp = (*this).st(); + + save_okay = diskio::save_csv_ascii(tmp, spec.filename, spec.header_ro, with_header, separator); + } + else + { + save_okay = diskio::save_csv_ascii(*this, spec.filename, spec.header_ro, with_header, separator); + } + + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): write failed; file: ", spec.filename); } return save_okay; } @@ -4680,9 +4637,8 @@ SpMat::save(const std::string name, const file_type type, const bool print_s //! save the matrix to a stream template inline -arma_cold bool -SpMat::save(std::ostream& os, const file_type type, const bool print_status) const +SpMat::save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); @@ -4693,7 +4649,11 @@ SpMat::save(std::ostream& os, const file_type type, const bool print_status) switch(type) { case csv_ascii: - save_okay = diskio::save_csv_ascii(*this, os); + save_okay = diskio::save_csv_ascii(*this, os, char(',')); + break; + + case ssv_ascii: + save_okay = diskio::save_csv_ascii(*this, os, char(';')); break; case arma_binary: @@ -4705,11 +4665,11 @@ SpMat::save(std::ostream& os, const file_type type, const bool print_status) break; default: - if(print_status) { arma_debug_warn("SpMat::save(): unsupported file type"); } + arma_debug_warn_level(1, "SpMat::save(): unsupported file type"); save_okay = false; } - if(print_status && (save_okay == false)) { arma_debug_warn("SpMat::save(): couldn't write to the given stream"); } + if(save_okay == false) { arma_debug_warn_level(3, "SpMat::save(): stream write failed"); } return save_okay; } @@ -4719,9 +4679,8 @@ SpMat::save(std::ostream& os, const file_type type, const bool print_status) //! load a matrix from a file template inline -arma_cold bool -SpMat::load(const std::string name, const file_type type, const bool print_status) +SpMat::load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); @@ -4737,7 +4696,11 @@ SpMat::load(const std::string name, const file_type type, const bool print_s // break; case csv_ascii: - load_okay = diskio::load_csv_ascii(*this, name, err_msg); + return (*this).load(csv_name(name), type); + break; + + case ssv_ascii: + return (*this).load(csv_name(name), type); break; case arma_binary: @@ -4749,27 +4712,113 @@ SpMat::load(const std::string name, const file_type type, const bool print_s break; default: - if(print_status) { arma_debug_warn("SpMat::load(): unsupported file type"); } + arma_debug_warn_level(1, "SpMat::load(): unsupported file type"); load_okay = false; } - if(print_status && (load_okay == false)) + if(load_okay == false) + { + if(err_msg.length() > 0) + { + arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", name); + } + else + { + arma_debug_warn_level(3, "SpMat::load(): read failed; file: ", name); + } + } + + if(load_okay == false) { (*this).reset(); } + + return load_okay; + } + + + +template +inline +bool +SpMat::load(const csv_name& spec, const file_type type) + { + arma_extra_debug_sigprint(); + + if( (type != csv_ascii) && (type != ssv_ascii) ) + { + arma_stop_runtime_error("SpMat::load(): unsupported file type for csv_name()"); + return false; + } + + const bool do_trans = bool(spec.opts.flags & csv_opts::flag_trans ); + const bool no_header = bool(spec.opts.flags & csv_opts::flag_no_header ); + const bool with_header = bool(spec.opts.flags & csv_opts::flag_with_header) && (no_header == false); + const bool use_semicolon = bool(spec.opts.flags & csv_opts::flag_semicolon ) || (type == ssv_ascii); + const bool strict = bool(spec.opts.flags & csv_opts::flag_strict ); + + arma_extra_debug_print("SpMat::load(csv_name): enabled flags:"); + + if(do_trans ) { arma_extra_debug_print("trans"); } + if(no_header ) { arma_extra_debug_print("no_header"); } + if(with_header ) { arma_extra_debug_print("with_header"); } + if(use_semicolon) { arma_extra_debug_print("semicolon"); } + if(strict ) { arma_extra_debug_print("strict"); } + + if(strict) { arma_debug_warn_level(1, "SpMat::load(): option 'strict' not implemented for sparse matrices"); } + + const char separator = (use_semicolon) ? char(';') : char(','); + + bool load_okay = false; + std::string err_msg; + + if(do_trans) + { + SpMat tmp_mat; + + load_okay = diskio::load_csv_ascii(tmp_mat, spec.filename, err_msg, spec.header_rw, with_header, separator); + + if(load_okay) + { + (*this) = tmp_mat.st(); + + if(with_header) + { + // field::set_size() preserves data if the number of elements hasn't changed + spec.header_rw.set_size(spec.header_rw.n_elem, 1); + } + } + } + else + { + load_okay = diskio::load_csv_ascii(*this, spec.filename, err_msg, spec.header_rw, with_header, separator); + } + + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("SpMat::load(): ", err_msg, name); + arma_debug_warn_level(3, "SpMat::load(): ", err_msg, "; file: ", spec.filename); } else { - arma_debug_warn("SpMat::load(): couldn't read ", name); + arma_debug_warn_level(3, "SpMat::load(): read failed; file: ", spec.filename); + } + } + else + { + const uword load_n_cols = (do_trans) ? (*this).n_rows : (*this).n_cols; + + if(with_header && (spec.header_rw.n_elem != load_n_cols)) + { + arma_debug_warn_level(3, "SpMat::load(): size mismatch between header and matrix"); } } if(load_okay == false) { (*this).reset(); - } + if(with_header) { spec.header_rw.reset(); } + } + return load_okay; } @@ -4778,9 +4827,8 @@ SpMat::load(const std::string name, const file_type type, const bool print_s //! load a matrix from a stream template inline -arma_cold bool -SpMat::load(std::istream& is, const file_type type, const bool print_status) +SpMat::load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); @@ -4796,7 +4844,11 @@ SpMat::load(std::istream& is, const file_type type, const bool print_status) // break; case csv_ascii: - load_okay = diskio::load_csv_ascii(*this, is, err_msg); + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(',')); + break; + + case ssv_ascii: + load_okay = diskio::load_csv_ascii(*this, is, err_msg, char(';')); break; case arma_binary: @@ -4808,84 +4860,73 @@ SpMat::load(std::istream& is, const file_type type, const bool print_status) break; default: - if(print_status) { arma_debug_warn("SpMat::load(): unsupported file type"); } + arma_debug_warn_level(1, "SpMat::load(): unsupported file type"); load_okay = false; } - if(print_status && (load_okay == false)) + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("SpMat::load(): ", err_msg, "the given stream"); + arma_debug_warn_level(3, "SpMat::load(): ", err_msg); } else { - arma_debug_warn("SpMat::load(): couldn't load from the given stream"); + arma_debug_warn_level(3, "SpMat::load(): stream read failed"); } } - if(load_okay == false) - { - (*this).reset(); - } - + if(load_okay == false) { (*this).reset(); } + return load_okay; } -//! save the matrix to a file, without printing any error messages template inline -arma_cold bool SpMat::quiet_save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(name, type, false); + return (*this).save(name, type); } -//! save the matrix to a stream, without printing any error messages template inline -arma_cold bool SpMat::quiet_save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(os, type, false); + return (*this).save(os, type); } -//! load a matrix from a file, without printing any error messages template inline -arma_cold bool SpMat::quiet_load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(name, type, false); + return (*this).load(name, type); } -//! load a matrix from a stream, without printing any error messages template inline -arma_cold bool SpMat::quiet_load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(is, type, false); + return (*this).load(is, type); } @@ -4907,6 +4948,15 @@ SpMat::init(uword in_rows, uword in_cols, const uword new_n_nonzero) if(row_indices) { memory::release(access::rw(row_indices)); } if(col_ptrs ) { memory::release(access::rw(col_ptrs)); } + // in case init_cold() throws an exception + access::rw(n_rows) = 0; + access::rw(n_cols) = 0; + access::rw(n_elem) = 0; + access::rw(n_nonzero) = 0; + access::rw(values) = nullptr; + access::rw(row_indices) = nullptr; + access::rw(col_ptrs) = nullptr; + init_cold(in_rows, in_cols, new_n_nonzero); } @@ -4915,7 +4965,6 @@ SpMat::init(uword in_rows, uword in_cols, const uword new_n_nonzero) template inline void -arma_cold SpMat::init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero) { arma_extra_debug_sigprint(); @@ -4935,10 +4984,10 @@ SpMat::init_cold(uword in_rows, uword in_cols, const uword new_n_nonzero) } } - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "SpMat::init(): requested size is too large"; #else - const char* error_message = "SpMat::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "SpMat::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif // Ensure that n_elem can hold the result of (n_rows * n_cols) @@ -5028,16 +5077,16 @@ SpMat::init(const SpMat& x) init_done = true; } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) if(x.sync_state == 1) { - x.cache_mutex.lock(); + const std::lock_guard lock(x.cache_mutex); + if(x.sync_state == 1) { (*this).init(x.cache); init_done = true; } - x.cache_mutex.unlock(); } #else if(x.sync_state == 1) @@ -5157,11 +5206,21 @@ SpMat::init_simple(const SpMat& x) if(this == &x) { return; } - init(x.n_rows, x.n_cols, x.n_nonzero); + if((x.n_nonzero == 0) && (n_nonzero == 0) && (n_rows == x.n_rows) && (n_cols == x.n_cols) && (values != nullptr)) + { + invalidate_cache(); + } + else + { + init(x.n_rows, x.n_cols, x.n_nonzero); + } - if(x.values ) { arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); } - if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); } - if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); } + if(x.n_nonzero != 0) + { + if(x.values ) { arrayops::copy(access::rwp(values), x.values, x.n_nonzero + 1); } + if(x.row_indices) { arrayops::copy(access::rwp(row_indices), x.row_indices, x.n_nonzero + 1); } + if(x.col_ptrs ) { arrayops::copy(access::rwp(col_ptrs), x.col_ptrs, x.n_cols + 1); } + } } @@ -5181,13 +5240,13 @@ SpMat::init_batch_std(const Mat& locs, const Mat& vals, const boo bool actually_sorted = true; - if(sort_locations == true) + if(sort_locations) { // check if we really need a time consuming sort const uword locs_n_cols = locs.n_cols; - for (uword i = 1; i < locs_n_cols; ++i) + for(uword i = 1; i < locs_n_cols; ++i) { const uword* locs_i = locs.colptr(i ); const uword* locs_im1 = locs.colptr(i-1); @@ -5213,7 +5272,7 @@ SpMat::init_batch_std(const Mat& locs, const Mat& vals, const boo const uword* locs_mem = locs.memptr(); - for (uword i = 0; i < locs_n_cols; ++i) + for(uword i = 0; i < locs_n_cols; ++i) { const uword row = (*locs_mem); locs_mem++; const uword col = (*locs_mem); locs_mem++; @@ -5227,7 +5286,7 @@ SpMat::init_batch_std(const Mat& locs, const Mat& vals, const boo std::sort( packet_vec.begin(), packet_vec.end(), comparator ); // insert the elements in the sorted order - for (uword i = 0; i < locs_n_cols; ++i) + for(uword i = 0; i < locs_n_cols; ++i) { const uword index = packet_vec[i].index; @@ -5298,7 +5357,7 @@ SpMat::init_batch_std(const Mat& locs, const Mat& vals, const boo } // Now fix the column pointers. - for (uword i = 0; i < n_cols; ++i) + for(uword i = 0; i < n_cols; ++i) { access::rw(col_ptrs[i + 1]) += col_ptrs[i]; } @@ -5324,11 +5383,11 @@ SpMat::init_batch_add(const Mat& locs, const Mat& vals, const boo bool actually_sorted = true; - if(sort_locations == true) + if(sort_locations) { // sort_index() uses std::sort() which may use quicksort... so we better // make sure it's not already sorted before taking an O(N^2) sort penalty. - for (uword i = 1; i < locs.n_cols; ++i) + for(uword i = 1; i < locs.n_cols; ++i) { const uword* locs_i = locs.colptr(i ); const uword* locs_im1 = locs.colptr(i-1); @@ -5343,9 +5402,9 @@ SpMat::init_batch_add(const Mat& locs, const Mat& vals, const boo if(actually_sorted == false) { // This may not be the fastest possible implementation but it maximizes code reuse. - Col abslocs(locs.n_cols); + Col abslocs(locs.n_cols, arma_nozeros_indicator()); - for (uword i = 0; i < locs.n_cols; ++i) + for(uword i = 0; i < locs.n_cols; ++i) { const uword* locs_i = locs.colptr(i); @@ -5471,7 +5530,7 @@ SpMat::init_batch_add(const Mat& locs, const Mat& vals, const boo } // Now fix the column pointers. - for (uword i = 0; i < n_cols; ++i) + for(uword i = 0; i < n_cols; ++i) { access::rw(col_ptrs[i + 1]) += col_ptrs[i]; } @@ -5488,9 +5547,9 @@ SpMat::SpMat(const arma_vec_indicator&, const uword in_vec_state) , n_elem(0) , n_nonzero(0) , vec_state(in_vec_state) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -5511,9 +5570,9 @@ SpMat::SpMat(const arma_vec_indicator&, const uword in_n_rows, const uword i , n_elem(0) , n_nonzero(0) , vec_state(in_vec_state) - , values(NULL) - , row_indices(NULL) - , col_ptrs(NULL) + , values(nullptr) + , row_indices(nullptr) + , col_ptrs(nullptr) { arma_extra_debug_sigprint_this(this); @@ -5589,9 +5648,11 @@ SpMat::remove_zeros() const eT* old_values = values; + constexpr eT zero = eT(0); + for(uword i=0; i < old_n_nonzero; ++i) { - new_n_nonzero += (old_values[i] != eT(0)) ? uword(1) : uword(0); + new_n_nonzero += (old_values[i] != zero) ? uword(1) : uword(0); } if(new_n_nonzero != old_n_nonzero) @@ -5602,18 +5663,21 @@ SpMat::remove_zeros() uword new_index = 0; - const_iterator it = begin(); - const_iterator it_end = end(); + const_iterator it = cbegin(); + const_iterator it_end = cend(); for(; it != it_end; ++it) { const eT val = eT(*it); - if(val != eT(0)) + if(val != zero) { + const uword it_row = it.row(); + const uword it_col = it.col(); + access::rw(tmp.values[new_index]) = val; - access::rw(tmp.row_indices[new_index]) = it.row(); - access::rw(tmp.col_ptrs[it.col() + 1])++; + access::rw(tmp.row_indices[new_index]) = it_row; + access::rw(tmp.col_ptrs[it_col + 1])++; ++new_index; } } @@ -5653,6 +5717,8 @@ SpMat::steal_mem(SpMat& x) if(layout_ok) { + arma_extra_debug_print("SpMat::steal_mem(): stealing memory"); + x.sync_csc(); steal_mem_simple(x); @@ -5663,6 +5729,8 @@ SpMat::steal_mem(SpMat& x) } else { + arma_extra_debug_print("SpMat::steal_mem(): copying memory"); + (*this).operator=(x); } } @@ -5697,16 +5765,15 @@ SpMat::steal_mem_simple(SpMat& x) access::rw(x.n_elem) = 0; access::rw(x.n_nonzero) = 0; - access::rw(x.values) = NULL; - access::rw(x.row_indices) = NULL; - access::rw(x.col_ptrs) = NULL; + access::rw(x.values) = nullptr; + access::rw(x.row_indices) = nullptr; + access::rw(x.col_ptrs) = nullptr; } template template -arma_hot inline void SpMat::init_xform(const SpBase& A, const Functor& func) @@ -5745,7 +5812,6 @@ SpMat::init_xform(const SpBase& A, const Functor& func) template template -arma_hot inline void SpMat::init_xform_mt(const SpBase& A, const Functor& func) @@ -5754,7 +5820,7 @@ SpMat::init_xform_mt(const SpBase& A, const Functor& func) const SpProxy P(A.get_ref()); - if( (P.is_alias(*this) == true) || (is_SpMat::stored_type>::value == true) ) + if( P.is_alias(*this) || (is_SpMat::stored_type>::value) ) { // NOTE: unwrap_spmat will convert a submatrix to a matrix, which in effect takes care of aliasing with submatrices; // NOTE: however, when more delayed ops are implemented, more elaborate handling of aliasing will be necessary @@ -5806,8 +5872,10 @@ SpMat::init_xform_mt(const SpBase& A, const Functor& func) if(val == eT(0)) { has_zero = true; } - access::rw(row_indices[it.pos()]) = it.row(); - access::rw(values[it.pos()]) = val; + const uword it_pos = it.pos(); + + access::rw(row_indices[it_pos]) = it.row(); + access::rw(values[it_pos]) = val; ++access::rw(col_ptrs[it.col() + 1]); ++it; } @@ -6149,10 +6217,56 @@ SpMat::size() const +template +arma_inline +SpMat_MapMat_val +SpMat::front() + { + arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" ); + + return SpMat_MapMat_val((*this), cache, 0, 0); + } + + + +template +arma_inline +eT +SpMat::front() const + { + arma_debug_check( (n_elem == 0), "SpMat::front(): matrix is empty" ); + + return get_value(0,0); + } + + + +template +arma_inline +SpMat_MapMat_val +SpMat::back() + { + arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" ); + + return SpMat_MapMat_val((*this), cache, n_rows-1, n_cols-1); + } + + + +template +arma_inline +eT +SpMat::back() const + { + arma_debug_check( (n_elem == 0), "SpMat::back(): matrix is empty" ); + + return get_value(n_rows-1, n_cols-1); + } + + + template inline -arma_hot -arma_warn_unused eT SpMat::get_value(const uword i) const { @@ -6167,8 +6281,6 @@ SpMat::get_value(const uword i) const template inline -arma_hot -arma_warn_unused eT SpMat::get_value(const uword in_row, const uword in_col) const { @@ -6183,8 +6295,6 @@ SpMat::get_value(const uword in_row, const uword in_col) const template inline -arma_hot -arma_warn_unused eT SpMat::get_value_csc(const uword i) const { @@ -6199,8 +6309,6 @@ SpMat::get_value_csc(const uword i) const template inline -arma_hot -arma_warn_unused const eT* SpMat::find_value_csc(const uword in_row, const uword in_col) const { @@ -6220,36 +6328,32 @@ SpMat::find_value_csc(const uword in_row, const uword in_col) const return &(values[index]); } - return NULL; + return nullptr; } template inline -arma_hot -arma_warn_unused eT SpMat::get_value_csc(const uword in_row, const uword in_col) const { const eT* val_ptr = find_value_csc(in_row, in_col); - return (val_ptr != NULL) ? eT(*val_ptr) : eT(0); + return (val_ptr != nullptr) ? eT(*val_ptr) : eT(0); } template inline -arma_hot -arma_warn_unused bool SpMat::try_set_value_csc(const uword in_row, const uword in_col, const eT in_val) { const eT* val_ptr = find_value_csc(in_row, in_col); // element not found, ie. it's zero; fail if trying to set it to non-zero value - if(val_ptr == NULL) { return (in_val == eT(0)); } + if(val_ptr == nullptr) { return (in_val == eT(0)); } // fail if trying to erase an existing element if(in_val == eT(0)) { return false; } @@ -6265,15 +6369,13 @@ SpMat::try_set_value_csc(const uword in_row, const uword in_col, const eT in template inline -arma_hot -arma_warn_unused bool SpMat::try_add_value_csc(const uword in_row, const uword in_col, const eT in_val) { const eT* val_ptr = find_value_csc(in_row, in_col); // element not found, ie. it's zero; fail if trying to add a non-zero value - if(val_ptr == NULL) { return (in_val == eT(0)); } + if(val_ptr == nullptr) { return (in_val == eT(0)); } const eT new_val = eT(*val_ptr) + in_val; @@ -6291,15 +6393,13 @@ SpMat::try_add_value_csc(const uword in_row, const uword in_col, const eT in template inline -arma_hot -arma_warn_unused bool SpMat::try_sub_value_csc(const uword in_row, const uword in_col, const eT in_val) { const eT* val_ptr = find_value_csc(in_row, in_col); // element not found, ie. it's zero; fail if trying to subtract a non-zero value - if(val_ptr == NULL) { return (in_val == eT(0)); } + if(val_ptr == nullptr) { return (in_val == eT(0)); } const eT new_val = eT(*val_ptr) - in_val; @@ -6317,15 +6417,13 @@ SpMat::try_sub_value_csc(const uword in_row, const uword in_col, const eT in template inline -arma_hot -arma_warn_unused bool SpMat::try_mul_value_csc(const uword in_row, const uword in_col, const eT in_val) { const eT* val_ptr = find_value_csc(in_row, in_col); // element not found, ie. it's zero; succeed if given value is finite; zero multiplied by anything is zero, except for nan and inf - if(val_ptr == NULL) { return arma_isfinite(in_val); } + if(val_ptr == nullptr) { return arma_isfinite(in_val); } const eT new_val = eT(*val_ptr) * in_val; @@ -6343,15 +6441,13 @@ SpMat::try_mul_value_csc(const uword in_row, const uword in_col, const eT in template inline -arma_hot -arma_warn_unused bool SpMat::try_div_value_csc(const uword in_row, const uword in_col, const eT in_val) { const eT* val_ptr = find_value_csc(in_row, in_col); // element not found, ie. it's zero; succeed if given value is not zero and not nan; zero divided by anything is zero, except for zero and nan - if(val_ptr == NULL) { return ((in_val != eT(0)) && (arma_isnan(in_val) == false)); } + if(val_ptr == nullptr) { return ((in_val != eT(0)) && (arma_isnan(in_val) == false)); } const eT new_val = eT(*val_ptr) / in_val; @@ -6371,14 +6467,9 @@ SpMat::try_div_value_csc(const uword in_row, const uword in_col, const eT in * Insert an element at the given position, and return a reference to it. * The element will be set to 0, unless otherwise specified. * If the element already exists, its value will be overwritten. - * - * @param in_row Row of new element. - * @param in_col Column of new element. - * @param in_val Value to set new element to (default 0). */ template inline -arma_warn_unused eT& SpMat::insert_element(const uword in_row, const uword in_col, const eT val) { @@ -6395,18 +6486,18 @@ SpMat::insert_element(const uword in_row, const uword in_col, const eT val) uword pos = colptr; // The position in the matrix of this value. - if (colptr != next_colptr) + if(colptr != next_colptr) { // There are other elements in this column, so we must find where this // element will fit as compared to those. - while (pos < next_colptr && in_row > row_indices[pos]) + while(pos < next_colptr && in_row > row_indices[pos]) { pos++; } // We aren't inserting into the last position, so it is still possible // that the element may exist. - if (pos != next_colptr && row_indices[pos] == in_row) + if(pos != next_colptr && row_indices[pos] == in_row) { // It already exists. Then, just overwrite it. access::rw(values[pos]) = val; @@ -6421,7 +6512,7 @@ SpMat::insert_element(const uword in_row, const uword in_col, const eT val) // // We have to update the rest of the column pointers. - for (uword i = in_col + 1; i < n_cols + 1; i++) + for(uword i = in_col + 1; i < n_cols + 1; i++) { access::rw(col_ptrs[i])++; // We are only inserting one new element. } @@ -6435,7 +6526,7 @@ SpMat::insert_element(const uword in_row, const uword in_col, const eT val) uword* new_row_indices = memory::acquire(n_nonzero + 1); // Copy things over, before the new element. - if (pos > 0) + if(pos > 0) { arrayops::copy(new_values, values, pos); arrayops::copy(new_row_indices, row_indices, pos); @@ -6463,9 +6554,6 @@ SpMat::insert_element(const uword in_row, const uword in_col, const eT val) /** * Delete an element at the given position. - * - * @param in_row Row of element to be deleted. - * @param in_col Column of element to be deleted. */ template inline @@ -6482,13 +6570,13 @@ SpMat::delete_element(const uword in_row, const uword in_col) uword colptr = col_ptrs[in_col]; uword next_colptr = col_ptrs[in_col + 1]; - if (colptr != next_colptr) + if(colptr != next_colptr) { // There's at least one element in this column. // Let's see if we are one of them. - for (uword pos = colptr; pos < next_colptr; pos++) + for(uword pos = colptr; pos < next_colptr; pos++) { - if (in_row == row_indices[pos]) + if(in_row == row_indices[pos]) { --access::rw(n_nonzero); // Remove one from the count of nonzero elements. @@ -6498,7 +6586,7 @@ SpMat::delete_element(const uword in_row, const uword in_col) eT* new_values = memory::acquire (n_nonzero + 1); uword* new_row_indices = memory::acquire(n_nonzero + 1); - if (pos > 0) + if(pos > 0) { arrayops::copy(new_values, values, pos); arrayops::copy(new_row_indices, row_indices, pos); @@ -6514,7 +6602,7 @@ SpMat::delete_element(const uword in_row, const uword in_col) access::rw(row_indices) = new_row_indices; // And lastly, update all the column pointers (decrement by one). - for (uword i = in_col + 1; i < n_cols + 1; i++) + for(uword i = in_col + 1; i < n_cols + 1; i++) { --access::rw(col_ptrs[i]); // We only removed one element. } @@ -6586,15 +6674,13 @@ SpMat::sync_cache() const } } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) { if(sync_state == 0) { - cache_mutex.lock(); + const std::lock_guard lock(cache_mutex); sync_cache_simple(); - - cache_mutex.unlock(); } } #else @@ -6639,14 +6725,12 @@ SpMat::sync_csc() const sync_csc_simple(); } } - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) if(sync_state == 1) { - cache_mutex.lock(); + const std::lock_guard lock(cache_mutex); sync_csc_simple(); - - cache_mutex.unlock(); } #else { @@ -6762,7 +6846,7 @@ SpMat_aux::set_imag(SpMat< std::complex >& out, const SpBase& X) -#ifdef ARMA_EXTRA_SPMAT_MEAT +#if defined(ARMA_EXTRA_SPMAT_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPMAT_MEAT) #endif diff --git a/src/armadillo_bits/SpOp_bones.hpp b/src/armadillo_bits/SpOp_bones.hpp index 4360371f..af8a229b 100644 --- a/src/armadillo_bits/SpOp_bones.hpp +++ b/src/armadillo_bits/SpOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class SpOp : public SpBase > +class SpOp : public SpBase< typename T1::elem_type, SpOp > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = op_type::template traits::is_row; - static const bool is_col = op_type::template traits::is_col; - static const bool is_xvec = op_type::template traits::is_xvec; + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; inline explicit SpOp(const T1& in_m); inline SpOp(const T1& in_m, const elem_type in_aux); diff --git a/src/armadillo_bits/SpOp_meat.hpp b/src/armadillo_bits/SpOp_meat.hpp index f30b0a52..2a6f1f51 100644 --- a/src/armadillo_bits/SpOp_meat.hpp +++ b/src/armadillo_bits/SpOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/SpProxy.hpp b/src/armadillo_bits/SpProxy.hpp index 204b7b16..50adcc80 100644 --- a/src/armadillo_bits/SpProxy.hpp +++ b/src/armadillo_bits/SpProxy.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -64,10 +66,8 @@ template -class SpProxy< SpMat > +struct SpProxy< SpMat > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpMat stored_type; @@ -75,12 +75,12 @@ class SpProxy< SpMat > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = false; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const SpMat& Q; @@ -118,10 +118,8 @@ class SpProxy< SpMat > template -class SpProxy< SpCol > +struct SpProxy< SpCol > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpCol stored_type; @@ -129,12 +127,12 @@ class SpProxy< SpCol > typedef typename SpCol::const_iterator const_iterator_type; typedef typename SpCol::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = false; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const SpCol& Q; @@ -146,7 +144,7 @@ class SpProxy< SpCol > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } @@ -172,10 +170,8 @@ class SpProxy< SpCol > template -class SpProxy< SpRow > +struct SpProxy< SpRow > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpRow stored_type; @@ -183,12 +179,12 @@ class SpProxy< SpRow > typedef typename SpRow::const_iterator const_iterator_type; typedef typename SpRow::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = false; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = false; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const SpRow& Q; @@ -199,7 +195,7 @@ class SpProxy< SpRow > Q.sync(); } - arma_inline uword get_n_rows() const { return 1; } + constexpr uword get_n_rows() const { return 1; } arma_inline uword get_n_cols() const { return Q.n_cols; } arma_inline uword get_n_elem() const { return Q.n_elem; } arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } @@ -226,10 +222,8 @@ class SpProxy< SpRow > template -class SpProxy< SpSubview > +struct SpProxy< SpSubview > { - public: - typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpSubview stored_type; @@ -237,12 +231,12 @@ class SpProxy< SpSubview > typedef typename SpSubview::const_iterator const_iterator_type; typedef typename SpSubview::const_row_iterator const_row_iterator_type; - static const bool use_iterator = true; - static const bool Q_is_generated = false; + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const SpSubview& Q; @@ -280,10 +274,163 @@ class SpProxy< SpSubview > template -class SpProxy< spdiagview > +struct SpProxy< SpSubview_col > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpSubview_col stored_type; + + typedef typename SpSubview::const_iterator const_iterator_type; + typedef typename SpSubview::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + + arma_aligned const SpSubview_col& Q; + + inline explicit SpProxy(const SpSubview_col& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.m.sync(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + constexpr uword get_n_cols() const { return 1; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q.at(i, 0); } + arma_inline elem_type at (const uword row, const uword) const { return Q.at(row, 0); } + + arma_inline const eT* get_values() const { return Q.m.values; } + arma_inline const uword* get_row_indices() const { return Q.m.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.m.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q.m) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< SpSubview_col_list > + { + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpMat stored_type; + + typedef typename SpMat::const_iterator const_iterator_type; + typedef typename SpMat::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + arma_aligned const SpMat Q; + + inline explicit SpProxy(const SpSubview_col_list& A) + : Q(A) + { + arma_extra_debug_sigprint(); + } + + arma_inline uword get_n_rows() const { return Q.n_rows; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q[i]; } + arma_inline elem_type at (const uword row, const uword col) const { return Q.at(row, col); } + + arma_inline const eT* get_values() const { return Q.values; } + arma_inline const uword* get_row_indices() const { return Q.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + constexpr bool is_alias(const SpMat&) const { return false; } + }; + + + +template +struct SpProxy< SpSubview_row > { - public: + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + typedef SpSubview_row stored_type; + + typedef typename SpSubview::const_iterator const_iterator_type; + typedef typename SpSubview::const_row_iterator const_row_iterator_type; + + static constexpr bool use_iterator = true; + static constexpr bool Q_is_generated = false; + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + arma_aligned const SpSubview_row& Q; + + inline explicit SpProxy(const SpSubview_row& A) + : Q(A) + { + arma_extra_debug_sigprint(); + Q.m.sync(); + } + + constexpr uword get_n_rows() const { return 1; } + arma_inline uword get_n_cols() const { return Q.n_cols; } + arma_inline uword get_n_elem() const { return Q.n_elem; } + arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } + + arma_inline elem_type operator[](const uword i) const { return Q.at(0, i ); } + arma_inline elem_type at (const uword, const uword col) const { return Q.at(0, col); } + + arma_inline const eT* get_values() const { return Q.m.values; } + arma_inline const uword* get_row_indices() const { return Q.m.row_indices; } + arma_inline const uword* get_col_ptrs() const { return Q.m.col_ptrs; } + + arma_inline const_iterator_type begin() const { return Q.begin(); } + arma_inline const_iterator_type begin_col(const uword col_num) const { return Q.begin_col(col_num); } + arma_inline const_row_iterator_type begin_row(const uword row_num = 0) const { return Q.begin_row(row_num); } + + arma_inline const_iterator_type end() const { return Q.end(); } + arma_inline const_row_iterator_type end_row() const { return Q.end_row(); } + arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } + + template + arma_inline bool is_alias(const SpMat& X) const { return (void_ptr(&Q.m) == void_ptr(&X)); } + }; + + + +template +struct SpProxy< spdiagview > + { typedef eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpMat stored_type; @@ -291,12 +438,12 @@ class SpProxy< spdiagview > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = true; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const SpMat Q; @@ -307,7 +454,7 @@ class SpProxy< spdiagview > } arma_inline uword get_n_rows() const { return Q.n_rows; } - arma_inline uword get_n_cols() const { return 1; } + constexpr uword get_n_cols() const { return 1; } arma_inline uword get_n_elem() const { return Q.n_elem; } arma_inline uword get_n_nonzero() const { return Q.n_nonzero; } @@ -327,16 +474,14 @@ class SpProxy< spdiagview > arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; template -class SpProxy< SpOp > +struct SpProxy< SpOp > { - public: - typedef typename T1::elem_type elem_type; typedef typename T1::elem_type eT; typedef typename get_pod_type::result pod_type; @@ -345,12 +490,12 @@ class SpProxy< SpOp > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = true; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; - static const bool is_row = SpOp::is_row; - static const bool is_col = SpOp::is_col; - static const bool is_xvec = SpOp::is_xvec; + static constexpr bool is_row = SpOp::is_row; + static constexpr bool is_col = SpOp::is_col; + static constexpr bool is_xvec = SpOp::is_xvec; arma_aligned const SpMat Q; @@ -381,16 +526,14 @@ class SpProxy< SpOp > arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; template -class SpProxy< SpGlue > +struct SpProxy< SpGlue > { - public: - typedef typename T1::elem_type elem_type; typedef typename T1::elem_type eT; typedef typename get_pod_type::result pod_type; @@ -399,12 +542,12 @@ class SpProxy< SpGlue > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = true; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; - static const bool is_row = SpGlue::is_row; - static const bool is_col = SpGlue::is_col; - static const bool is_xvec = SpGlue::is_xvec; + static constexpr bool is_row = SpGlue::is_row; + static constexpr bool is_col = SpGlue::is_col; + static constexpr bool is_xvec = SpGlue::is_xvec; arma_aligned const SpMat Q; @@ -435,16 +578,14 @@ class SpProxy< SpGlue > arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; template -class SpProxy< mtSpOp > +struct SpProxy< mtSpOp > { - public: - typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpMat stored_type; @@ -452,12 +593,12 @@ class SpProxy< mtSpOp > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = true; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; - static const bool is_row = mtSpOp::is_row; - static const bool is_col = mtSpOp::is_col; - static const bool is_xvec = mtSpOp::is_xvec; + static constexpr bool is_row = mtSpOp::is_row; + static constexpr bool is_col = mtSpOp::is_col; + static constexpr bool is_xvec = mtSpOp::is_xvec; arma_aligned const SpMat Q; @@ -488,16 +629,14 @@ class SpProxy< mtSpOp > arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; template -class SpProxy< mtSpGlue > +struct SpProxy< mtSpGlue > { - public: - typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; typedef SpMat stored_type; @@ -505,12 +644,12 @@ class SpProxy< mtSpGlue > typedef typename SpMat::const_iterator const_iterator_type; typedef typename SpMat::const_row_iterator const_row_iterator_type; - static const bool use_iterator = false; - static const bool Q_is_generated = true; + static constexpr bool use_iterator = false; + static constexpr bool Q_is_generated = true; - static const bool is_row = mtSpGlue::is_row; - static const bool is_col = mtSpGlue::is_col; - static const bool is_xvec = mtSpGlue::is_xvec; + static constexpr bool is_row = mtSpGlue::is_row; + static constexpr bool is_col = mtSpGlue::is_col; + static constexpr bool is_xvec = mtSpGlue::is_xvec; arma_aligned const SpMat Q; @@ -541,7 +680,7 @@ class SpProxy< mtSpGlue > arma_inline const_row_iterator_type end_row(const uword row_num) const { return Q.end_row(row_num); } template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; diff --git a/src/armadillo_bits/SpRow_bones.hpp b/src/armadillo_bits/SpRow_bones.hpp index b15efb52..c13de1b4 100644 --- a/src/armadillo_bits/SpRow_bones.hpp +++ b/src/armadillo_bits/SpRow_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class SpRow : public SpMat typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; inline SpRow(); @@ -54,6 +56,10 @@ class SpRow : public SpMat template inline explicit SpRow(const SpBase& A, const SpBase& B); + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + inline void shed_col (const uword col_num); inline void shed_cols(const uword in_col1, const uword in_col2); @@ -69,7 +75,7 @@ class SpRow : public SpMat inline row_iterator end_row(const uword row_num = 0); inline const_row_iterator end_row(const uword row_num = 0) const; - #ifdef ARMA_EXTRA_SPROW_PROTO + #if defined(ARMA_EXTRA_SPROW_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPROW_PROTO) #endif }; diff --git a/src/armadillo_bits/SpRow_meat.hpp b/src/armadillo_bits/SpRow_meat.hpp index c6b93c12..10f052f7 100644 --- a/src/armadillo_bits/SpRow_meat.hpp +++ b/src/armadillo_bits/SpRow_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -200,6 +202,36 @@ SpRow::SpRow +template +inline +const SpOp,spop_htrans> +SpRow::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpRow::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpRow::st() const + { + return SpOp,spop_strans>(*this); + } + + + //! remove specified columns template inline @@ -208,7 +240,7 @@ SpRow::shed_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= SpMat::n_cols, "SpRow::shed_col(): out of bounds"); + arma_debug_check_bounds( col_num >= SpMat::n_cols, "SpRow::shed_col(): out of bounds" ); shed_cols(col_num, col_num); } @@ -223,7 +255,7 @@ SpRow::shed_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= SpMat::n_cols), "SpRow::shed_cols(): indices out of bounds or incorrectly used" @@ -237,7 +269,7 @@ SpRow::shed_cols(const uword in_col1, const uword in_col2) const uword start = SpMat::col_ptrs[in_col1]; const uword end = SpMat::col_ptrs[in_col2 + 1]; - if (start != end) + if(start != end) { const uword elem_diff = end - start; @@ -245,14 +277,14 @@ SpRow::shed_cols(const uword in_col1, const uword in_col2) uword* new_row_indices = memory::acquire(SpMat::n_nonzero - elem_diff); // Copy first set of elements, if necessary. - if (start > 0) + if(start > 0) { arrayops::copy(new_values, SpMat::values, start); arrayops::copy(new_row_indices, SpMat::row_indices, start); } // Copy last set of elements, if necessary. - if (end != SpMat::n_nonzero) + if(end != SpMat::n_nonzero) { arrayops::copy(new_values + start, SpMat::values + end, (SpMat::n_nonzero - end)); arrayops::copy(new_row_indices + start, SpMat::row_indices + end, (SpMat::n_nonzero - end)); @@ -271,13 +303,13 @@ SpRow::shed_cols(const uword in_col1, const uword in_col2) uword* new_col_ptrs = memory::acquire(SpMat::n_cols - diff + 1); // Copy first part of column pointers. - if (in_col1 > 0) + if(in_col1 > 0) { arrayops::copy(new_col_ptrs, SpMat::col_ptrs, in_col1); } // Copy last part of column pointers (and adjust their values as necessary). - if (in_col2 < SpMat::n_cols - 1) + if(in_col2 < SpMat::n_cols - 1) { arrayops::copy(new_col_ptrs + in_col1, SpMat::col_ptrs + in_col2 + 1, SpMat::n_cols - in_col2); // Modify their values. @@ -306,9 +338,9 @@ SpRow::shed_cols(const uword in_col1, const uword in_col2) // arma_extra_debug_sigprint(); // // // insertion at col_num == n_cols is in effect an append operation -// arma_debug_check( (col_num > SpMat::n_cols), "SpRow::insert_cols(): out of bounds"); +// arma_debug_check_bounds( (col_num > SpMat::n_cols), "SpRow::insert_cols(): out of bounds" ); // -// arma_debug_check( (set_to_zero == false), "SpRow::insert_cols(): cannot set elements to nonzero values"); +// arma_debug_check( (set_to_zero == false), "SpRow::insert_cols(): cannot set elements to nonzero values" ); // // uword newVal = (col_num == 0) ? 0 : SpMat::col_ptrs[col_num]; // SpMat::col_ptrs.insert(col_num, N, newVal); @@ -336,7 +368,7 @@ SpRow::begin_row(const uword row_num) // Since this is a row, row_num can only be 0. But the option is provided for // compatibility. - arma_debug_check((row_num >= 1), "SpRow::begin_row(): index out of bounds"); + arma_debug_check_bounds((row_num >= 1), "SpRow::begin_row(): index out of bounds"); return SpMat::begin(); } @@ -352,7 +384,7 @@ SpRow::begin_row(const uword row_num) const // Since this is a row, row_num can only be 0. But the option is provided for // compatibility. - arma_debug_check((row_num >= 1), "SpRow::begin_row(): index out of bounds"); + arma_debug_check_bounds((row_num >= 1), "SpRow::begin_row(): index out of bounds"); return SpMat::begin(); } @@ -368,7 +400,7 @@ SpRow::end_row(const uword row_num) // Since this is a row, row_num can only be 0. But the option is provided for // compatibility. - arma_debug_check((row_num >= 1), "SpRow::end_row(): index out of bounds"); + arma_debug_check_bounds((row_num >= 1), "SpRow::end_row(): index out of bounds"); return SpMat::end(); } @@ -384,7 +416,7 @@ SpRow::end_row(const uword row_num) const // Since this is a row, row_num can only be 0. But the option is provided for // compatibility. - arma_debug_check((row_num >= 1), "SpRow::end_row(): index out of bounds"); + arma_debug_check_bounds((row_num >= 1), "SpRow::end_row(): index out of bounds"); return SpMat::end(); } @@ -392,7 +424,7 @@ SpRow::end_row(const uword row_num) const -#ifdef ARMA_EXTRA_SPROW_MEAT +#if defined(ARMA_EXTRA_SPROW_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_SPROW_MEAT) #endif diff --git a/src/armadillo_bits/SpSubview_bones.hpp b/src/armadillo_bits/SpSubview_bones.hpp index 5e549a20..6be50b3a 100644 --- a/src/armadillo_bits/SpSubview_bones.hpp +++ b/src/armadillo_bits/SpSubview_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,7 @@ template -class SpSubview : public SpBase > +class SpSubview : public SpBase< eT, SpSubview > { public: @@ -28,9 +30,9 @@ class SpSubview : public SpBase > typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const uword aux_row1; const uword aux_col1; @@ -39,31 +41,32 @@ class SpSubview : public SpBase > const uword n_elem; const uword n_nonzero; - friend class SpValProxy< SpSubview >; // allow SpValProxy to call insert_element() and delete_element() - protected: - - arma_inline SpSubview(const SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); - arma_inline SpSubview( SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); - + + inline SpSubview(const SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols); + public: - + inline ~SpSubview(); - + inline SpSubview() = delete; + + inline SpSubview(const SpSubview& in); + inline SpSubview( SpSubview&& in); + inline const SpSubview& operator+= (const eT val); inline const SpSubview& operator-= (const eT val); inline const SpSubview& operator*= (const eT val); inline const SpSubview& operator/= (const eT val); - + inline const SpSubview& operator=(const SpSubview& x); - + template inline const SpSubview& operator= (const Base& x); template inline const SpSubview& operator+=(const Base& x); template inline const SpSubview& operator-=(const Base& x); template inline const SpSubview& operator*=(const Base& x); template inline const SpSubview& operator%=(const Base& x); template inline const SpSubview& operator/=(const Base& x); - + template inline const SpSubview& operator_equ_common(const SpBase& x); template inline const SpSubview& operator= (const SpBase& x); @@ -75,7 +78,7 @@ class SpSubview : public SpBase > /* inline static void extract(SpMat& out, const SpSubview& in); - + inline static void plus_inplace(Mat& out, const subview& in); inline static void minus_inplace(Mat& out, const subview& in); inline static void schur_inplace(Mat& out, const subview& in); @@ -91,205 +94,210 @@ class SpSubview : public SpBase > inline void clean(const pod_type threshold); + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); inline void eye(); - - arma_hot inline SpSubview_MapMat_val operator[](const uword i); - arma_hot inline eT operator[](const uword i) const; - - arma_hot inline SpSubview_MapMat_val operator()(const uword i); - arma_hot inline eT operator()(const uword i) const; - - arma_hot inline SpSubview_MapMat_val operator()(const uword in_row, const uword in_col); - arma_hot inline eT operator()(const uword in_row, const uword in_col) const; - - arma_hot inline SpSubview_MapMat_val at(const uword i); - arma_hot inline eT at(const uword i) const; - - arma_hot inline SpSubview_MapMat_val at(const uword in_row, const uword in_col); - arma_hot inline eT at(const uword in_row, const uword in_col) const; - + inline void randu(); + inline void randn(); + + + arma_warn_unused inline SpSubview_MapMat_val operator[](const uword i); + arma_warn_unused inline eT operator[](const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val operator()(const uword i); + arma_warn_unused inline eT operator()(const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline SpSubview_MapMat_val at(const uword i); + arma_warn_unused inline eT at(const uword i) const; + + arma_warn_unused inline SpSubview_MapMat_val at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + inline bool check_overlap(const SpSubview& x) const; - - inline bool is_vec() const; - - inline SpSubview row(const uword row_num); - inline const SpSubview row(const uword row_num) const; - - inline SpSubview col(const uword col_num); - inline const SpSubview col(const uword col_num) const; - + + arma_warn_unused inline bool is_vec() const; + + inline SpSubview_row row(const uword row_num); + inline const SpSubview_row row(const uword row_num) const; + + inline SpSubview_col col(const uword col_num); + inline const SpSubview_col col(const uword col_num) const; + inline SpSubview rows(const uword in_row1, const uword in_row2); inline const SpSubview rows(const uword in_row1, const uword in_row2) const; - + inline SpSubview cols(const uword in_col1, const uword in_col2); inline const SpSubview cols(const uword in_col1, const uword in_col2) const; - + inline SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2); inline const SpSubview submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const; - + inline SpSubview submat(const span& row_span, const span& col_span); inline const SpSubview submat(const span& row_span, const span& col_span) const; - + inline SpSubview operator()(const uword row_num, const span& col_span); inline const SpSubview operator()(const uword row_num, const span& col_span) const; - + inline SpSubview operator()(const span& row_span, const uword col_num); inline const SpSubview operator()(const span& row_span, const uword col_num) const; - + inline SpSubview operator()(const span& row_span, const span& col_span); inline const SpSubview operator()(const span& row_span, const span& col_span) const; inline void swap_rows(const uword in_row1, const uword in_row2); inline void swap_cols(const uword in_col1, const uword in_col2); - + // Forward declarations. class iterator_base; class const_iterator; class iterator; class const_row_iterator; class row_iterator; - + // Similar to SpMat iterators but automatically iterates past and ignores values not in the subview. class iterator_base { public: - + inline iterator_base(const SpSubview& in_M); inline iterator_base(const SpSubview& in_M, const uword col, const uword pos); - + arma_inline uword col() const { return internal_col; } arma_inline uword pos() const { return internal_pos; } - + arma_aligned const SpSubview* M; arma_aligned uword internal_col; arma_aligned uword internal_pos; - + typedef std::bidirectional_iterator_tag iterator_category; typedef eT value_type; typedef std::ptrdiff_t difference_type; // TODO: not certain on this one typedef const eT* pointer; typedef const eT& reference; }; - + class const_iterator : public iterator_base { public: - + inline const_iterator(const SpSubview& in_M, uword initial_pos = 0); inline const_iterator(const SpSubview& in_M, uword in_row, uword in_col); inline const_iterator(const SpSubview& in_M, uword in_row, uword in_col, uword in_pos, uword skip_pos); inline const_iterator(const const_iterator& other); - + arma_inline eT operator*() const; - + // Don't hold location internally; call "dummy" methods to get that information. arma_inline uword row() const { return iterator_base::M->m.row_indices[iterator_base::internal_pos + skip_pos] - iterator_base::M->aux_row1; } - - inline arma_hot const_iterator& operator++(); - inline arma_warn_unused const_iterator operator++(int); - - inline arma_hot const_iterator& operator--(); - inline arma_warn_unused const_iterator operator--(int); - - inline arma_hot bool operator!=(const const_iterator& rhs) const; - inline arma_hot bool operator==(const const_iterator& rhs) const; - - inline arma_hot bool operator!=(const typename SpMat::const_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpMat::const_iterator& rhs) const; - - inline arma_hot bool operator!=(const const_row_iterator& rhs) const; - inline arma_hot bool operator==(const const_row_iterator& rhs) const; - - inline arma_hot bool operator!=(const typename SpMat::const_row_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpMat::const_row_iterator& rhs) const; - + + arma_hot inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); + + arma_hot inline const_iterator& operator--(); + arma_warn_unused inline const_iterator operator--(int); + + arma_hot inline bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_iterator& rhs) const; + + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_row_iterator& rhs) const; + arma_aligned uword skip_pos; // not used in row_iterator or const_row_iterator }; - + class iterator : public const_iterator { public: - + inline iterator(SpSubview& in_M, const uword initial_pos = 0) : const_iterator(in_M, initial_pos) { } inline iterator(SpSubview& in_M, const uword in_row, const uword in_col) : const_iterator(in_M, in_row, in_col) { } inline iterator(SpSubview& in_M, const uword in_row, const uword in_col, const uword in_pos, const uword in_skip_pos) : const_iterator(in_M, in_row, in_col, in_pos, in_skip_pos) { } inline iterator(const iterator& other) : const_iterator(other) { } - - inline arma_hot SpValProxy > operator*(); - + + arma_hot inline SpValProxy< SpSubview > operator*(); + // overloads needed for return type correctness - inline arma_hot iterator& operator++(); - inline arma_warn_unused iterator operator++(int); - - inline arma_hot iterator& operator--(); - inline arma_warn_unused iterator operator--(int); - + arma_hot inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); + + arma_hot inline iterator& operator--(); + arma_warn_unused inline iterator operator--(int); + // This has a different value_type than iterator_base. - typedef SpValProxy > value_type; - typedef const SpValProxy >* pointer; - typedef const SpValProxy >& reference; + typedef SpValProxy< SpSubview > value_type; + typedef const SpValProxy< SpSubview >* pointer; + typedef const SpValProxy< SpSubview >& reference; }; - + class const_row_iterator : public iterator_base { public: - + inline const_row_iterator(); inline const_row_iterator(const SpSubview& in_M, uword initial_pos = 0); inline const_row_iterator(const SpSubview& in_M, uword in_row, uword in_col); inline const_row_iterator(const const_row_iterator& other); - - inline arma_hot const_row_iterator& operator++(); - inline arma_warn_unused const_row_iterator operator++(int); - - inline arma_hot const_row_iterator& operator--(); - inline arma_warn_unused const_row_iterator operator--(int); - + + arma_hot inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); + + arma_hot inline const_row_iterator& operator--(); + arma_warn_unused inline const_row_iterator operator--(int); + uword internal_row; // Hold row internally because we use internal_pos differently. uword actual_pos; // Actual position in subview's parent matrix. - + arma_inline eT operator*() const { return iterator_base::M->m.values[actual_pos]; } - + arma_inline uword row() const { return internal_row; } - - inline arma_hot bool operator!=(const const_iterator& rhs) const; - inline arma_hot bool operator==(const const_iterator& rhs) const; - - inline arma_hot bool operator!=(const typename SpMat::const_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpMat::const_iterator& rhs) const; - - inline arma_hot bool operator!=(const const_row_iterator& rhs) const; - inline arma_hot bool operator==(const const_row_iterator& rhs) const; - - inline arma_hot bool operator!=(const typename SpMat::const_row_iterator& rhs) const; - inline arma_hot bool operator==(const typename SpMat::const_row_iterator& rhs) const; + + arma_hot inline bool operator!=(const const_iterator& rhs) const; + arma_hot inline bool operator==(const const_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_iterator& rhs) const; + + arma_hot inline bool operator!=(const const_row_iterator& rhs) const; + arma_hot inline bool operator==(const const_row_iterator& rhs) const; + + arma_hot inline bool operator!=(const typename SpMat::const_row_iterator& rhs) const; + arma_hot inline bool operator==(const typename SpMat::const_row_iterator& rhs) const; }; - + class row_iterator : public const_row_iterator { public: - + inline row_iterator(SpSubview& in_M, uword initial_pos = 0) : const_row_iterator(in_M, initial_pos) { } inline row_iterator(SpSubview& in_M, uword in_row, uword in_col) : const_row_iterator(in_M, in_row, in_col) { } inline row_iterator(const row_iterator& other) : const_row_iterator(other) { } - - inline arma_hot SpValProxy > operator*(); - + + arma_hot inline SpValProxy< SpSubview > operator*(); + // overloads needed for return type correctness - inline arma_hot row_iterator& operator++(); - inline arma_warn_unused row_iterator operator++(int); - - inline arma_hot row_iterator& operator--(); - inline arma_warn_unused row_iterator operator--(int); - + arma_hot inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); + + arma_hot inline row_iterator& operator--(); + arma_warn_unused inline row_iterator operator--(int); + // This has a different value_type than iterator_base. - typedef SpValProxy > value_type; - typedef const SpValProxy >* pointer; - typedef const SpValProxy >& reference; + typedef SpValProxy< SpSubview > value_type; + typedef const SpValProxy< SpSubview >* pointer; + typedef const SpValProxy< SpSubview >& reference; }; inline iterator begin(); @@ -298,112 +306,113 @@ class SpSubview : public SpBase > inline iterator begin_col(const uword col_num); inline const_iterator begin_col(const uword col_num) const; - + inline row_iterator begin_row(const uword row_num = 0); inline const_row_iterator begin_row(const uword row_num = 0) const; - + inline iterator end(); inline const_iterator end() const; inline const_iterator cend() const; - + inline row_iterator end_row(); inline const_row_iterator end_row() const; - + inline row_iterator end_row(const uword row_num); inline const_row_iterator end_row(const uword row_num) const; //! don't use this unless you're writing internal Armadillo code arma_inline bool is_alias(const SpMat& X) const; - - + + private: + friend class SpMat; - SpSubview(); + friend class SpSubview_col; + friend class SpSubview_row; + friend class SpValProxy< SpSubview >; // allow SpValProxy to call insert_element() and delete_element() - inline arma_warn_unused eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); - inline void delete_element(const uword in_row, const uword in_col); + arma_warn_unused inline eT& insert_element(const uword in_row, const uword in_col, const eT in_val = eT(0)); + inline void delete_element(const uword in_row, const uword in_col); inline void invalidate_cache() const; }; -/* + + template class SpSubview_col : public SpSubview { public: - + typedef eT elem_type; typedef typename get_pod_type::result pod_type; - + + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; + inline void operator= (const SpSubview& x); inline void operator= (const SpSubview_col& x); - - template - inline void operator= (const Base& x); - - inline SpSubview_col rows(const uword in_row1, const uword in_row2); - inline const SpSubview_col rows(const uword in_row1, const uword in_row2) const; - - inline SpSubview_col subvec(const uword in_row1, const uword in_row2); - inline const SpSubview_col subvec(const uword in_row1, const uword in_row2) const; - - + + template inline void operator= (const SpBase& x); + template inline void operator= (const Base& x); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + protected: - - inline SpSubview_col(const Mat& in_m, const uword in_col); - inline SpSubview_col( Mat& in_m, const uword in_col); - - inline SpSubview_col(const Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); - inline SpSubview_col( Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); - - + + inline SpSubview_col(const SpMat& in_m, const uword in_col); + inline SpSubview_col(const SpMat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); + inline SpSubview_col() = delete; + + private: - - friend class Mat; - friend class Col; + + friend class SpMat; friend class SpSubview; - - SpSubview_col(); }; + + template class SpSubview_row : public SpSubview { public: - + typedef eT elem_type; typedef typename get_pod_type::result pod_type; - + + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + inline void operator= (const SpSubview& x); inline void operator= (const SpSubview_row& x); - - template - inline void operator= (const Base& x); - - inline SpSubview_row cols(const uword in_col1, const uword in_col2); - inline const SpSubview_row cols(const uword in_col1, const uword in_col2) const; - - inline SpSubview_row subvec(const uword in_col1, const uword in_col2); - inline const SpSubview_row subvec(const uword in_col1, const uword in_col2) const; - - + + template inline void operator= (const SpBase& x); + template inline void operator= (const Base& x); + + arma_warn_unused inline const SpOp,spop_htrans> t() const; + arma_warn_unused inline const SpOp,spop_htrans> ht() const; + arma_warn_unused inline const SpOp,spop_strans> st() const; + + protected: - - inline SpSubview_row(const Mat& in_m, const uword in_row); - inline SpSubview_row( Mat& in_m, const uword in_row); - - inline SpSubview_row(const Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); - inline SpSubview_row( Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); - - + + inline SpSubview_row(const SpMat& in_m, const uword in_row); + inline SpSubview_row(const SpMat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); + inline SpSubview_row() = delete; + + private: - - friend class Mat; - friend class Row; + + friend class SpMat; friend class SpSubview; - - SpSubview_row(); }; -*/ + + //! @} diff --git a/src/armadillo_bits/SpSubview_col_list_bones.hpp b/src/armadillo_bits/SpSubview_col_list_bones.hpp new file mode 100644 index 00000000..85012913 --- /dev/null +++ b/src/armadillo_bits/SpSubview_col_list_bones.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview_col_list +//! @{ + + + +template +class SpSubview_col_list : public SpBase< eT, SpSubview_col_list > + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + const SpMat& m; + const quasi_unwrap U_ci; + + + protected: + + arma_inline SpSubview_col_list(const SpMat& in_m, const Base& in_ci); + + + public: + + inline ~SpSubview_col_list(); + inline SpSubview_col_list() = delete; + + template inline void for_each(functor F); + template inline void for_each(functor F) const; + + template inline void transform(functor F); + + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void fill(const eT val); + inline void zeros(); + inline void ones(); + + inline void operator+= (const eT val); + inline void operator-= (const eT val); + inline void operator*= (const eT val); + inline void operator/= (const eT val); + + template inline void operator= (const Base& x); + template inline void operator+=(const Base& x); + template inline void operator-=(const Base& x); + template inline void operator%=(const Base& x); + template inline void operator/=(const Base& x); + + inline void operator= (const SpSubview_col_list& x); + template inline void operator= (const SpSubview_col_list& x); + + template inline void operator= (const SpBase& x); + template inline void operator+= (const SpBase& x); + template inline void operator-= (const SpBase& x); + template inline void operator%= (const SpBase& x); + template inline void operator/= (const SpBase& x); + + inline static void extract(SpMat& out, const SpSubview_col_list& in); + + inline static void plus_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void minus_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void schur_inplace(SpMat& out, const SpSubview_col_list& in); + inline static void div_inplace(SpMat& out, const SpSubview_col_list& in); + + + friend class SpMat; + }; + + + +//! @} diff --git a/src/armadillo_bits/SpSubview_col_list_meat.hpp b/src/armadillo_bits/SpSubview_col_list_meat.hpp new file mode 100644 index 00000000..46d2d8d5 --- /dev/null +++ b/src/armadillo_bits/SpSubview_col_list_meat.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpSubview_col_list +//! @{ + + + +template +inline +SpSubview_col_list::~SpSubview_col_list() + { + arma_extra_debug_sigprint(); + } + + + +template +arma_inline +SpSubview_col_list::SpSubview_col_list + ( + const SpMat& in_m, + const Base& in_ci + ) + : m (in_m ) + , U_ci(in_ci.get_ref()) + { + arma_extra_debug_sigprint(); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_check + ( + ( (ci.is_vec() == false) && (ci.is_empty() == false) ), + "SpMat::cols(): given object must be a vector" + ); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + arma_debug_check_bounds( (i >= in_m.n_cols), "SpMat::cols(): index out of bounds" ); + } + } + + + +//! apply a functor to each element +template +template +inline +void +SpSubview_col_list::for_each(functor F) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.for_each(F); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::for_each(functor F) const + { + arma_extra_debug_sigprint(); + + const SpMat tmp(*this); + + tmp.for_each(F); + } + + + +//! transform each element using a functor +template +template +inline +void +SpSubview_col_list::transform(functor F) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.transform(F); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::fill(const eT val) + { + arma_extra_debug_sigprint(); + + Mat tmp(m.n_rows, U_ci.M.n_elem, arma_nozeros_indicator()); tmp.fill(val); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::zeros() + { + arma_extra_debug_sigprint(); + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + m_local.sync_csc(); + m_local.invalidate_cache(); + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = m_local.col_ptrs[i+1] - m_local.col_ptrs[i]; + + uword offset = m_local.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + access::rw(m_local.values[offset]) = eT(0); + + ++offset; + } + } + + m_local.remove_zeros(); + } + + + +template +inline +void +SpSubview_col_list::ones() + { + arma_extra_debug_sigprint(); + + const Mat tmp(m.n_rows, U_ci.M.n_elem, fill::ones); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::operator+= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const Mat tmp3 = tmp1 + tmp2; + + (*this).operator=(tmp3); + } + + + +template +inline +void +SpSubview_col_list::operator-= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const Mat tmp3 = tmp1 - tmp2; + + (*this).operator=(tmp3); + } + + + +template +inline +void +SpSubview_col_list::operator*= (const eT val) + { + arma_extra_debug_sigprint(); + + if(val == eT(0)) { (*this).zeros(); return; } + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + m_local.sync_csc(); + m_local.invalidate_cache(); + + bool has_zero = false; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = m_local.col_ptrs[i+1] - m_local.col_ptrs[i]; + + uword offset = m_local.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + eT& m_local_val = access::rw(m_local.values[offset]); + + m_local_val *= val; + + if(m_local_val == eT(0)) { has_zero = true; } + + ++offset; + } + } + + if(has_zero) { m_local.remove_zeros(); } + } + + + +template +inline +void +SpSubview_col_list::operator/= (const eT val) + { + arma_extra_debug_sigprint(); + + const SpMat tmp1(*this); + + Mat tmp2(tmp1.n_rows, tmp1.n_cols, arma_nozeros_indicator()); tmp2.fill(val); + + const SpMat tmp3 = tmp1 / tmp2; + + (*this).operator=(tmp3); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const Base& x) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap U(x.get_ref()); + const Mat& X = U.M; + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( m_local.n_rows, ci_n_elem, X.n_rows, X.n_cols, "SpMat::cols()" ); + + const uword X_n_elem = X.n_elem; + const eT* X_mem = X.memptr(); + + uword X_n_nonzero = 0; + + for(uword i=0; i < X_n_elem; ++i) { X_n_nonzero += (X_mem[i] != eT(0)) ? uword(1) : uword(0); } + + SpMat Y(arma_reserve_indicator(), X.n_rows, m_local.n_cols, X_n_nonzero); + + uword count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + for(uword row=0; row < X.n_rows; ++row) + { + const eT X_val = (*X_mem); ++X_mem; + + if(X_val != eT(0)) + { + access::rw(Y.row_indices[count]) = row; + access::rw(Y.values [count]) = X_val; + ++count; + ++access::rw(Y.col_ptrs[i + 1]); + } + } + } + + // fix the column pointers + for(uword i = 0; i < Y.n_cols; ++i) + { + access::rw(Y.col_ptrs[i+1]) += Y.col_ptrs[i]; + } + + (*this).zeros(); + + SpMat tmp = m_local + Y; + + m_local.steal_mem(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator+= (const Base& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp = SpMat(*this) + x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator-= (const Base& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp = SpMat(*this) - x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator%= (const Base& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) % x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator/= (const Base& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) / x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview_col_list::operator= (const SpSubview_col_list& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(x); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const SpSubview_col_list& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(x); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(x.get_ref()); + const SpMat& X = U.M; + + if(U.is_alias(m)) + { + const SpMat tmp(X); + + (*this).operator=(tmp); + + return; + } + + SpMat& m_local = const_cast< SpMat& >(m); + + const umat& ci = U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + arma_debug_assert_same_size( m_local.n_rows, ci_n_elem, X.n_rows, X.n_cols, "SpMat::cols()" ); + + SpMat Y(arma_reserve_indicator(), X.n_rows, m_local.n_cols, X.n_nonzero); + + uword count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + typename SpMat::const_col_iterator X_col_it = X.begin_col(ci_count); + typename SpMat::const_col_iterator X_col_it_end = X.end_col(ci_count); + + while(X_col_it != X_col_it_end) + { + access::rw(Y.row_indices[count]) = X_col_it.row(); + access::rw(Y.values [count]) = (*X_col_it); + ++count; + ++access::rw(Y.col_ptrs[i + 1]); + ++X_col_it; + } + } + + // fix the column pointers + for(uword i = 0; i < Y.n_cols; ++i) + { + access::rw(Y.col_ptrs[i+1]) += Y.col_ptrs[i]; + } + + (*this).zeros(); + + SpMat tmp = m_local + Y; + + m_local.steal_mem(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator+= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) + x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator-= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) - x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator%= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const SpMat tmp = SpMat(*this) % x.get_ref(); + + (*this).operator=(tmp); + } + + + +template +template +inline +void +SpSubview_col_list::operator/= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp /= x.get_ref(); + + (*this).operator=(tmp); + } + + + +// +// + + + +template +inline +void +SpSubview_col_list::extract(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + // NOTE: aliasing is handled by SpMat::operator=(const SpSubview_col_list& in) + + const umat& ci = in.U_ci.M; + const uword* ci_mem = ci.memptr(); + const uword ci_n_elem = ci.n_elem; + + const SpMat& in_m = in.m; + + in_m.sync_csc(); + + uword total_n_nonzero = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = in_m.col_ptrs[i+1] - in_m.col_ptrs[i]; + + total_n_nonzero += col_n_nonzero; + } + + out.reserve(in.m.n_rows, ci_n_elem, total_n_nonzero); + + uword out_n_nonzero = 0; + uword out_col_count = 0; + + for(uword ci_count=0; ci_count < ci_n_elem; ++ci_count) + { + const uword i = ci_mem[ci_count]; + + const uword col_n_nonzero = in_m.col_ptrs[i+1] - in_m.col_ptrs[i]; + + uword offset = in_m.col_ptrs[i]; + + for(uword j=0; j < col_n_nonzero; ++j) + { + const eT val = in_m.values [ offset ]; + const uword row = in_m.row_indices[ offset ]; + + ++offset; + + access::rw(out.values [out_n_nonzero]) = val; + access::rw(out.row_indices[out_n_nonzero]) = row; + + access::rw(out.col_ptrs[out_col_count+1])++; + + ++out_n_nonzero; + } + + ++out_col_count; + } + + // fix the column pointers + for(uword i = 0; i < out.n_cols; ++i) + { + access::rw(out.col_ptrs[i+1]) += out.col_ptrs[i]; + } + } + + + +template +inline +void +SpSubview_col_list::plus_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out += tmp; + } + + + +template +inline +void +SpSubview_col_list::minus_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out -= tmp; + } + + + +template +inline +void +SpSubview_col_list::schur_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out %= tmp; + } + + + +template +inline +void +SpSubview_col_list::div_inplace(SpMat& out, const SpSubview_col_list& in) + { + arma_extra_debug_sigprint(); + + const SpMat tmp(in); + + out /= tmp; + } + + + +//! @} diff --git a/src/armadillo_bits/SpSubview_iterators_meat.hpp b/src/armadillo_bits/SpSubview_iterators_meat.hpp index cd583c25..d97d7c63 100644 --- a/src/armadillo_bits/SpSubview_iterators_meat.hpp +++ b/src/armadillo_bits/SpSubview_iterators_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -278,7 +280,6 @@ SpSubview::const_iterator::operator*() const template inline -arma_hot typename SpSubview::const_iterator& SpSubview::const_iterator::operator++() { @@ -337,7 +338,6 @@ SpSubview::const_iterator::operator++() template inline -arma_warn_unused typename SpSubview::const_iterator SpSubview::const_iterator::operator++(int) { @@ -352,7 +352,6 @@ SpSubview::const_iterator::operator++(int) template inline -arma_hot typename SpSubview::const_iterator& SpSubview::const_iterator::operator--() { @@ -408,7 +407,6 @@ SpSubview::const_iterator::operator--() template inline -arma_warn_unused typename SpSubview::const_iterator SpSubview::const_iterator::operator--(int) { @@ -423,7 +421,6 @@ SpSubview::const_iterator::operator--(int) template inline -arma_hot bool SpSubview::const_iterator::operator==(const const_iterator& rhs) const { @@ -434,7 +431,6 @@ SpSubview::const_iterator::operator==(const const_iterator& rhs) const template inline -arma_hot bool SpSubview::const_iterator::operator!=(const const_iterator& rhs) const { @@ -445,7 +441,6 @@ SpSubview::const_iterator::operator!=(const const_iterator& rhs) const template inline -arma_hot bool SpSubview::const_iterator::operator==(const typename SpMat::const_iterator& rhs) const { @@ -456,7 +451,6 @@ SpSubview::const_iterator::operator==(const typename SpMat::const_iterat template inline -arma_hot bool SpSubview::const_iterator::operator!=(const typename SpMat::const_iterator& rhs) const { @@ -467,7 +461,6 @@ SpSubview::const_iterator::operator!=(const typename SpMat::const_iterat template inline -arma_hot bool SpSubview::const_iterator::operator==(const const_row_iterator& rhs) const { @@ -478,7 +471,6 @@ SpSubview::const_iterator::operator==(const const_row_iterator& rhs) const template inline -arma_hot bool SpSubview::const_iterator::operator!=(const const_row_iterator& rhs) const { @@ -489,7 +481,6 @@ SpSubview::const_iterator::operator!=(const const_row_iterator& rhs) const template inline -arma_hot bool SpSubview::const_iterator::operator==(const typename SpMat::const_row_iterator& rhs) const { @@ -500,7 +491,6 @@ SpSubview::const_iterator::operator==(const typename SpMat::const_row_it template inline -arma_hot bool SpSubview::const_iterator::operator!=(const typename SpMat::const_row_iterator& rhs) const { @@ -515,11 +505,10 @@ SpSubview::const_iterator::operator!=(const typename SpMat::const_row_it template inline -arma_hot -SpValProxy > +SpValProxy< SpSubview > SpSubview::iterator::operator*() { - return SpValProxy >( + return SpValProxy< SpSubview >( const_iterator::row(), iterator_base::col(), access::rw(*iterator_base::M), @@ -530,7 +519,6 @@ SpSubview::iterator::operator*() template inline -arma_hot typename SpSubview::iterator& SpSubview::iterator::operator++() { @@ -542,7 +530,6 @@ SpSubview::iterator::operator++() template inline -arma_warn_unused typename SpSubview::iterator SpSubview::iterator::operator++(int) { @@ -557,7 +544,6 @@ SpSubview::iterator::operator++(int) template inline -arma_hot typename SpSubview::iterator& SpSubview::iterator::operator--() { @@ -569,7 +555,6 @@ SpSubview::iterator::operator--() template inline -arma_warn_unused typename SpSubview::iterator SpSubview::iterator::operator--(int) { @@ -625,9 +610,9 @@ SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, // Since we don't know where the elements are in each row, we have to loop // across all columns looking for elements in row 0 and add to our sum, then // in row 1, and so forth, until we get to the desired position. - for (uword row = 0; row < iterator_base::M->n_rows; ++row) + for(uword row = 0; row < iterator_base::M->n_rows; ++row) { - for (uword col = 0; col < iterator_base::M->n_cols; ++col) + for(uword col = 0; col < iterator_base::M->n_cols; ++col) { // Find the first element with row greater than or equal to row + aux_row. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; @@ -636,24 +621,24 @@ SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, row + aux_row); const uword offset = uword(pos_ptr - start_ptr); - if (iterator_base::M->m.row_indices[col_offset + offset] == row + aux_row) + if(iterator_base::M->m.row_indices[col_offset + offset] == row + aux_row) { cur_actual_pos = col_offset + offset; // Increment position portably. - if (cur_pos == std::numeric_limits::max()) + if(cur_pos == std::numeric_limits::max()) cur_pos = 0; else ++cur_pos; // Do we terminate? - if (cur_pos == initial_pos) + if(cur_pos == initial_pos) { internal_row = row; iterator_base::internal_col = col; @@ -702,7 +687,7 @@ SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, uword cur_min_col = 0; uword cur_actual_pos = 0; - for (uword col = 0; col < iterator_base::M->n_cols; ++col) + for(uword col = 0; col < iterator_base::M->n_cols; ++col) { // Find the first element with row greater than or equal to in_row. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; @@ -711,12 +696,12 @@ SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { // First let us find the first element that is in the subview. const uword* first_subview_ptr = std::lower_bound(start_ptr, end_ptr, aux_row); - if (first_subview_ptr != end_ptr && (*first_subview_ptr) < aux_row + iterator_base::M->n_rows) + if(first_subview_ptr != end_ptr && (*first_subview_ptr) < aux_row + iterator_base::M->n_rows) { // There exists at least one element in the subview. const uword* pos_ptr = std::lower_bound(first_subview_ptr, end_ptr, aux_row + in_row); @@ -725,15 +710,15 @@ SpSubview::const_row_iterator::const_row_iterator(const SpSubview& in_M, // than in_row. cur_pos += uword(pos_ptr - first_subview_ptr); - if (pos_ptr != end_ptr && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + if(pos_ptr != end_ptr && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // This is the row index of the first element in the column with row // index greater than or equal to in_row + aux_row. - if ((*pos_ptr) - aux_row < cur_min_row) + if((*pos_ptr) - aux_row < cur_min_row) { // If we are in the desired row but before the desired column, we // can't take this. - if (col >= in_col) + if(col >= in_col) { cur_min_row = (*pos_ptr) - aux_row; cur_min_col = col; @@ -768,7 +753,6 @@ SpSubview::const_row_iterator::const_row_iterator(const const_row_iterator& template inline -arma_hot typename SpSubview::const_row_iterator& SpSubview::const_row_iterator::operator++() { @@ -795,7 +779,7 @@ SpSubview::const_row_iterator::operator++() uword next_min_col = 0; uword next_actual_pos = 0; - for (uword col = iterator_base::internal_col + 1; col < M_n_cols; ++col) + for(uword col = iterator_base::internal_col + 1; col < M_n_cols; ++col) { // Find the first element with row greater than or equal to row. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; @@ -804,24 +788,24 @@ SpSubview::const_row_iterator::operator++() const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { // Find the first element in the column with row greater than or equal to // the current row. Since this is a subview, it's possible that we may // find rows past the end of the subview. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row); - if (pos_ptr != end_ptr) + if(pos_ptr != end_ptr) { // We found something; is the row index correct? - if ((*pos_ptr) == internal_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + if((*pos_ptr) == internal_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // Exact match---so we are done. iterator_base::internal_col = col; actual_pos = col_offset + (pos_ptr - start_ptr); return *this; } - else if ((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + else if((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // The first element in this column is in a subsequent row, but it's // the minimum row we've seen so far. @@ -829,7 +813,7 @@ SpSubview::const_row_iterator::operator++() next_min_col = col; next_actual_pos = col_offset + (pos_ptr - start_ptr); } - else if ((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + else if((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // The first element in this column is in a subsequent row that we // already have another elemnt for, but the column index is less so @@ -842,7 +826,7 @@ SpSubview::const_row_iterator::operator++() } // Restart the search in the next row. - for (uword col = 0; col <= iterator_base::internal_col; ++col) + for(uword col = 0; col <= iterator_base::internal_col; ++col) { // Find the first element with row greater than or equal to row + 1. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; @@ -851,14 +835,14 @@ SpSubview::const_row_iterator::operator++() const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row + 1); - if (pos_ptr != end_ptr) + if(pos_ptr != end_ptr) { // We found something in the column, but is the row index correct? - if ((*pos_ptr) == internal_row + aux_row + 1 && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + if((*pos_ptr) == internal_row + aux_row + 1 && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // Exact match---so we are done. iterator_base::internal_col = col; @@ -866,7 +850,7 @@ SpSubview::const_row_iterator::operator++() actual_pos = col_offset + (pos_ptr - start_ptr); return *this; } - else if ((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + else if((*pos_ptr) < next_min_row + aux_row && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // The first element in this column is in a subsequent row, but it's // the minimum row we've seen so far. @@ -874,7 +858,7 @@ SpSubview::const_row_iterator::operator++() next_min_col = col; next_actual_pos = col_offset + (pos_ptr - start_ptr); } - else if ((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) + else if((*pos_ptr) == next_min_row + aux_row && col < next_min_col && (*pos_ptr) < aux_row + iterator_base::M->n_rows) { // We've found a better column. next_min_col = col; @@ -895,7 +879,6 @@ SpSubview::const_row_iterator::operator++() template inline -arma_warn_unused typename SpSubview::const_row_iterator SpSubview::const_row_iterator::operator++(int) { @@ -910,11 +893,10 @@ SpSubview::const_row_iterator::operator++(int) template inline -arma_hot typename SpSubview::const_row_iterator& SpSubview::const_row_iterator::operator--() { - if (iterator_base::internal_pos == 0) + if(iterator_base::internal_pos == 0) { // We are already at the beginning. return *this; @@ -930,7 +912,7 @@ SpSubview::const_row_iterator::operator--() uword max_col = 0; uword next_actual_pos = 0; - for (uword col = iterator_base::internal_col; col >= 1; --col) + for(uword col = iterator_base::internal_col; col >= 1; --col) { // Find the first element with row greater than or equal to in_row + 1. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col - 1]; @@ -939,21 +921,21 @@ SpSubview::const_row_iterator::operator--() const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { // There are elements in this column. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row + 1); - if (pos_ptr != start_ptr) + if(pos_ptr != start_ptr) { - if (*(pos_ptr - 1) > max_row + aux_row) + if(*(pos_ptr - 1) > max_row + aux_row) { // There are elements in this column with row index < internal_row. max_row = *(pos_ptr - 1) - aux_row; max_col = col - 1; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); } - else if (*(pos_ptr - 1) == max_row + aux_row && (col - 1) >= max_col) + else if(*(pos_ptr - 1) == max_row + aux_row && (col - 1) >= max_col) { max_col = col - 1; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); @@ -962,7 +944,7 @@ SpSubview::const_row_iterator::operator--() } } - for (uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) + for(uword col = iterator_base::M->n_cols - 1; col >= iterator_base::internal_col; --col) { // Find the first element with row greater than or equal to row + 1. const uword col_offset = iterator_base::M->m.col_ptrs[col + aux_col ]; @@ -971,21 +953,21 @@ SpSubview::const_row_iterator::operator--() const uword* start_ptr = &iterator_base::M->m.row_indices[ col_offset]; const uword* end_ptr = &iterator_base::M->m.row_indices[next_col_offset]; - if (start_ptr != end_ptr) + if(start_ptr != end_ptr) { // There are elements in this column. const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, internal_row + aux_row); - if (pos_ptr != start_ptr) + if(pos_ptr != start_ptr) { // There are elements in this column with row index < internal_row. - if (*(pos_ptr - 1) > max_row + aux_row) + if(*(pos_ptr - 1) > max_row + aux_row) { max_row = *(pos_ptr - 1) - aux_row; max_col = col; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); } - else if (*(pos_ptr - 1) == max_row + aux_row && col >= max_col) + else if(*(pos_ptr - 1) == max_row + aux_row && col >= max_col) { max_col = col; next_actual_pos = col_offset + (pos_ptr - 1 - start_ptr); @@ -993,7 +975,7 @@ SpSubview::const_row_iterator::operator--() } } - if (col == 0) // Catch edge case that the loop termination condition won't. + if(col == 0) // Catch edge case that the loop termination condition won't. { break; } @@ -1010,7 +992,6 @@ SpSubview::const_row_iterator::operator--() template inline -arma_warn_unused typename SpSubview::const_row_iterator SpSubview::const_row_iterator::operator--(int) { @@ -1025,7 +1006,6 @@ SpSubview::const_row_iterator::operator--(int) template inline -arma_hot bool SpSubview::const_row_iterator::operator==(const const_iterator& rhs) const { @@ -1036,7 +1016,6 @@ SpSubview::const_row_iterator::operator==(const const_iterator& rhs) const template inline -arma_hot bool SpSubview::const_row_iterator::operator!=(const const_iterator& rhs) const { @@ -1047,7 +1026,6 @@ SpSubview::const_row_iterator::operator!=(const const_iterator& rhs) const template inline -arma_hot bool SpSubview::const_row_iterator::operator==(const typename SpMat::const_iterator& rhs) const { @@ -1058,7 +1036,6 @@ SpSubview::const_row_iterator::operator==(const typename SpMat::const_it template inline -arma_hot bool SpSubview::const_row_iterator::operator!=(const typename SpMat::const_iterator& rhs) const { @@ -1069,7 +1046,6 @@ SpSubview::const_row_iterator::operator!=(const typename SpMat::const_it template inline -arma_hot bool SpSubview::const_row_iterator::operator==(const const_row_iterator& rhs) const { @@ -1080,7 +1056,6 @@ SpSubview::const_row_iterator::operator==(const const_row_iterator& rhs) con template inline -arma_hot bool SpSubview::const_row_iterator::operator!=(const const_row_iterator& rhs) const { @@ -1091,7 +1066,6 @@ SpSubview::const_row_iterator::operator!=(const const_row_iterator& rhs) con template inline -arma_hot bool SpSubview::const_row_iterator::operator==(const typename SpMat::const_row_iterator& rhs) const { @@ -1102,7 +1076,6 @@ SpSubview::const_row_iterator::operator==(const typename SpMat::const_ro template inline -arma_hot bool SpSubview::const_row_iterator::operator!=(const typename SpMat::const_row_iterator& rhs) const { @@ -1117,11 +1090,10 @@ SpSubview::const_row_iterator::operator!=(const typename SpMat::const_ro template inline -arma_hot -SpValProxy > +SpValProxy< SpSubview > SpSubview::row_iterator::operator*() { - return SpValProxy >( + return SpValProxy< SpSubview >( const_row_iterator::internal_row, iterator_base::internal_col, access::rw(*iterator_base::M), @@ -1132,7 +1104,6 @@ SpSubview::row_iterator::operator*() template inline -arma_hot typename SpSubview::row_iterator& SpSubview::row_iterator::operator++() { @@ -1144,7 +1115,6 @@ SpSubview::row_iterator::operator++() template inline -arma_warn_unused typename SpSubview::row_iterator SpSubview::row_iterator::operator++(int) { @@ -1159,7 +1129,6 @@ SpSubview::row_iterator::operator++(int) template inline -arma_hot typename SpSubview::row_iterator& SpSubview::row_iterator::operator--() { @@ -1171,7 +1140,6 @@ SpSubview::row_iterator::operator--() template inline -arma_warn_unused typename SpSubview::row_iterator SpSubview::row_iterator::operator--(int) { @@ -1182,4 +1150,5 @@ SpSubview::row_iterator::operator--(int) return tmp; } + //! @} diff --git a/src/armadillo_bits/SpSubview_meat.hpp b/src/armadillo_bits/SpSubview_meat.hpp index d20dc975..481359be 100644 --- a/src/armadillo_bits/SpSubview_meat.hpp +++ b/src/armadillo_bits/SpSubview_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,16 @@ template -arma_inline +inline +SpSubview::~SpSubview() + { + arma_extra_debug_sigprint_this(this); + } + + + +template +inline SpSubview::SpSubview(const SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols) : m(in_m) , aux_row1(in_row1) @@ -29,7 +40,7 @@ SpSubview::SpSubview(const SpMat& in_m, const uword in_row1, const uword , n_elem(in_n_rows * in_n_cols) , n_nonzero(0) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); m.sync_csc(); @@ -53,44 +64,42 @@ SpSubview::SpSubview(const SpMat& in_m, const uword in_row1, const uword template -arma_inline -SpSubview::SpSubview(SpMat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols) - : m(in_m) - , aux_row1(in_row1) - , aux_col1(in_col1) - , n_rows(in_n_rows) - , n_cols(in_n_cols) - , n_elem(in_n_rows * in_n_cols) - , n_nonzero(0) +inline +SpSubview::SpSubview(const SpSubview& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + , n_nonzero(in.n_nonzero) { - arma_extra_debug_sigprint(); - - m.sync_csc(); - - // There must be a O(1) way to do this - uword lend = m.col_ptrs[in_col1 + in_n_cols]; - uword lend_row = in_row1 + in_n_rows; - uword count = 0; - - for(uword i = m.col_ptrs[in_col1]; i < lend; ++i) - { - const uword m_row_indices_i = m.row_indices[i]; - - const bool condition = (m_row_indices_i >= in_row1) && (m_row_indices_i < lend_row); - - count += condition ? uword(1) : uword(0); - } - - access::rw(n_nonzero) = count; + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); } template inline -SpSubview::~SpSubview() +SpSubview::SpSubview(SpSubview&& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + , n_nonzero(in.n_nonzero) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.aux_row1 ) = 0; + access::rw(in.aux_col1 ) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem ) = 0; + access::rw(in.n_nonzero) = 0; } @@ -102,12 +111,9 @@ SpSubview::operator+=(const eT val) { arma_extra_debug_sigprint(); - if(val == eT(0)) - { - return *this; - } + if(val == eT(0)) { return *this; } - Mat tmp( (*this).n_rows, (*this).n_cols ); + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); tmp.fill(val); @@ -122,13 +128,10 @@ const SpSubview& SpSubview::operator-=(const eT val) { arma_extra_debug_sigprint(); - - if(val == eT(0)) - { - return *this; - } - - Mat tmp( (*this).n_rows, (*this).n_cols ); + + if(val == eT(0)) { return *this; } + + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); tmp.fill(val); @@ -320,7 +323,7 @@ const SpSubview& SpSubview::operator-=(const Base& x) { arma_extra_debug_sigprint(); - + return (*this).operator=( (*this) - x.get_ref() ); } @@ -350,8 +353,61 @@ const SpSubview& SpSubview::operator%=(const Base& x) { arma_extra_debug_sigprint(); - - return (*this).operator=( (*this) % x.get_ref() ); + + SpSubview& sv = (*this); + + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; + + arma_debug_assert_same_size(sv.n_rows, sv.n_cols, B.n_rows, B.n_cols, "element-wise multiplication"); + + SpMat& sv_m = access::rw(sv.m); + + sv_m.sync_csc(); + sv_m.invalidate_cache(); + + const uword m_row_start = sv.aux_row1; + const uword m_row_end = sv.aux_row1 + sv.n_rows - 1; + + const uword m_col_start = sv.aux_col1; + const uword m_col_end = sv.aux_col1 + sv.n_cols - 1; + + constexpr eT zero = eT(0); + + bool has_zero = false; + uword count = 0; + + for(uword m_col = m_col_start; m_col <= m_col_end; ++m_col) + { + const uword sv_col = m_col - m_col_start; + + const uword index_start = sv_m.col_ptrs[m_col ]; + const uword index_end = sv_m.col_ptrs[m_col + 1]; + + for(uword i=index_start; i < index_end; ++i) + { + const uword m_row = sv_m.row_indices[i]; + + if(m_row < m_row_start) { continue; } + if(m_row > m_row_end ) { break; } + + const uword sv_row = m_row - m_row_start; + + eT& m_val = access::rw(sv_m.values[i]); + + const eT result = m_val * B.at(sv_row, sv_col); + + m_val = result; + + if(result == zero) { has_zero = true; } else { ++count; } + } + } + + if(has_zero) { sv_m.remove_zeros(); } + + access::rw(sv.n_nonzero) = count; + + return (*this); } @@ -363,8 +419,69 @@ const SpSubview& SpSubview::operator/=(const Base& x) { arma_extra_debug_sigprint(); - - return (*this).operator=( (*this) / x.get_ref() ); + + const SpSubview& A = (*this); + + const quasi_unwrap U(x.get_ref()); + const Mat& B = U.M; + + arma_debug_assert_same_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "element-wise division"); + + bool result_ok = true; + + constexpr eT zero = eT(0); + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + for(uword c=0; c < B_n_cols; ++c) + { + for(uword r=0; r < B_n_rows; ++r) + { + // a zero in B and A at the same location implies the division result is NaN; + // hence a zero in A (not stored) needs to be changed into a non-zero + + // for efficiency, an element in B is checked before checking the corresponding element in A + + if((B.at(r,c) == zero) && (A.at(r,c) == zero)) { result_ok = false; break; } + } + + if(result_ok == false) { break; } + } + + if(result_ok) + { + const_iterator cit = A.begin(); + const_iterator cit_end = A.end(); + + while(cit != cit_end) + { + const eT tmp = (*cit) / B.at(cit.row(), cit.col()); + + if(tmp == zero) { result_ok = false; break; } + + ++cit; + } + } + + if(result_ok) + { + iterator it = (*this).begin(); + iterator it_end = (*this).end(); + + while(it != it_end) + { + (*it) /= B.at(it.row(), it.col()); + + ++it; + } + } + else + { + (*this).operator=( (*this) / B ); + } + + return (*this); } @@ -404,16 +521,18 @@ SpSubview::operator_equ_common(const SpBase& in) const unwrap_spmat U(in.get_ref()); + arma_debug_assert_same_size(n_rows, n_cols, U.M.n_rows, U.M.n_cols, "insertion into sparse submatrix"); + if(U.is_alias(m)) { const SpMat tmp(U.M); - return (*this).operator_equ_common(tmp); + spglue_merge::subview_merge(*this, tmp); + } + else + { + spglue_merge::subview_merge(*this, U.M); } - - arma_debug_assert_same_size(n_rows, n_cols, U.M.n_rows, U.M.n_cols, "insertion into sparse submatrix"); - - spglue_merge::subview_merge(*this, U.M); return *this; } @@ -455,7 +574,7 @@ const SpSubview& SpSubview::operator*=(const SpBase& x) { arma_extra_debug_sigprint(); - + return (*this).operator=( (*this) * x.get_ref() ); } @@ -475,7 +594,6 @@ SpSubview::operator%=(const SpBase& x) -//! If you are using this function, you are probably misguided. template template inline @@ -484,6 +602,8 @@ SpSubview::operator/=(const SpBase& x) { arma_extra_debug_sigprint(); + // NOTE: use of this function is not advised; it is implemented only for completeness + SpProxy p(x.get_ref()); arma_debug_assert_same_size(n_rows, n_cols, p.get_n_rows(), p.get_n_cols(), "element-wise division"); @@ -768,6 +888,36 @@ SpSubview::clean(const typename get_pod_type::result threshold) +template +inline +void +SpSubview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpSubview::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "SpSubview::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "SpSubview::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + if((n_elem == 0) || (n_nonzero == 0)) { return; } + + // TODO: replace with a more efficient implementation + + SpMat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + template inline void @@ -777,7 +927,7 @@ SpSubview::fill(const eT val) if(val != eT(0)) { - Mat tmp( (*this).n_rows, (*this).n_cols ); + Mat tmp( (*this).n_rows, (*this).n_cols, arma_nozeros_indicator() ); tmp.fill(val); @@ -854,7 +1004,7 @@ void SpSubview::ones() { arma_extra_debug_sigprint(); - + (*this).fill(eT(1)); } @@ -877,121 +1027,140 @@ SpSubview::eye() template -arma_hot +inline +void +SpSubview::randu() + { + arma_extra_debug_sigprint(); + + Mat tmp( (*this).n_rows, (*this).n_cols, fill::randu ); + + (*this).operator=(tmp); + } + + + +template +inline +void +SpSubview::randn() + { + arma_extra_debug_sigprint(); + + Mat tmp( (*this).n_rows, (*this).n_cols, fill::randn ); + + (*this).operator=(tmp); + } + + + +template inline SpSubview_MapMat_val SpSubview::operator[](const uword i) { const uword lrow = i % n_rows; const uword lcol = i / n_rows; - + return (*this).at(lrow, lcol); } template -arma_hot inline eT SpSubview::operator[](const uword i) const { const uword lrow = i % n_rows; const uword lcol = i / n_rows; - + return (*this).at(lrow, lcol); } template -arma_hot inline SpSubview_MapMat_val SpSubview::operator()(const uword i) { - arma_debug_check( (i >= n_elem), "SpSubview::operator(): index out of bounds"); - + arma_debug_check_bounds( (i >= n_elem), "SpSubview::operator(): index out of bounds" ); + const uword lrow = i % n_rows; const uword lcol = i / n_rows; - + return (*this).at(lrow, lcol); } template -arma_hot inline eT SpSubview::operator()(const uword i) const { - arma_debug_check( (i >= n_elem), "SpSubview::operator(): index out of bounds"); - + arma_debug_check_bounds( (i >= n_elem), "SpSubview::operator(): index out of bounds" ); + const uword lrow = i % n_rows; const uword lcol = i / n_rows; - + return (*this).at(lrow, lcol); } template -arma_hot inline SpSubview_MapMat_val SpSubview::operator()(const uword in_row, const uword in_col) { - arma_debug_check( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds"); - + arma_debug_check_bounds( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds" ); + return (*this).at(in_row, in_col); } template -arma_hot inline eT SpSubview::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds"); - + arma_debug_check_bounds( (in_row >= n_rows) || (in_col >= n_cols), "SpSubview::operator(): index out of bounds" ); + return (*this).at(in_row, in_col); } template -arma_hot inline SpSubview_MapMat_val SpSubview::at(const uword i) { const uword lrow = i % n_rows; const uword lcol = i / n_cols; - + return (*this).at(lrow, lcol); } template -arma_hot inline eT SpSubview::at(const uword i) const { const uword lrow = i % n_rows; const uword lcol = i / n_cols; - + return (*this).at(lrow, lcol); } template -arma_hot inline SpSubview_MapMat_val SpSubview::at(const uword in_row, const uword in_col) @@ -1002,7 +1171,6 @@ SpSubview::at(const uword in_row, const uword in_col) template -arma_hot inline eT SpSubview::at(const uword in_row, const uword in_col) const @@ -1017,8 +1185,8 @@ inline bool SpSubview::check_overlap(const SpSubview& x) const { - const subview& t = *this; - + const SpSubview& t = *this; + if(&t.m != &x.m) { return false; @@ -1033,19 +1201,19 @@ SpSubview::check_overlap(const SpSubview& x) const { const uword t_row_start = t.aux_row1; const uword t_row_end_p1 = t_row_start + t.n_rows; - + const uword t_col_start = t.aux_col1; const uword t_col_end_p1 = t_col_start + t.n_cols; - + const uword x_row_start = x.aux_row1; const uword x_row_end_p1 = x_row_start + x.n_rows; - + const uword x_col_start = x.aux_col1; const uword x_col_end_p1 = x_col_start + x.n_cols; - + const bool outside_rows = ( (x_row_start >= t_row_end_p1) || (t_row_start >= x_row_end_p1) ); const bool outside_cols = ( (x_col_start >= t_col_end_p1) || (t_col_start >= x_col_end_p1) ); - + return ( (outside_rows == false) && (outside_cols == false) ); } } @@ -1065,56 +1233,56 @@ SpSubview::is_vec() const template inline -SpSubview +SpSubview_row SpSubview::row(const uword row_num) { arma_extra_debug_sigprint(); - - arma_debug_check(row_num >= n_rows, "SpSubview::row(): out of bounds"); - - return submat(row_num, 0, row_num, n_cols - 1); + + arma_debug_check_bounds(row_num >= n_rows, "SpSubview::row(): out of bounds"); + + return SpSubview_row(const_cast< SpMat& >(m), row_num + aux_row1, aux_col1, n_cols); } template inline -const SpSubview +const SpSubview_row SpSubview::row(const uword row_num) const { arma_extra_debug_sigprint(); - - arma_debug_check(row_num >= n_rows, "SpSubview::row(): out of bounds"); - - return submat(row_num, 0, row_num, n_cols - 1); + + arma_debug_check_bounds(row_num >= n_rows, "SpSubview::row(): out of bounds"); + + return SpSubview_row(m, row_num + aux_row1, aux_col1, n_cols); } template inline -SpSubview +SpSubview_col SpSubview::col(const uword col_num) { arma_extra_debug_sigprint(); - - arma_debug_check(col_num >= n_cols, "SpSubview::col(): out of bounds"); - - return submat(0, col_num, n_rows - 1, col_num); + + arma_debug_check_bounds(col_num >= n_cols, "SpSubview::col(): out of bounds"); + + return SpSubview_col(const_cast< SpMat& >(m), col_num + aux_col1, aux_row1, n_rows); } template inline -const SpSubview +const SpSubview_col SpSubview::col(const uword col_num) const { arma_extra_debug_sigprint(); - - arma_debug_check(col_num >= n_cols, "SpSubview::col(): out of bounds"); - - return submat(0, col_num, n_rows - 1, col_num); + + arma_debug_check_bounds(col_num >= n_cols, "SpSubview::col(): out of bounds"); + + return SpSubview_col(m, col_num + aux_col1, aux_row1, n_rows); } @@ -1125,13 +1293,13 @@ SpSubview SpSubview::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "SpSubview::rows(): indices out of bounds or incorrectly used" ); - + return submat(in_row1, 0, in_row2, n_cols - 1); } @@ -1143,8 +1311,8 @@ const SpSubview SpSubview::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "SpSubview::rows(): indices out of bounds or incorrectly used" @@ -1161,13 +1329,13 @@ SpSubview SpSubview::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "SpSubview::cols(): indices out of bounds or incorrectly used" ); - + return submat(0, in_col1, n_rows - 1, in_col2); } @@ -1179,13 +1347,13 @@ const SpSubview SpSubview::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "SpSubview::cols(): indices out of bounds or incorrectly used" ); - + return submat(0, in_col1, n_rows - 1, in_col2); } @@ -1197,13 +1365,13 @@ SpSubview SpSubview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "SpSubview::submat(): indices out of bounds or incorrectly used" ); - + return access::rw(m).submat(in_row1 + aux_row1, in_col1 + aux_col1, in_row2 + aux_row1, in_col2 + aux_col1); } @@ -1215,13 +1383,13 @@ const SpSubview SpSubview::submat(const uword in_row1, const uword in_col1, const uword in_row2, const uword in_col2) const { arma_extra_debug_sigprint(); - - arma_debug_check + + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "SpSubview::submat(): indices out of bounds or incorrectly used" ); - + return m.submat(in_row1 + aux_row1, in_col1 + aux_col1, in_row2 + aux_row1, in_col2 + aux_col1); } @@ -1233,24 +1401,24 @@ SpSubview SpSubview::submat(const span& row_span, const span& col_span) { arma_extra_debug_sigprint(); - + const bool row_all = row_span.whole; const bool col_all = row_span.whole; - + const uword in_row1 = row_all ? 0 : row_span.a; const uword in_row2 = row_all ? n_rows : row_span.b; - + const uword in_col1 = col_all ? 0 : col_span.a; const uword in_col2 = col_all ? n_cols : col_span.b; - - arma_debug_check + + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= n_rows))) || ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= n_cols))), "SpSubview::submat(): indices out of bounds or incorrectly used" ); - + return submat(in_row1, in_col1, in_row2, in_col2); } @@ -1262,24 +1430,24 @@ const SpSubview SpSubview::submat(const span& row_span, const span& col_span) const { arma_extra_debug_sigprint(); - + const bool row_all = row_span.whole; const bool col_all = row_span.whole; - + const uword in_row1 = row_all ? 0 : row_span.a; const uword in_row2 = row_all ? n_rows - 1 : row_span.b; - + const uword in_col1 = col_all ? 0 : col_span.a; const uword in_col2 = col_all ? n_cols - 1 : col_span.b; - - arma_debug_check + + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= n_rows))) || ( col_all ? false : ((in_col1 > in_col2) || (in_col2 >= n_cols))), "SpSubview::submat(): indices out of bounds or incorrectly used" ); - + return submat(in_row1, in_col1, in_row2, in_col2); } @@ -1291,7 +1459,7 @@ SpSubview SpSubview::operator()(const uword row_num, const span& col_span) { arma_extra_debug_sigprint(); - + return submat(span(row_num, row_num), col_span); } @@ -1303,7 +1471,7 @@ const SpSubview SpSubview::operator()(const uword row_num, const span& col_span) const { arma_extra_debug_sigprint(); - + return submat(span(row_num, row_num), col_span); } @@ -1315,7 +1483,7 @@ SpSubview SpSubview::operator()(const span& row_span, const uword col_num) { arma_extra_debug_sigprint(); - + return submat(row_span, span(col_num, col_num)); } @@ -1327,7 +1495,7 @@ const SpSubview SpSubview::operator()(const span& row_span, const uword col_num) const { arma_extra_debug_sigprint(); - + return submat(row_span, span(col_num, col_num)); } @@ -1339,7 +1507,7 @@ SpSubview SpSubview::operator()(const span& row_span, const span& col_span) { arma_extra_debug_sigprint(); - + return submat(row_span, col_span); } @@ -1351,7 +1519,7 @@ const SpSubview SpSubview::operator()(const span& row_span, const span& col_span) const { arma_extra_debug_sigprint(); - + return submat(row_span, col_span); } @@ -1363,12 +1531,12 @@ void SpSubview::swap_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - + arma_debug_check((in_row1 >= n_rows) || (in_row2 >= n_rows), "SpSubview::swap_rows(): invalid row index"); - + const uword lstart_col = aux_col1; const uword lend_col = aux_col1 + n_cols; - + for(uword c = lstart_col; c < lend_col; ++c) { const eT val = access::rw(m).at(in_row1 + aux_row1, c); @@ -1385,12 +1553,12 @@ void SpSubview::swap_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - + arma_debug_check((in_col1 >= n_cols) || (in_col2 >= n_cols), "SpSubview::swap_cols(): invalid column index"); - + const uword lstart_row = aux_row1; const uword lend_row = aux_row1 + n_rows; - + for(uword r = lstart_row; r < lend_row; ++r) { const eT val = access::rw(m).at(r, in_col1 + aux_col1); @@ -1406,6 +1574,8 @@ inline typename SpSubview::iterator SpSubview::begin() { + m.sync_csc(); + return iterator(*this); } @@ -1578,18 +1748,17 @@ SpSubview::is_alias(const SpMat& X) const template inline -arma_warn_unused eT& SpSubview::insert_element(const uword in_row, const uword in_col, const eT in_val) { arma_extra_debug_sigprint(); - + // This may not actually insert an element. const uword old_n_nonzero = m.n_nonzero; eT& retval = access::rw(m).insert_element(in_row + aux_row1, in_col + aux_col1, in_val); // Update n_nonzero (if necessary). access::rw(n_nonzero) += (m.n_nonzero - old_n_nonzero); - + return retval; } @@ -1601,7 +1770,7 @@ void SpSubview::delete_element(const uword in_row, const uword in_col) { arma_extra_debug_sigprint(); - + // This may not actually delete an element. const uword old_n_nonzero = m.n_nonzero; access::rw(m).delete_element(in_row + aux_row1, in_col + aux_col1); @@ -1622,69 +1791,216 @@ SpSubview::invalidate_cache() const -/** - * Sparse subview col - * +// +// +// + + + +template +inline +SpSubview_col::SpSubview_col(const SpMat& in_m, const uword in_col) + : SpSubview(in_m, 0, in_col, in_m.n_rows, 1) + { + arma_extra_debug_sigprint(); + } + + + template inline -SpSubview_col::SpSubview_col(const Mat& in_m, const uword in_col) +SpSubview_col::SpSubview_col(const SpMat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows) + : SpSubview(in_m, in_row1, in_col, in_n_rows, 1) { arma_extra_debug_sigprint(); } + + template inline -SpSubview_col::SpSubview_col(Mat& in_m, const uword in_col) +void +SpSubview_col::operator=(const SpSubview& x) { arma_extra_debug_sigprint(); + + SpSubview::operator=(x); } + + template inline -SpSubview_col::SpSubview_col(const Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows) +void +SpSubview_col::operator=(const SpSubview_col& x) { arma_extra_debug_sigprint(); + + SpSubview::operator=(x); // interprets 'SpSubview_col' as 'SpSubview' } + + template +template inline -SpSubview_col::SpSubview_col(Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows) +void +SpSubview_col::operator=(const SpBase& x) { arma_extra_debug_sigprint(); + + SpSubview::operator=(x); } -*/ -/** - * Sparse subview row - * + + template +template inline -SpSubview_row::SpSubview_row(const Mat& in_m, const uword in_row) +void +SpSubview_col::operator=(const Base& x) { arma_extra_debug_sigprint(); + + SpSubview::operator=(x); } + + template inline -SpSubview_row::SpSubview_row(Mat& in_m, const uword in_row) +const SpOp,spop_htrans> +SpSubview_col::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_col::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpSubview_col::st() const + { + return SpOp,spop_strans>(*this); + } + + + +// +// +// + + + +template +inline +SpSubview_row::SpSubview_row(const SpMat& in_m, const uword in_row) + : SpSubview(in_m, in_row, 0, 1, in_m.n_cols) { arma_extra_debug_sigprint(); } + + template inline -SpSubview_row::SpSubview_row(const Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols) +SpSubview_row::SpSubview_row(const SpMat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols) + : SpSubview(in_m, in_row, in_col1, 1, in_n_cols) { arma_extra_debug_sigprint(); } + + template inline -SpSubview_row::SpSubview_row(Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols) +void +SpSubview_row::operator=(const SpSubview& x) { arma_extra_debug_sigprint(); + + SpSubview::operator=(x); } -*/ + + + +template +inline +void +SpSubview_row::operator=(const SpSubview_row& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); // interprets 'SpSubview_row' as 'SpSubview' + } + + + +template +template +inline +void +SpSubview_row::operator=(const SpBase& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +template +inline +void +SpSubview_row::operator=(const Base& x) + { + arma_extra_debug_sigprint(); + + SpSubview::operator=(x); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_row::t() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_htrans> +SpSubview_row::ht() const + { + return SpOp,spop_htrans>(*this); + } + + + +template +inline +const SpOp,spop_strans> +SpSubview_row::st() const + { + return SpOp,spop_strans>(*this); + } + //! @} diff --git a/src/armadillo_bits/SpToDGlue_bones.hpp b/src/armadillo_bits/SpToDGlue_bones.hpp new file mode 100644 index 00000000..7d4ce9d3 --- /dev/null +++ b/src/armadillo_bits/SpToDGlue_bones.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup SpToDGlue +//! @{ + + + +template +class SpToDGlue : public Base< typename T1::elem_type, SpToDGlue > + { + public: + + typedef typename T1::elem_type elem_type; + typedef typename get_pod_type::result pod_type; + + inline explicit SpToDGlue(const T1& in_A, const T2& in_B); + inline ~SpToDGlue(); + + const T1& A; //!< first operand; must be derived from Base or SpBase + const T2& B; //!< second operand; must be derived from Base or SpBase + + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; + }; + + + +//! @} diff --git a/src/armadillo_bits/compiler_extra.hpp b/src/armadillo_bits/SpToDGlue_meat.hpp similarity index 63% rename from src/armadillo_bits/compiler_extra.hpp rename to src/armadillo_bits/SpToDGlue_meat.hpp index ecc6bf9e..1d3d095f 100644 --- a/src/armadillo_bits/compiler_extra.hpp +++ b/src/armadillo_bits/SpToDGlue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,23 +16,29 @@ // ------------------------------------------------------------------------ +//! \addtogroup SpToDGlue +//! @{ + + + +template +inline +SpToDGlue::SpToDGlue(const T1& in_A, const T2& in_B) + : A(in_A) + , B(in_B) + { + arma_extra_debug_sigprint(); + } + -#if (__cplusplus >= 201103L) - #undef ARMA_USE_CXX11 - #define ARMA_USE_CXX11 -#endif +template +inline +SpToDGlue::~SpToDGlue() + { + arma_extra_debug_sigprint(); + } -// MS really can't get its proverbial shit together -#if (defined(_MSVC_LANG) && (_MSVC_LANG >= 201402L)) - #undef ARMA_USE_CXX11 - #define ARMA_USE_CXX11 - #undef ARMA_DONT_PRINT_CXX11_WARNING - #define ARMA_DONT_PRINT_CXX11_WARNING -#endif -#if (defined(_OPENMP) && (_OPENMP >= 201107)) - #undef ARMA_USE_OPENMP - #define ARMA_USE_OPENMP -#endif +//! @} diff --git a/src/armadillo_bits/SpToDOp_bones.hpp b/src/armadillo_bits/SpToDOp_bones.hpp index 57977d17..b8ae6ccb 100644 --- a/src/armadillo_bits/SpToDOp_bones.hpp +++ b/src/armadillo_bits/SpToDOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,7 @@ //! Class for storing data required for delayed unary operations on a sparse //! matrix that produce a dense matrix; the data for storage may include -//! the operand (e.g. the matrix to which the operation is to be applied) and the unary operator (e.g. inverse). +//! the operand (eg. the matrix to which the operation is to be applied) and the unary operator (eg. inverse). //! The operand is stored as a reference (which can be optimised away), //! while the operator is "stored" through the template definition (op_type). //! The operands can be 'SpMat', 'SpRow', 'SpCol', 'SpOp', and 'SpGlue'. @@ -31,24 +33,23 @@ //! SpToDOp< SpGlue< SpMat, SpMat, sp_glue_times >, op_sp_plus > template -class SpToDOp : public Base > +class SpToDOp : public Base< typename T1::elem_type, SpToDOp > { public: - + typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - + inline explicit SpToDOp(const T1& in_m); inline SpToDOp(const T1& in_m, const elem_type in_aux); inline ~SpToDOp(); - + arma_aligned const T1& m; //!< the operand; must be derived from SpBase arma_aligned elem_type aux; //!< auxiliary data, using the element type as used by T1 - - static const bool is_row = op_type::template traits::is_row; - static const bool is_col = op_type::template traits::is_col; - static const bool is_xvec = op_type::template traits::is_xvec; - + + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; }; diff --git a/src/armadillo_bits/SpToDOp_meat.hpp b/src/armadillo_bits/SpToDOp_meat.hpp index 2a743a0c..66ab6405 100644 --- a/src/armadillo_bits/SpToDOp_meat.hpp +++ b/src/armadillo_bits/SpToDOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/SpValProxy_bones.hpp b/src/armadillo_bits/SpValProxy_bones.hpp index d6e706a1..af9a52d4 100644 --- a/src/armadillo_bits/SpValProxy_bones.hpp +++ b/src/armadillo_bits/SpValProxy_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ template class SpValProxy { public: - + typedef typename T1::elem_type eT; // Convenience typedef - + friend class SpMat; friend class SpSubview; @@ -36,7 +38,8 @@ class SpValProxy * Create the sparse value proxy. * Otherwise, pass a pointer to a reference of the value. */ - arma_inline SpValProxy(uword row, uword col, T1& in_parent, eT* in_val_ptr = NULL); + arma_inline SpValProxy(uword row, uword col, T1& in_parent, eT* in_val_ptr = nullptr); + inline SpValProxy() = delete; //! For swapping operations. arma_inline SpValProxy& operator=(const SpValProxy& rhs); @@ -46,16 +49,17 @@ class SpValProxy //! Overload all of the potential operators. //! First, the ones that could modify a value. - arma_inline SpValProxy& operator=(const eT rhs); - arma_inline SpValProxy& operator+=(const eT rhs); - arma_inline SpValProxy& operator-=(const eT rhs); - arma_inline SpValProxy& operator*=(const eT rhs); - arma_inline SpValProxy& operator/=(const eT rhs); + inline SpValProxy& operator= (const eT rhs); + inline SpValProxy& operator+=(const eT rhs); + inline SpValProxy& operator-=(const eT rhs); + inline SpValProxy& operator*=(const eT rhs); + inline SpValProxy& operator/=(const eT rhs); + + inline SpValProxy& operator++(); + inline SpValProxy& operator--(); - arma_inline SpValProxy& operator++(); - arma_inline SpValProxy& operator--(); - arma_inline eT operator++(const int); - arma_inline eT operator--(const int); + inline eT operator++(const int); + inline eT operator--(const int); //! This will work for any other operations that do not modify a value. arma_inline operator eT() const; @@ -66,7 +70,7 @@ class SpValProxy private: - // Deletes the element if it is zero. Does not check if val_ptr == NULL! + // Deletes the element if it is zero; NOTE: does not check if val_ptr == nullptr arma_inline void check_zero(); arma_aligned const uword row; diff --git a/src/armadillo_bits/SpValProxy_meat.hpp b/src/armadillo_bits/SpValProxy_meat.hpp index 657bfbc7..242ec07e 100644 --- a/src/armadillo_bits/SpValProxy_meat.hpp +++ b/src/armadillo_bits/SpValProxy_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -54,52 +56,47 @@ SpValProxy::operator=(const SpValProxy& rhs) template -arma_inline +inline SpValProxy& SpValProxy::operator=(const eT rhs) { - if (rhs != eT(0)) // A nonzero element is being assigned. + if(rhs != eT(0)) // A nonzero element is being assigned. { - - if (val_ptr) + if(val_ptr) { // The value exists and merely needs to be updated. *val_ptr = rhs; parent.invalidate_cache(); } - else { // The value is nonzero and must be inserted. val_ptr = &parent.insert_element(row, col, rhs); } - } else // A zero is being assigned.~ { - - if (val_ptr) + if(val_ptr) { // The element exists, but we need to remove it, because it is being set to 0. parent.delete_element(row, col); - val_ptr = NULL; + val_ptr = nullptr; } - + // If the element does not exist, we do not need to do anything at all. - } - + return *this; } template -arma_inline +inline SpValProxy& SpValProxy::operator+=(const eT rhs) { - if (val_ptr) + if(val_ptr) { // The value already exists and merely needs to be updated. *val_ptr += rhs; @@ -108,7 +105,7 @@ SpValProxy::operator+=(const eT rhs) } else { - if (rhs != eT(0)) + if(rhs != eT(0)) { // The value does not exist and must be inserted. val_ptr = &parent.insert_element(row, col, rhs); @@ -121,11 +118,11 @@ SpValProxy::operator+=(const eT rhs) template -arma_inline +inline SpValProxy& SpValProxy::operator-=(const eT rhs) { - if (val_ptr) + if(val_ptr) { // The value already exists and merely needs to be updated. *val_ptr -= rhs; @@ -134,162 +131,150 @@ SpValProxy::operator-=(const eT rhs) } else { - if (rhs != eT(0)) + if(rhs != eT(0)) { // The value does not exist and must be inserted. val_ptr = &parent.insert_element(row, col, -rhs); } } - + return *this; } template -arma_inline +inline SpValProxy& SpValProxy::operator*=(const eT rhs) { - if (rhs != eT(0)) + if(rhs != eT(0)) { - - if (val_ptr) + if(val_ptr) { // The value already exists and merely needs to be updated. *val_ptr *= rhs; parent.invalidate_cache(); check_zero(); } - } else { - - if (val_ptr) + if(val_ptr) { // Since we are multiplying by zero, the value can be deleted. parent.delete_element(row, col); - val_ptr = NULL; + val_ptr = nullptr; } - } - + return *this; } template -arma_inline +inline SpValProxy& SpValProxy::operator/=(const eT rhs) { - if (rhs != eT(0)) // I hope this is true! + if(rhs != eT(0)) // I hope this is true! { - - if (val_ptr) + if(val_ptr) { *val_ptr /= rhs; parent.invalidate_cache(); check_zero(); } - } else { - - if (val_ptr) + if(val_ptr) { *val_ptr /= rhs; // That is where it gets ugly. // Now check if it's 0. - if (*val_ptr == eT(0)) + if(*val_ptr == eT(0)) { parent.delete_element(row, col); - val_ptr = NULL; + val_ptr = nullptr; } } - else { eT val = eT(0) / rhs; // This may vary depending on type and implementation. - - if (val != eT(0)) + + if(val != eT(0)) { // Ok, now we have to insert it. val_ptr = &parent.insert_element(row, col, val); } - } } - + return *this; } template -arma_inline +inline SpValProxy& SpValProxy::operator++() { - if (val_ptr) + if(val_ptr) { (*val_ptr) += eT(1); parent.invalidate_cache(); check_zero(); } - else { val_ptr = &parent.insert_element(row, col, eT(1)); } - + return *this; } template -arma_inline +inline SpValProxy& SpValProxy::operator--() { - if (val_ptr) + if(val_ptr) { (*val_ptr) -= eT(1); parent.invalidate_cache(); check_zero(); } - else { val_ptr = &parent.insert_element(row, col, eT(-1)); } - + return *this; } template -arma_inline +inline typename T1::elem_type SpValProxy::operator++(const int) { - if (val_ptr) + if(val_ptr) { (*val_ptr) += eT(1); parent.invalidate_cache(); check_zero(); } - else { val_ptr = &parent.insert_element(row, col, eT(1)); } - - if (val_ptr) // It may have changed to now be 0. + + if(val_ptr) // It may have changed to now be 0. { return *(val_ptr) - eT(1); } @@ -302,23 +287,22 @@ SpValProxy::operator++(const int) template -arma_inline +inline typename T1::elem_type SpValProxy::operator--(const int) { - if (val_ptr) + if(val_ptr) { (*val_ptr) -= eT(1); parent.invalidate_cache(); check_zero(); } - else { val_ptr = &parent.insert_element(row, col, eT(-1)); } - - if (val_ptr) // It may have changed to now be 0. + + if(val_ptr) // It may have changed to now be 0. { return *(val_ptr) + eT(1); } @@ -368,10 +352,10 @@ arma_inline void SpValProxy::check_zero() { - if (*val_ptr == eT(0)) + if(*val_ptr == eT(0)) { parent.delete_element(row, col); - val_ptr = NULL; + val_ptr = nullptr; } } diff --git a/src/armadillo_bits/access.hpp b/src/armadillo_bits/access.hpp index 095e2105..77db8621 100644 --- a/src/armadillo_bits/access.hpp +++ b/src/armadillo_bits/access.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,19 +25,19 @@ class access public: //! internal function to allow modification of data declared as read-only (use with caution) - template arma_inline static T1& rw (const T1& x) { return const_cast(x); } - template arma_inline static T1*& rwp(const T1* const& x) { return const_cast(x); } + template constexpr static T1& rw (const T1& x) { return const_cast(x); } + template constexpr static T1*& rwp(const T1* const& x) { return const_cast(x); } //! internal function to obtain the real part of either a plain number or a complex number - template arma_inline static const eT& tmp_real(const eT& X) { return X; } - template arma_inline static const T tmp_real(const std::complex& X) { return X.real(); } + template constexpr static const eT& tmp_real(const eT& X) { return X; } + template constexpr static const T tmp_real(const std::complex& X) { return X.real(); } //! internal function to obtain the imag part of either a plain number or a complex number - template arma_inline static const eT tmp_imag(const eT ) { return eT(0); } - template arma_inline static const T tmp_imag(const std::complex& X) { return X.imag(); } + template constexpr static const eT tmp_imag(const eT ) { return eT(0); } + template constexpr static const T tmp_imag(const std::complex& X) { return X.imag(); } //! internal function to work around braindead compilers - template arma_inline static const typename enable_if2::no, const eT&>::result alt_conj(const eT& X) { return X; } + template constexpr static const typename enable_if2::no, const eT&>::result alt_conj(const eT& X) { return X; } template arma_inline static const typename enable_if2::yes, const eT >::result alt_conj(const eT& X) { return std::conj(X); } }; diff --git a/src/armadillo_bits/arma_cmath.hpp b/src/armadillo_bits/arma_cmath.hpp index 73091a87..22df4bf0 100644 --- a/src/armadillo_bits/arma_cmath.hpp +++ b/src/armadillo_bits/arma_cmath.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,90 +27,41 @@ template -arma_inline +inline bool -arma_isfinite(eT val) +arma_isfinite(eT) { - arma_ignore(val); - return true; } template<> -arma_inline +inline bool arma_isfinite(float x) { - #if defined(ARMA_USE_CXX11) - { - return std::isfinite(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::isfinite(x); - } - #elif defined(ARMA_HAVE_ISFINITE) - { - return (std::isfinite(x) != 0); - } - #else - { - const float y = (std::numeric_limits::max)(); - - const volatile float xx = x; - - return (xx == xx) && (x >= -y) && (x <= y); - } - #endif + return std::isfinite(x); } template<> -arma_inline +inline bool arma_isfinite(double x) { - #if defined(ARMA_USE_CXX11) - { - return std::isfinite(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::isfinite(x); - } - #elif defined(ARMA_HAVE_ISFINITE) - { - return (std::isfinite(x) != 0); - } - #else - { - const double y = (std::numeric_limits::max)(); - - const volatile double xx = x; - - return (xx == xx) && (x >= -y) && (x <= y); - } - #endif + return std::isfinite(x); } template -arma_inline +inline bool arma_isfinite(const std::complex& x) { - if( (arma_isfinite(x.real()) == false) || (arma_isfinite(x.imag()) == false) ) - { - return false; - } - else - { - return true; - } + return ( arma_isfinite(x.real()) && arma_isfinite(x.imag()) ); } @@ -118,71 +71,37 @@ arma_isfinite(const std::complex& x) template -arma_inline +inline bool -arma_isinf(eT val) +arma_isinf(eT) { - arma_ignore(val); - return false; } template<> -arma_inline +inline bool arma_isinf(float x) { - #if defined(ARMA_USE_CXX11) - { - return std::isinf(x); - } - #elif defined(ARMA_HAVE_ISINF) - { - return (std::isinf(x) != 0); - } - #else - { - const float y = (std::numeric_limits::max)(); - - const volatile float xx = x; - - return (xx == xx) && ((x < -y) || (x > y)); - } - #endif + return std::isinf(x); } template<> -arma_inline +inline bool arma_isinf(double x) { - #if defined(ARMA_USE_CXX11) - { - return std::isinf(x); - } - #elif defined(ARMA_HAVE_ISINF) - { - return (std::isinf(x) != 0); - } - #else - { - const double y = (std::numeric_limits::max)(); - - const volatile double xx = x; - - return (xx == xx) && ((x < -y) || (x > y)); - } - #endif + return std::isinf(x); } template -arma_inline +inline bool arma_isinf(const std::complex& x) { @@ -196,7 +115,7 @@ arma_isinf(const std::complex& x) template -arma_inline +inline bool arma_isnan(eT val) { @@ -208,55 +127,27 @@ arma_isnan(eT val) template<> -arma_inline +inline bool arma_isnan(float x) { - #if defined(ARMA_USE_CXX11) - { - return std::isnan(x); - } - #elif defined(ARMA_HAVE_ISNAN) - { - return (std::isnan(x) != 0); - } - #else - { - const volatile float xx = x; - - return (xx != xx); - } - #endif + return std::isnan(x); } template<> -arma_inline +inline bool arma_isnan(double x) { - #if defined(ARMA_USE_CXX11) - { - return std::isnan(x); - } - #elif defined(ARMA_HAVE_ISNAN) - { - return (std::isnan(x) != 0); - } - #else - { - const volatile double xx = x; - - return (xx != xx); - } - #endif + return std::isnan(x); } template -arma_inline +inline bool arma_isnan(const std::complex& x) { @@ -265,76 +156,12 @@ arma_isnan(const std::complex& x) -// rudimentary wrappers for log1p() - -arma_inline -float -arma_log1p(const float x) - { - #if defined(ARMA_USE_CXX11) - { - return std::log1p(x); - } - #else - { - if((x >= float(0)) && (x < std::numeric_limits::epsilon())) - { - return x; - } - else - if((x < float(0)) && (-x < std::numeric_limits::epsilon())) - { - return x; - } - else - { - return std::log(float(1) + x); - } - } - #endif - } - - - -arma_inline -double -arma_log1p(const double x) - { - #if defined(ARMA_USE_CXX11) - { - return std::log1p(x); - } - #elif defined(ARMA_HAVE_LOG1P) - { - return log1p(x); - } - #else - { - if((x >= double(0)) && (x < std::numeric_limits::epsilon())) - { - return x; - } - else - if((x < double(0)) && (-x < std::numeric_limits::epsilon())) - { - return x; - } - else - { - return std::log(double(1) + x); - } - } - #endif - } - - - // // implementation of arma_sign() template -arma_inline +constexpr typename arma_unsigned_integral_only::result arma_sign(const eT x) { @@ -344,7 +171,7 @@ arma_sign(const eT x) template -arma_inline +constexpr typename arma_signed_integral_only::result arma_sign(const eT x) { @@ -354,17 +181,17 @@ arma_sign(const eT x) template -arma_inline +constexpr typename arma_real_only::result arma_sign(const eT x) { - return (x > eT(0)) ? eT(+1) : ( (x < eT(0)) ? eT(-1) : eT(0) ); + return (x > eT(0)) ? eT(+1) : ( (x < eT(0)) ? eT(-1) : ((x == eT(0)) ? eT(0) : x) ); } template -arma_inline +inline typename arma_cx_only::result arma_sign(const eT& x) { @@ -377,322 +204,10 @@ arma_sign(const eT& x) -// -// wrappers for trigonometric functions -// -// wherever possible, try to use C++11 or TR1 versions of the following functions: -// -// complex acos -// complex asin -// complex atan -// -// real acosh -// real asinh -// real atanh -// -// complex acosh -// complex asinh -// complex atanh -// -// -// if C++11 or TR1 are not available, we have rudimentary versions of: -// -// real acosh -// real asinh -// real atanh - - - -template -arma_inline -std::complex -arma_acos(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::acos(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::acos(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("acos(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - -template -arma_inline -std::complex -arma_asin(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::asin(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::asin(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("asin(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - -template -arma_inline -std::complex -arma_atan(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::atan(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::atan(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("atan(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - -template -arma_inline -eT -arma_acosh(const eT x) - { - #if defined(ARMA_USE_CXX11) - { - return std::acosh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::acosh(x); - } - #else - { - if(x >= eT(1)) - { - // http://functions.wolfram.com/ElementaryFunctions/ArcCosh/02/ - return std::log( x + std::sqrt(x*x - eT(1)) ); - } - else - { - if(std::numeric_limits::has_quiet_NaN) - { - return -(std::numeric_limits::quiet_NaN()); - } - else - { - return eT(0); - } - } - } - #endif - } - - - -template -arma_inline -eT -arma_asinh(const eT x) - { - #if defined(ARMA_USE_CXX11) - { - return std::asinh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::asinh(x); - } - #else - { - // http://functions.wolfram.com/ElementaryFunctions/ArcSinh/02/ - return std::log( x + std::sqrt(x*x + eT(1)) ); - } - #endif - } - - - -template -arma_inline -eT -arma_atanh(const eT x) - { - #if defined(ARMA_USE_CXX11) - { - return std::atanh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::atanh(x); - } - #else - { - if( (x >= eT(-1)) && (x <= eT(+1)) ) - { - // http://functions.wolfram.com/ElementaryFunctions/ArcTanh/02/ - return std::log( ( eT(1)+x ) / ( eT(1)-x ) ) / eT(2); - } - else - { - if(std::numeric_limits::has_quiet_NaN) - { - return -(std::numeric_limits::quiet_NaN()); - } - else - { - return eT(0); - } - } - } - #endif - } - - - -template -arma_inline -std::complex -arma_acosh(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::acosh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::acosh(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("acosh(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - -template -arma_inline -std::complex -arma_asinh(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::asinh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::asinh(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("asinh(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - -template -arma_inline -std::complex -arma_atanh(const std::complex& x) - { - #if defined(ARMA_USE_CXX11) - { - return std::atanh(x); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::atanh(x); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("atanh(): C++11 compiler required"); - - return std::complex(0); - } - #endif - } - - - // // wrappers for hypot(x, y) = sqrt(x^2 + y^2) -template -inline -eT -arma_hypot_generic(const eT x, const eT y) - { - #if defined(ARMA_USE_CXX11) - { - return std::hypot(x, y); - } - #elif defined(ARMA_HAVE_TR1) - { - return std::tr1::hypot(x, y); - } - #else - { - const eT xabs = std::abs(x); - const eT yabs = std::abs(y); - - eT larger; - eT ratio; - - if(xabs > yabs) - { - larger = xabs; - ratio = yabs / xabs; - } - else - { - larger = yabs; - ratio = xabs / yabs; - } - - return (larger == eT(0)) ? eT(0) : (larger * std::sqrt(eT(1) + ratio * ratio)); - } - #endif - } - - - template inline eT @@ -709,21 +224,21 @@ arma_hypot(const eT x, const eT y) template<> -arma_inline +inline float arma_hypot(const float x, const float y) { - return arma_hypot_generic(x,y); + return std::hypot(x, y); } template<> -arma_inline +inline double arma_hypot(const double x, const double y) { - return arma_hypot_generic(x,y); + return std::hypot(x, y); } @@ -733,7 +248,7 @@ arma_hypot(const double x, const double y) template -arma_inline +inline eT arma_sinc_generic(const eT x) { @@ -747,7 +262,7 @@ arma_sinc_generic(const eT x) template -arma_inline +inline eT arma_sinc(const eT x) { @@ -757,7 +272,7 @@ arma_sinc(const eT x) template<> -arma_inline +inline float arma_sinc(const float x) { @@ -767,7 +282,7 @@ arma_sinc(const float x) template<> -arma_inline +inline double arma_sinc(const double x) { @@ -777,7 +292,7 @@ arma_sinc(const double x) template -arma_inline +inline std::complex arma_sinc(const std::complex& x) { @@ -798,18 +313,7 @@ struct arma_arg eT eval(const eT x) { - #if defined(ARMA_USE_CXX11) - { - return eT( std::arg(x) ); - } - #else - { - arma_ignore(x); - arma_stop_logic_error("arg(): C++11 compiler required"); - - return eT(0); - } - #endif + return eT( std::arg(x) ); } }; @@ -819,19 +323,11 @@ template<> struct arma_arg { static - arma_inline + inline float eval(const float x) { - #if defined(ARMA_USE_CXX11) - { - return std::arg(x); - } - #else - { - return std::arg( std::complex( x, float(0) ) ); - } - #endif + return std::arg(x); } }; @@ -841,19 +337,11 @@ template<> struct arma_arg { static - arma_inline + inline double eval(const double x) { - #if defined(ARMA_USE_CXX11) - { - return std::arg(x); - } - #else - { - return std::arg( std::complex( x, double(0) ) ); - } - #endif + return std::arg(x); } }; @@ -863,7 +351,7 @@ template<> struct arma_arg< std::complex > { static - arma_inline + inline float eval(const std::complex& x) { @@ -877,7 +365,7 @@ template<> struct arma_arg< std::complex > { static - arma_inline + inline double eval(const std::complex& x) { diff --git a/src/armadillo_bits/arma_config.hpp b/src/armadillo_bits/arma_config.hpp index fc068112..3670199e 100644 --- a/src/armadillo_bits/arma_config.hpp +++ b/src/armadillo_bits/arma_config.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,93 +24,121 @@ struct arma_config { #if defined(ARMA_MAT_PREALLOC) - static const uword mat_prealloc = (sword(ARMA_MAT_PREALLOC) > 0) ? uword(ARMA_MAT_PREALLOC) : 1; + static constexpr uword mat_prealloc = (sword(ARMA_MAT_PREALLOC) > 0) ? uword(ARMA_MAT_PREALLOC) : 1; #else - static const uword mat_prealloc = 16; + static constexpr uword mat_prealloc = 16; #endif #if defined(ARMA_OPENMP_THRESHOLD) - static const uword mp_threshold = (sword(ARMA_OPENMP_THRESHOLD) > 0) ? uword(ARMA_OPENMP_THRESHOLD) : 240; + static constexpr uword mp_threshold = (sword(ARMA_OPENMP_THRESHOLD) > 0) ? uword(ARMA_OPENMP_THRESHOLD) : 320; #else - static const uword mp_threshold = 240; + static constexpr uword mp_threshold = 320; #endif #if defined(ARMA_OPENMP_THREADS) - static const uword mp_threads = (sword(ARMA_OPENMP_THREADS) > 0) ? uword(ARMA_OPENMP_THREADS) : 10; + static constexpr uword mp_threads = (sword(ARMA_OPENMP_THREADS) > 0) ? uword(ARMA_OPENMP_THREADS) : 8; #else - static const uword mp_threads = 10; + static constexpr uword mp_threads = 8; #endif - #if defined(ARMA_USE_ATLAS) - static const bool atlas = true; + #if defined(ARMA_OPTIMISE_BAND) + static constexpr bool optimise_band = true; #else - static const bool atlas = false; + static constexpr bool optimise_band = false; + #endif + + + #if defined(ARMA_OPTIMISE_SYM) + static constexpr bool optimise_sym = true; + #else + static constexpr bool optimise_sym = false; + #endif + + + #if defined(ARMA_OPTIMISE_INVEXPR) + static constexpr bool optimise_invexpr = true; + #else + static constexpr bool optimise_invexpr = false; + #endif + + + #if defined(ARMA_CHECK_NONFINITE) + static constexpr bool check_nonfinite = true; + #else + static constexpr bool check_nonfinite = false; #endif #if defined(ARMA_USE_LAPACK) - static const bool lapack = true; + static constexpr bool lapack = true; #else - static const bool lapack = false; + static constexpr bool lapack = false; #endif #if defined(ARMA_USE_BLAS) - static const bool blas = true; + static constexpr bool blas = true; + #else + static constexpr bool blas = false; + #endif + + + #if defined(ARMA_USE_ATLAS) + static constexpr bool atlas = true; #else - static const bool blas = false; + static constexpr bool atlas = false; #endif #if defined(ARMA_USE_NEWARP) - static const bool newarp = true; + static constexpr bool newarp = true; #else - static const bool newarp = false; + static constexpr bool newarp = false; #endif #if defined(ARMA_USE_ARPACK) - static const bool arpack = true; + static constexpr bool arpack = true; #else - static const bool arpack = false; + static constexpr bool arpack = false; #endif #if defined(ARMA_USE_SUPERLU) - static const bool superlu = true; + static constexpr bool superlu = true; #else - static const bool superlu = false; + static constexpr bool superlu = false; #endif #if defined(ARMA_USE_HDF5) - static const bool hdf5 = true; + static constexpr bool hdf5 = true; #else - static const bool hdf5 = false; + static constexpr bool hdf5 = false; #endif #if defined(ARMA_NO_DEBUG) - static const bool debug = false; + static constexpr bool debug = false; #else - static const bool debug = true; + static constexpr bool debug = true; #endif #if defined(ARMA_EXTRA_DEBUG) - static const bool extra_debug = true; + static constexpr bool extra_debug = true; #else - static const bool extra_debug = false; + static constexpr bool extra_debug = false; #endif #if defined(ARMA_GOOD_COMPILER) - static const bool good_comp = true; + static constexpr bool good_comp = true; #else - static const bool good_comp = false; + static constexpr bool good_comp = false; #endif @@ -121,46 +151,100 @@ struct arma_config || defined(ARMA_EXTRA_SPMAT_PROTO) || defined(ARMA_EXTRA_SPMAT_MEAT) \ || defined(ARMA_EXTRA_SPCOL_PROTO) || defined(ARMA_EXTRA_SPCOL_MEAT) \ || defined(ARMA_EXTRA_SPROW_PROTO) || defined(ARMA_EXTRA_SPROW_MEAT) \ + || defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) \ + || defined(ARMA_ALIEN_MEM_FREE_FUNCTION) \ ) - static const bool extra_code = true; + static constexpr bool extra_code = true; #else - static const bool extra_code = false; + static constexpr bool extra_code = false; #endif - #if defined(ARMA_USE_CXX11) - static const bool cxx11 = true; + #if defined(ARMA_HAVE_CXX14) + static constexpr bool cxx14 = true; #else - static const bool cxx11 = false; + static constexpr bool cxx14 = false; + #endif + + + #if defined(ARMA_HAVE_CXX17) + static constexpr bool cxx17 = true; + #else + static constexpr bool cxx17 = false; + #endif + + + #if defined(ARMA_HAVE_CXX20) + static constexpr bool cxx20 = true; + #else + static constexpr bool cxx20 = false; + #endif + + + #if (!defined(ARMA_DONT_USE_STD_MUTEX)) + static constexpr bool std_mutex = true; + #else + static constexpr bool std_mutex = false; #endif #if (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - static const bool posix = true; + static constexpr bool posix = true; #else - static const bool posix = false; + static constexpr bool posix = false; #endif #if defined(ARMA_USE_WRAPPER) - static const bool wrapper = true; + static constexpr bool wrapper = true; #else - static const bool wrapper = false; + static constexpr bool wrapper = false; #endif #if defined(ARMA_USE_OPENMP) - static const bool openmp = true; + static constexpr bool openmp = true; #else - static const bool openmp = false; + static constexpr bool openmp = false; #endif #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) - static const bool hidden_args = true; + static constexpr bool hidden_args = true; + #else + static constexpr bool hidden_args = false; + #endif + + + #if defined(ARMA_DONT_ZERO_INIT) + static constexpr bool zero_init = false; + #else + static constexpr bool zero_init = true; + #endif + + + #if defined(ARMA_FAST_MATH) + static constexpr bool fast_math = true; + #else + static constexpr bool fast_math = false; + #endif + + + #if defined(ARMA_FAST_MATH) && !defined(ARMA_DONT_PRINT_FAST_MATH_WARNING) + static constexpr bool fast_math_warn = true; #else - static const bool hidden_args = false; + static constexpr bool fast_math_warn = false; #endif + + + #if (!defined(ARMA_DONT_TREAT_TEXT_AS_BINARY)) + static constexpr bool text_as_binary = true; + #else + static constexpr bool text_as_binary = false; + #endif + + + static constexpr uword warn_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : 0; }; diff --git a/src/armadillo_bits/arma_forward.hpp b/src/armadillo_bits/arma_forward.hpp index 15b1f151..4b2f37f1 100644 --- a/src/armadillo_bits/arma_forward.hpp +++ b/src/armadillo_bits/arma_forward.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,6 +37,7 @@ template class xtrans_mat; template class subview; template class subview_col; +template class subview_cols; template class subview_row; template class subview_row_strans; template class subview_row_htrans; @@ -46,6 +49,8 @@ template class SpMat; template class SpCol; template class SpRow; template class SpSubview; +template class SpSubview_col; +template class SpSubview_row; template class diagview; template class spdiagview; @@ -65,6 +70,8 @@ template class subview_cube_each1; template class subview_cube_each2; template class subview_cube_slices; +template class SpSubview_col_list; + class SizeMat; class SizeCube; @@ -76,12 +83,18 @@ class diskio; class op_strans; class op_htrans; class op_htrans2; -class op_inv; -class op_inv_sympd; +class op_inv_gen_default; +class op_inv_spd_default; +class op_inv_gen_full; +class op_inv_spd_full; class op_diagmat; class op_trimat; class op_vectorise_row; class op_vectorise_col; + +class op_row_as_mat; +class op_col_as_mat; + class glue_times; class glue_times_diag; @@ -108,8 +121,6 @@ class op_rel_noteq; class gen_eye; class gen_ones; class gen_zeros; -class gen_randu; -class gen_randn; @@ -142,9 +153,9 @@ struct traits_op_default template struct traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; }; @@ -154,9 +165,9 @@ struct traits_op_xvec template struct traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = true; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; }; }; @@ -166,9 +177,9 @@ struct traits_op_col template struct traits { - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; }; }; @@ -178,9 +189,9 @@ struct traits_op_row template struct traits { - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; }; @@ -190,9 +201,9 @@ struct traits_op_passthru template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; }; }; @@ -202,9 +213,9 @@ struct traits_glue_default template struct traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; }; @@ -214,9 +225,9 @@ struct traits_glue_or template struct traits { - static const bool is_row = (T1::is_row || T2::is_row ); - static const bool is_col = (T1::is_col || T2::is_col ); - static const bool is_xvec = (T1::is_xvec || T2::is_xvec); + static constexpr bool is_row = (T1::is_row || T2::is_row ); + static constexpr bool is_col = (T1::is_col || T2::is_col ); + static constexpr bool is_xvec = (T1::is_xvec || T2::is_xvec); }; }; @@ -234,9 +245,10 @@ template< typename T1, typename op_type> class SpToDOp; template< typename T1, typename op_type> class CubeToMatOp; template class mtOp; -template< typename T1, typename T2, typename glue_type> class Glue; -template< typename T1, typename T2, typename eglue_type> class eGlue; -template class mtGlue; +template< typename T1, typename T2, typename glue_type> class Glue; +template< typename T1, typename T2, typename eglue_type> class eGlue; +template< typename T1, typename T2, typename glue_type> class SpToDGlue; +template class mtGlue; @@ -251,12 +263,13 @@ template< typename T1, typename T2, typename eglue_type> class template class mtGlueCube; -template class Proxy; -template class ProxyCube; +template struct Proxy; +template struct ProxyCube; template class diagmat_proxy; template struct unwrap; +template struct quasi_unwrap; template struct unwrap_cube; template struct unwrap_spmat; @@ -267,7 +280,7 @@ struct state_type { #if defined(ARMA_USE_OPENMP) int state; - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) std::atomic state; #else int state; @@ -286,7 +299,7 @@ struct state_type #if defined(ARMA_USE_OPENMP) #pragma omp atomic read out = state; - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) out = state.load(); #else out = state; @@ -302,7 +315,7 @@ struct state_type #if defined(ARMA_USE_OPENMP) #pragma omp atomic write state = in_state; - #elif defined(ARMA_USE_CXX11) + #elif (!defined(ARMA_DONT_USE_STD_MUTEX)) state.store(in_state); #else state = in_state; @@ -318,7 +331,7 @@ template< typename T1, typename T2, typename spglue_type> class template class mtSpGlue; -template class SpProxy; +template struct SpProxy; @@ -327,12 +340,18 @@ struct arma_fixed_indicator {}; struct arma_reserve_indicator {}; struct arma_layout_indicator {}; +template struct arma_initmode_indicator {}; + +struct arma_zeros_indicator : public arma_initmode_indicator {}; +struct arma_nozeros_indicator : public arma_initmode_indicator {}; + //! \addtogroup injector //! @{ template struct injector_end_of_row {}; +// DEPRECATED: DO NOT USE IN NEW CODE static const injector_end_of_row<> endr = injector_end_of_row<>(); //!< endr indicates "end of row" when using the << operator; //!< similar conceptual meaning to std::endl @@ -345,123 +364,47 @@ static const injector_end_of_row<> endr = injector_end_of_row<>(); //! @{ -enum file_type +enum struct file_type : unsigned int { file_type_unknown, - auto_detect, //!< Automatically detect the file type - raw_ascii, //!< ASCII format (text), without any other information. - arma_ascii, //!< Armadillo ASCII format (text), with information about matrix type and size - csv_ascii, //!< comma separated values (CSV), without any other information - raw_binary, //!< raw binary format, without any other information. - arma_binary, //!< Armadillo binary format, with information about matrix type and size + auto_detect, //!< attempt to automatically detect the file type + raw_ascii, //!< raw text (ASCII), without a header + arma_ascii, //!< Armadillo text format, with a header specifying matrix type and size + csv_ascii, //!< comma separated values (CSV), without a header + raw_binary, //!< raw binary format (machine dependent), without a header + arma_binary, //!< Armadillo binary format (machine dependent), with a header specifying matrix type and size pgm_binary, //!< Portable Grey Map (greyscale image) ppm_binary, //!< Portable Pixel Map (colour image), used by the field and cube classes - hdf5_binary, //!< Open binary format, not specific to Armadillo, which can store arbitrary data - hdf5_binary_trans, //!< as per hdf5_binary, but save/load the data with columns transposed to rows - coord_ascii //!< simple co-ordinate format for sparse matrices + hdf5_binary, //!< HDF5: open binary format, not specific to Armadillo, which can store arbitrary data + hdf5_binary_trans, //!< [NOTE: DO NOT USE - deprecated] as per hdf5_binary, but save/load the data with columns transposed to rows + coord_ascii, //!< simple co-ordinate format for sparse matrices (indices start at zero) + ssv_ascii, //!< similar to csv_ascii; uses semicolon (;) instead of comma (,) as the separator }; -namespace hdf5_opts - { - typedef unsigned int flag_type; - - struct opts - { - const flag_type flags; - - inline explicit opts(const flag_type in_flags); - - inline const opts operator+(const opts& rhs) const; - }; - - inline - opts::opts(const flag_type in_flags) - : flags(in_flags) - {} - - inline - const opts - opts::operator+(const opts& rhs) const - { - const opts result( flags | rhs.flags ); - - return result; - } - - // The values below (eg. 1u << 0) are for internal Armadillo use only. - // The values can change without notice. - - static const flag_type flag_none = flag_type(0 ); - static const flag_type flag_trans = flag_type(1u << 0); - static const flag_type flag_append = flag_type(1u << 1); - static const flag_type flag_replace = flag_type(1u << 2); - - struct opts_none : public opts { inline opts_none() : opts(flag_none ) {} }; - struct opts_trans : public opts { inline opts_trans() : opts(flag_trans ) {} }; - struct opts_append : public opts { inline opts_append() : opts(flag_append ) {} }; - struct opts_replace : public opts { inline opts_replace() : opts(flag_replace) {} }; - - static const opts_none none; - static const opts_trans trans; - static const opts_append append; - static const opts_replace replace; - } +static constexpr file_type file_type_unknown = file_type::file_type_unknown; +static constexpr file_type auto_detect = file_type::auto_detect; +static constexpr file_type raw_ascii = file_type::raw_ascii; +static constexpr file_type arma_ascii = file_type::arma_ascii; +static constexpr file_type csv_ascii = file_type::csv_ascii; +static constexpr file_type raw_binary = file_type::raw_binary; +static constexpr file_type arma_binary = file_type::arma_binary; +static constexpr file_type pgm_binary = file_type::pgm_binary; +static constexpr file_type ppm_binary = file_type::ppm_binary; +static constexpr file_type hdf5_binary = file_type::hdf5_binary; +static constexpr file_type hdf5_binary_trans = file_type::hdf5_binary_trans; +static constexpr file_type coord_ascii = file_type::coord_ascii; +static constexpr file_type ssv_ascii = file_type::ssv_ascii; -struct hdf5_name - { - const std::string filename; - const std::string dsname; - const hdf5_opts::opts opts; - - inline - hdf5_name(const std::string& in_filename) - : filename(in_filename ) - , dsname (std::string() ) - , opts (hdf5_opts::none) - {} - - inline - hdf5_name(const std::string& in_filename, const std::string& in_dsname, const hdf5_opts::opts& in_opts = hdf5_opts::none) - : filename(in_filename) - , dsname (in_dsname ) - , opts (in_opts ) - {} - }; +struct hdf5_name; +struct csv_name; //! @} -//! \addtogroup fill -//! @{ - -namespace fill - { - struct fill_none {}; - struct fill_zeros {}; - struct fill_ones {}; - struct fill_eye {}; - struct fill_randu {}; - struct fill_randn {}; - - template - struct fill_class { inline fill_class() {} }; - - static const fill_class none; - static const fill_class zeros; - static const fill_class ones; - static const fill_class eye; - static const fill_class randu; - static const fill_class randn; - } - -//! @} - - - //! \addtogroup fn_spsolve //! @{ @@ -507,3 +450,26 @@ struct superlu_opts : public spsolve_opts_base //! @} + + + +//! \ingroup fn_eigs_sym fs_eigs_gen +//! @{ + + +struct eigs_opts + { + double tol; // tolerance + unsigned int maxiter; // max iterations + unsigned int subdim; // subspace dimension + + inline eigs_opts() + { + tol = 0.0; + maxiter = 1000; + subdim = 0; + } + }; + + +//! @} diff --git a/src/armadillo_bits/arma_ostream_bones.hpp b/src/armadillo_bits/arma_ostream_bones.hpp index 51073b31..e59c26fc 100644 --- a/src/armadillo_bits/arma_ostream_bones.hpp +++ b/src/armadillo_bits/arma_ostream_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,15 +24,15 @@ class arma_ostream_state { private: - + const ios::fmtflags orig_flags; const std::streamsize orig_precision; const std::streamsize orig_width; const char orig_fill; - - + + public: - + inline arma_ostream_state(const std::ostream& o); inline void restore(std::ostream& o) const; @@ -44,26 +46,32 @@ class arma_ostream template inline static std::streamsize modify_stream(std::ostream& o, const eT* data, const uword n_elem); template inline static std::streamsize modify_stream(std::ostream& o, const std::complex* data, const uword n_elem); - template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_not_cx::result* junk = 0); - template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk = 0); + template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_not_cx::result* junk = nullptr); + template inline static std::streamsize modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk = nullptr); template inline static void print_elem_zero(std::ostream& o, const bool modify); - template inline static void print_elem(std::ostream& o, const eT& x, const bool modify); - template inline static void print_elem(std::ostream& o, const std::complex& x, const bool modify); - + template inline static void print_elem(std::ostream& o, const eT& x, const bool modify); + template inline static void raw_print_elem(std::ostream& o, const eT& x); + + template inline static void print_elem(std::ostream& o, const std::complex& x, const bool modify); + template inline static void raw_print_elem(std::ostream& o, const std::complex& x); + template arma_cold inline static void print(std::ostream& o, const Mat& m, const bool modify); template arma_cold inline static void print(std::ostream& o, const Cube& m, const bool modify); template arma_cold inline static void print(std::ostream& o, const field& m); template arma_cold inline static void print(std::ostream& o, const subview_field& m); - - + template arma_cold inline static void print_dense(std::ostream& o, const SpMat& m, const bool modify); template arma_cold inline static void print(std::ostream& o, const SpMat& m, const bool modify); arma_cold inline static void print(std::ostream& o, const SizeMat& S); arma_cold inline static void print(std::ostream& o, const SizeCube& S); + + template arma_cold inline static void brief_print(std::ostream& o, const Mat& m, const bool print_size = true); + template arma_cold inline static void brief_print(std::ostream& o, const Cube& m); + template arma_cold inline static void brief_print(std::ostream& o, const SpMat& m); }; diff --git a/src/armadillo_bits/arma_ostream_meat.hpp b/src/armadillo_bits/arma_ostream_meat.hpp index f3968d70..dbd4b6ca 100644 --- a/src/armadillo_bits/arma_ostream_meat.hpp +++ b/src/armadillo_bits/arma_ostream_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -190,18 +192,18 @@ arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator { arma_extra_debug_sigprint(); arma_ignore(junk); - + o.unsetf(ios::showbase); o.unsetf(ios::uppercase); o.unsetf(ios::showpos); - + o.fill(' '); - + std::streamsize cell_width; - + bool use_layout_B = false; bool use_layout_C = false; - + for(typename SpMat::const_iterator it = begin; it.pos() < n_elem; ++it) { const eT val = (*it); @@ -218,7 +220,7 @@ arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator use_layout_C = true; break; } - + if( (val >= eT(+10)) || ( (is_signed::value) && (val <= eT(-10)) ) ) @@ -226,7 +228,7 @@ arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator use_layout_B = true; } } - + if(use_layout_C) { o.setf(ios::scientific); @@ -259,10 +261,10 @@ arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator //! "better than nothing" settings for complex numbers -template +template inline std::streamsize -arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk) +arma_ostream::modify_stream(std::ostream& o, typename SpMat::const_iterator begin, const uword n_elem, const typename arma_cx_only::result* junk) { arma_ignore(begin); arma_ignore(n_elem); @@ -292,6 +294,8 @@ inline void arma_ostream::print_elem_zero(std::ostream& o, const bool modify) { + typedef typename promote_type::result promoted_eT; + if(modify) { const ios::fmtflags save_flags = o.flags(); @@ -301,116 +305,126 @@ arma_ostream::print_elem_zero(std::ostream& o, const bool modify) o.setf(ios::fixed); o.precision(0); - o << eT(0); + o << promoted_eT(0); o.flags(save_flags); o.precision(save_precision); } else { - o << eT(0); + o << promoted_eT(0); } } -//! Print an element to the specified stream template inline void arma_ostream::print_elem(std::ostream& o, const eT& x, const bool modify) + { + if(x == eT(0)) + { + arma_ostream::print_elem_zero(o, modify); + } + else + { + arma_ostream::raw_print_elem(o, x); + } + } + + + +template +inline +void +arma_ostream::raw_print_elem(std::ostream& o, const eT& x) { if(is_signed::value) { typedef typename promote_type::result promoted_eT; - if(x != eT(0)) + if(arma_isfinite(x)) { - if(arma_isfinite(x)) - { - o << promoted_eT(x); - } - else - { - o << ( arma_isinf(x) ? ((x <= eT(0)) ? "-inf" : "inf") : "nan" ); - } + o << promoted_eT(x); } else { - arma_ostream::print_elem_zero(o, modify); + o << ( arma_isinf(x) ? ((x <= eT(0)) ? "-inf" : "inf") : "nan" ); } } else { typedef typename promote_type::result promoted_eT; - if(x != eT(0)) - { - o << promoted_eT(x); - } - else - { - arma_ostream::print_elem_zero(o, modify); - } + o << promoted_eT(x); } } -//! Print a complex element to the specified stream template inline void arma_ostream::print_elem(std::ostream& o, const std::complex& x, const bool modify) { - if( (x.real() != T(0)) || (x.imag() != T(0)) || (modify == false) ) + if( (x.real() == T(0)) && (x.imag() == T(0)) && (modify) ) { - std::ostringstream ss; - ss.flags(o.flags()); - //ss.imbue(o.getloc()); - ss.precision(o.precision()); - - ss << '('; - - const T a = x.real(); - - if(arma_isfinite(a)) - { - ss << a; - } - else - { - ss << ( arma_isinf(a) ? ((a <= T(0)) ? "-inf" : "+inf") : "nan" ); - } - - ss << ','; - - const T b = x.imag(); - - if(arma_isfinite(b)) - { - ss << b; - } - else - { - ss << ( arma_isinf(b) ? ((b <= T(0)) ? "-inf" : "+inf") : "nan" ); - } - - ss << ')'; - - o << ss.str(); + o << "(0,0)"; } else { - o << "(0,0)"; + arma_ostream::raw_print_elem(o, x); } } +template +inline +void +arma_ostream::raw_print_elem(std::ostream& o, const std::complex& x) + { + std::ostringstream ss; + ss.flags(o.flags()); + //ss.imbue(o.getloc()); + ss.precision(o.precision()); + + ss << '('; + + const T a = x.real(); + + if(arma_isfinite(a)) + { + ss << a; + } + else + { + ss << ( arma_isinf(a) ? ((a <= T(0)) ? "-inf" : "+inf") : "nan" ); + } + + ss << ','; + + const T b = x.imag(); + + if(arma_isfinite(b)) + { + ss << b; + } + else + { + ss << ( arma_isinf(b) ? ((b <= T(0)) ? "-inf" : "+inf") : "nan" ); + } + + ss << ')'; + + o << ss.str(); + } + + + //! Print a matrix to the specified stream template -arma_cold inline void arma_ostream::print(std::ostream& o, const Mat& m, const bool modify) @@ -461,6 +475,14 @@ arma_ostream::print(std::ostream& o, const Mat& m, const bool modify) } else { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + o << "[matrix size: " << m_n_rows << 'x' << m_n_cols << "]\n"; } @@ -472,7 +494,6 @@ arma_ostream::print(std::ostream& o, const Mat& m, const bool modify) //! Print a cube to the specified stream template -arma_cold inline void arma_ostream::print(std::ostream& o, const Cube& x, const bool modify) @@ -487,16 +508,25 @@ arma_ostream::print(std::ostream& o, const Cube& x, const bool modify) { const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); - o << "[cube slice " << slice << ']' << '\n'; + o << "[cube slice: " << slice << ']' << '\n'; arma_ostream::print(o, tmp, modify); - o << '\n'; + + if((slice+1) < x.n_slices) { o << '\n'; } } } else { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + o << "[cube size: " << x.n_rows << 'x' << x.n_cols << 'x' << x.n_slices << "]\n"; } - + stream_state.restore(o); } @@ -504,9 +534,8 @@ arma_ostream::print(std::ostream& o, const Cube& x, const bool modify) //! Print a field to the specified stream -//! Assumes type oT can be printed, i.e. oT has std::ostream& operator<< (std::ostream&, const oT&) +//! Assumes type oT can be printed, ie. oT has std::ostream& operator<< (std::ostream&, const oT&) template -arma_cold inline void arma_ostream::print(std::ostream& o, const field& x) @@ -525,11 +554,11 @@ arma_ostream::print(std::ostream& o, const field& x) { if(x_n_slices == 1) { - for(uword col=0; col& x) } else { - for(uword slice=0; slice& x) } else { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + o << "[field size: " << x_n_rows << 'x' << x_n_cols << 'x' << x_n_slices << "]\n"; } @@ -573,9 +607,8 @@ arma_ostream::print(std::ostream& o, const field& x) //! Print a subfield to the specified stream -//! Assumes type oT can be printed, i.e. oT has std::ostream& operator<< (std::ostream&, const oT&) +//! Assumes type oT can be printed, ie. oT has std::ostream& operator<< (std::ostream&, const oT&) template -arma_cold inline void arma_ostream::print(std::ostream& o, const subview_field& x) @@ -594,9 +627,9 @@ arma_ostream::print(std::ostream& o, const subview_field& x) { if(x_n_slices == 1) { - for(uword col=0; col& x) } else { - for(uword slice=0; slice& x) } else { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + o << "[field size: " << x_n_rows << 'x' << x_n_cols << 'x' << x_n_slices << "]\n"; } @@ -641,7 +679,6 @@ arma_ostream::print(std::ostream& o, const subview_field& x) template -arma_cold inline void arma_ostream::print_dense(std::ostream& o, const SpMat& m, const bool modify) @@ -706,6 +743,14 @@ arma_ostream::print_dense(std::ostream& o, const SpMat& m, const bool modify } else { + if(modify) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + } + o << "[matrix size: " << m_n_rows << 'x' << m_n_cols << "]\n"; } @@ -716,7 +761,6 @@ arma_ostream::print_dense(std::ostream& o, const SpMat& m, const bool modify template -arma_cold inline void arma_ostream::print(std::ostream& o, const SpMat& m, const bool modify) @@ -776,37 +820,38 @@ arma_ostream::print(std::ostream& o, const SpMat& m, const bool modify) { const std::streamsize cell_width = modify ? arma_ostream::modify_stream(o, m.begin(), m_n_nonzero) : o.width(); - typename SpMat::const_iterator begin = m.begin(); - typename SpMat::const_iterator m_end = m.end(); + typename SpMat::const_iterator it = m.begin(); + typename SpMat::const_iterator it_end = m.end(); - while(begin != m_end) + while(it != it_end) { - const uword row = begin.row(); + const uword row = it.row(); + const uword col = it.col(); // TODO: change the maximum number of spaces before and after each location to be dependent on n_rows and n_cols - if(row < 10) { o << " "; } - else if(row < 100) { o << " "; } - else if(row < 1000) { o << " "; } - else if(row < 10000) { o << " "; } - else if(row < 100000) { o << ' '; } - - const uword col = begin.col(); + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } o << '(' << row << ", " << col << ") "; - if(col < 10) { o << " "; } - else if(col < 100) { o << " "; } - else if(col < 1000) { o << " "; } - else if(col < 10000) { o << " "; } - else if(col < 100000) { o << ' '; } + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } if(cell_width > 0) { o.width(cell_width); } - arma_ostream::print_elem(o, eT(*begin), modify); + arma_ostream::print_elem(o, eT(*it), modify); o << '\n'; - ++begin; + ++it; } o << '\n'; @@ -818,7 +863,6 @@ arma_ostream::print(std::ostream& o, const SpMat& m, const bool modify) -arma_cold inline void arma_ostream::print(std::ostream& o, const SizeMat& S) @@ -840,7 +884,6 @@ arma_ostream::print(std::ostream& o, const SizeMat& S) -arma_cold inline void arma_ostream::print(std::ostream& o, const SizeCube& S) @@ -862,4 +905,370 @@ arma_ostream::print(std::ostream& o, const SizeCube& S) +template +inline +void +arma_ostream::brief_print(std::ostream& o, const Mat& m, const bool print_size) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + if(print_size) + { + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + + o << "[matrix size: " << m.n_rows << 'x' << m.n_cols << "]\n"; + } + + if(m.n_elem == 0) { o.flush(); stream_state.restore(o); return; } + + if((m.n_rows <= 5) && (m.n_cols <= 5)) { arma_ostream::print(o, m, true); return; } + + const bool print_row_ellipsis = (m.n_rows >= 6); + const bool print_col_ellipsis = (m.n_cols >= 6); + + if( (print_row_ellipsis == true) && (print_col_ellipsis == true) ) + { + Mat X(4, 4, arma_nozeros_indicator()); + + X( span(0,2), span(0,2) ) = m( span(0,2), span(0,2) ); // top left submatrix + X( 3, span(0,2) ) = m( m.n_rows-1, span(0,2) ); // truncated last row + X( span(0,2), 3 ) = m( span(0,2), m.n_cols-1 ); // truncated last column + X( 3, 3 ) = m( m.n_rows-1, m.n_cols-1 ); // bottom right element + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row <= 2; ++row) + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + o << ':'; + } + + o.width(6); + o << "..."; + + o.width(cell_width); + o << ':' << '\n'; + + const uword row = 3; + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + } + + + if( (print_row_ellipsis == true) && (print_col_ellipsis == false) ) + { + Mat X(4, m.n_cols, arma_nozeros_indicator()); + + X( span(0,2), span::all ) = m( span(0,2), span::all ); // top + X( 3, span::all ) = m( m.n_rows-1, span::all ); // bottom + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row <= 2; ++row) // first 3 rows + { + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o << '\n'; + } + + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + o << ':'; + } + + o.width(cell_width); + o << '\n'; + + const uword row = 3; + { + for(uword col=0; col < m.n_cols; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + } + + o << '\n'; + } + + + if( (print_row_ellipsis == false) && (print_col_ellipsis == true) ) + { + Mat X(m.n_rows, 4, arma_nozeros_indicator()); + + X( span::all, span(0,2) ) = m( span::all, span(0,2) ); // left + X( span::all, 3 ) = m( span::all, m.n_cols-1 ); // right + + const std::streamsize cell_width = arma_ostream::modify_stream(o, X.memptr(), X.n_elem); + + for(uword row=0; row < m.n_rows; ++row) + { + for(uword col=0; col <= 2; ++col) + { + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,col), true); + } + + o.width(6); + o << "..."; + + o.width(cell_width); + arma_ostream::print_elem(o, X.at(row,3), true); + o << '\n'; + } + } + + + o.flush(); + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::brief_print(std::ostream& o, const Cube& x) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.setf(ios::fixed); + + o << "[cube size: " << x.n_rows << 'x' << x.n_cols << 'x' << x.n_slices << "]\n"; + + if(x.n_elem == 0) { o.flush(); stream_state.restore(o); return; } + + if(x.n_slices <= 3) + { + for(uword slice=0; slice < x.n_slices; ++slice) + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + + if((slice+1) < x.n_slices) { o << '\n'; } + } + } + else + { + for(uword slice=0; slice <= 1; ++slice) + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + o << '\n'; + } + + o << "[cube slice: ...]\n\n"; + + const uword slice = x.n_slices-1; + { + const Mat tmp(const_cast(x.slice_memptr(slice)), x.n_rows, x.n_cols, false); + + o << "[cube slice: " << slice << ']' << '\n'; + arma_ostream::brief_print(o, tmp, false); + } + } + + stream_state.restore(o); + } + + + +template +inline +void +arma_ostream::brief_print(std::ostream& o, const SpMat& m) + { + arma_extra_debug_sigprint(); + + if(m.n_nonzero <= 10) { arma_ostream::print(o, m, true); return; } + + const arma_ostream_state stream_state(o); + + o.unsetf(ios::showbase); + o.unsetf(ios::uppercase); + o.unsetf(ios::showpos); + o.unsetf(ios::scientific); + o.setf(ios::right); + o.setf(ios::fixed); + + const uword m_n_nonzero = m.n_nonzero; + const double density = (m.n_elem > 0) ? (double(m_n_nonzero) / double(m.n_elem) * double(100)) : double(0); + + o << "[matrix size: " << m.n_rows << 'x' << m.n_cols << "; n_nonzero: " << m_n_nonzero; + + if(density == double(0)) + { + o.precision(0); + } + else + if(density >= (double(10.0)-std::numeric_limits::epsilon())) + { + o.precision(1); + } + else + if(density > (double(0.01)-std::numeric_limits::epsilon())) + { + o.precision(2); + } + else + if(density > (double(0.001)-std::numeric_limits::epsilon())) + { + o.precision(3); + } + else + if(density > (double(0.0001)-std::numeric_limits::epsilon())) + { + o.precision(4); + } + else + { + o.unsetf(ios::fixed); + o.setf(ios::scientific); + o.precision(2); + } + + o << "; density: " << density << "%]\n\n"; + + // get the first 9 elements and the last element + + typename SpMat::const_iterator it = m.begin(); + typename SpMat::const_iterator it_end = m.end(); + + uvec storage_row(10); + uvec storage_col(10); + Col storage_val(10); + + uword count = 0; + + while( (it != it_end) && (count < 9) ) + { + storage_row(count) = it.row(); + storage_col(count) = it.col(); + storage_val(count) = (*it); + + ++it; + ++count; + } + + it = it_end; + --it; + + storage_row(count) = it.row(); + storage_col(count) = it.col(); + storage_val(count) = (*it); + + const std::streamsize cell_width = arma_ostream::modify_stream(o, storage_val.memptr(), 10); + + for(uword i=0; i < 9; ++i) + { + const uword row = storage_row(i); + const uword col = storage_col(i); + + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } + + o << '(' << row << ", " << col << ") "; + + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } + + if(cell_width > 0) { o.width(cell_width); } + + arma_ostream::print_elem(o, storage_val(i), true); + o << '\n'; + } + + o << " (:, :) "; + if(cell_width > 0) { o.width(cell_width); } + o << "...\n"; + + + const uword i = 9; + { + const uword row = storage_row(i); + const uword col = storage_col(i); + + if(row < 10) { o << " "; } + else if(row < 100) { o << " "; } + else if(row < 1000) { o << " "; } + else if(row < 10000) { o << " "; } + else if(row < 100000) { o << " "; } + else if(row < 1000000) { o << ' '; } + + o << '(' << row << ", " << col << ") "; + + if(col < 10) { o << " "; } + else if(col < 100) { o << " "; } + else if(col < 1000) { o << " "; } + else if(col < 10000) { o << " "; } + else if(col < 100000) { o << " "; } + else if(col < 1000000) { o << ' '; } + + if(cell_width > 0) { o.width(cell_width); } + + arma_ostream::print_elem(o, storage_val(i), true); + o << '\n'; + } + + o.flush(); + stream_state.restore(o); + } + + + //! @} diff --git a/src/armadillo_bits/arma_rel_comparators.hpp b/src/armadillo_bits/arma_rel_comparators.hpp index 67f704ae..977617b2 100644 --- a/src/armadillo_bits/arma_rel_comparators.hpp +++ b/src/armadillo_bits/arma_rel_comparators.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -67,6 +69,33 @@ struct arma_lt_comparator< std::complex > // // return ( (abs_a != abs_b) ? (abs_a < abs_b) : (std::arg(a) < std::arg(b)) ); // } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T a_real = a.real(); + // const T a_imag = a.imag(); + // + // const T a_mag_squared = a_real*a_real + a_imag*a_imag; + // + // const T b_real = b.real(); + // const T b_imag = b.imag(); + // + // const T b_mag_squared = b_real*b_real + b_imag*b_imag; + // + // if( (a_mag_squared != T(0)) && (b_mag_squared != T(0)) && std::isfinite(a_mag_squared) && std::isfinite(b_mag_squared) ) + // { + // return ( (a_mag_squared != b_mag_squared) ? (a_mag_squared < b_mag_squared) : (std::arg(a) < std::arg(b)) ); + // } + // else + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a < abs_b) : (std::arg(a) < std::arg(b)) ); + // } + // } }; @@ -87,6 +116,33 @@ struct arma_gt_comparator< std::complex > // // return ( (abs_a != abs_b) ? (abs_a > abs_b) : (std::arg(a) > std::arg(b)) ); // } + + // inline + // bool + // operator() (const eT& a, const eT& b) const + // { + // const T a_real = a.real(); + // const T a_imag = a.imag(); + // + // const T a_mag_squared = a_real*a_real + a_imag*a_imag; + // + // const T b_real = b.real(); + // const T b_imag = b.imag(); + // + // const T b_mag_squared = b_real*b_real + b_imag*b_imag; + // + // if( (a_mag_squared != T(0)) && (b_mag_squared != T(0)) && std::isfinite(a_mag_squared) && std::isfinite(b_mag_squared) ) + // { + // return ( (a_mag_squared != b_mag_squared) ? (a_mag_squared > b_mag_squared) : (std::arg(a) > std::arg(b)) ); + // } + // else + // { + // const T abs_a = std::abs(a); + // const T abs_b = std::abs(b); + // + // return ( (abs_a != abs_b) ? (abs_a > abs_b) : (std::arg(a) > std::arg(b)) ); + // } + // } }; diff --git a/src/armadillo_bits/arma_rng.hpp b/src/armadillo_bits/arma_rng.hpp index 362cc665..da1b4f7a 100644 --- a/src/armadillo_bits/arma_rng.hpp +++ b/src/armadillo_bits/arma_rng.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,19 +20,62 @@ //! @{ -#if defined(ARMA_RNG_ALT) - #undef ARMA_USE_EXTERN_CXX11_RNG +#undef ARMA_USE_CXX11_RNG +#define ARMA_USE_CXX11_RNG + +#undef ARMA_USE_THREAD_LOCAL +#define ARMA_USE_THREAD_LOCAL + +#if (defined(ARMA_RNG_ALT) || defined(ARMA_DONT_USE_CXX11_RNG)) + #undef ARMA_USE_CXX11_RNG +#endif + +#if defined(ARMA_DONT_USE_THREAD_LOCAL) + #undef ARMA_USE_THREAD_LOCAL +#endif + + +// NOTE: ARMA_WARMUP_PRODUCER enables a workaround +// NOTE: for thread_local issue on macOS 11 and/or AppleClang 12.0 +// NOTE: see https://gitlab.com/conradsnicta/armadillo-code/-/issues/173 +// NOTE: if this workaround causes problems, please report it and +// NOTE: disable the workaround by commenting out the code block below: + +#if defined(__APPLE__) || defined(__apple_build_version__) + #undef ARMA_WARMUP_PRODUCER + #define ARMA_WARMUP_PRODUCER #endif +#if defined(ARMA_DONT_WARMUP_PRODUCER) + #undef ARMA_WARMUP_PRODUCER +#endif + +// NOTE: workaround for another thread_local issue on macOS +// NOTE: where GCC (not Clang) may not have support for thread_local -#if !defined(ARMA_USE_CXX11) - #undef ARMA_USE_EXTERN_CXX11_RNG +#if (defined(__APPLE__) && defined(__GNUG__) && !defined(__clang__)) + #undef ARMA_USE_THREAD_LOCAL #endif +// NOTE: disable use of thread_local on MinGW et al; +// NOTE: i don't have the patience to keep looking into these broken platforms -#if defined(ARMA_USE_EXTERN_CXX11_RNG) - extern thread_local arma_rng_cxx11 arma_rng_cxx11_instance; - // namespace { thread_local arma_rng_cxx11 arma_rng_cxx11_instance; } +#if (defined(__MINGW32__) || defined(__MINGW64__) || defined(__CYGWIN__) || defined(__MSYS__) || defined(__MSYS2__)) + #undef ARMA_USE_THREAD_LOCAL +#endif + +#if defined(ARMA_FORCE_USE_THREAD_LOCAL) + #undef ARMA_USE_THREAD_LOCAL + #define ARMA_USE_THREAD_LOCAL +#endif + +#if (!defined(ARMA_USE_THREAD_LOCAL)) + #undef ARMA_GUARD_PRODUCER + #define ARMA_GUARD_PRODUCER +#endif + +#if (defined(ARMA_DONT_GUARD_PRODUCER) || defined(ARMA_DONT_USE_STD_MUTEX)) + #undef ARMA_GUARD_PRODUCER #endif @@ -39,19 +84,31 @@ class arma_rng public: #if defined(ARMA_RNG_ALT) - typedef arma_rng_alt::seed_type seed_type; - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) - typedef arma_rng_cxx11::seed_type seed_type; + typedef arma_rng_alt::seed_type seed_type; + #elif defined(ARMA_USE_CXX11_RNG) + typedef std::mt19937_64::result_type seed_type; #else - typedef arma_rng_cxx98::seed_type seed_type; + typedef arma_rng_cxx03::seed_type seed_type; #endif #if defined(ARMA_RNG_ALT) - static const int rng_method = 2; - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) - static const int rng_method = 1; + static constexpr int rng_method = 2; + #elif defined(ARMA_USE_CXX11_RNG) + static constexpr int rng_method = 1; #else - static const int rng_method = 0; + static constexpr int rng_method = 0; + #endif + + #if defined(ARMA_USE_CXX11_RNG) + inline static std::mt19937_64& get_producer(); + inline static void warmup_producer(std::mt19937_64& producer); + + inline static void lock_producer(); + inline static void unlock_producer(); + + #if defined(ARMA_GUARD_PRODUCER) + inline static std::mutex& get_producer_mutex(); + #endif #endif inline static void set_seed(const seed_type val); @@ -60,10 +117,106 @@ class arma_rng template struct randi; template struct randu; template struct randn; + template struct randg; }; +#if defined(ARMA_USE_CXX11_RNG) + +inline +std::mt19937_64& +arma_rng::get_producer() + { + #if defined(ARMA_USE_THREAD_LOCAL) + + // use a thread-safe RNG, with each thread having its own unique starting seed + + static std::atomic mt19937_64_producer_counter(0); + + static thread_local std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed + mt19937_64_producer_counter++ ); + + arma_rng::warmup_producer(mt19937_64_producer); + + #else + + // use a plain RNG in case we don't have thread_local + + static std::mt19937_64 mt19937_64_producer( std::mt19937_64::default_seed ); + + arma_rng::warmup_producer(mt19937_64_producer); + + #endif + + return mt19937_64_producer; + } + + +inline +void +arma_rng::warmup_producer(std::mt19937_64& producer) + { + #if defined(ARMA_WARMUP_PRODUCER) + + static std::atomic_flag warmup_done = ATOMIC_FLAG_INIT; // init to false + + if(warmup_done.test_and_set() == false) + { + typename std::mt19937_64::result_type junk = producer(); + + arma_ignore(junk); + } + + #else + + arma_ignore(producer); + + #endif + } + + +inline +void +arma_rng::lock_producer() + { + #if defined(ARMA_GUARD_PRODUCER) + + std::mutex& producer_mutex = arma_rng::get_producer_mutex(); + + producer_mutex.lock(); + + #endif + } + + +inline +void +arma_rng::unlock_producer() + { + #if defined(ARMA_GUARD_PRODUCER) + + std::mutex& producer_mutex = arma_rng::get_producer_mutex(); + + producer_mutex.unlock(); + + #endif + } + + +#if defined(ARMA_GUARD_PRODUCER) + inline + std::mutex& + arma_rng::get_producer_mutex() + { + static std::mutex producer_mutex; + + return producer_mutex; + } +#endif + +#endif + + inline void arma_rng::set_seed(const arma_rng::seed_type val) @@ -72,13 +225,15 @@ arma_rng::set_seed(const arma_rng::seed_type val) { arma_rng_alt::set_seed(val); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - arma_rng_cxx11_instance.set_seed(val); + arma_rng::lock_producer(); + arma_rng::get_producer().seed(val); + arma_rng::unlock_producer(); } #else { - arma_rng_cxx98::set_seed(val); + arma_rng_cxx03::set_seed(val); } #endif } @@ -94,23 +249,18 @@ arma_rng::set_seed_random() seed_type seed2 = seed_type(0); seed_type seed3 = seed_type(0); seed_type seed4 = seed_type(0); - seed_type seed5 = seed_type(0); bool have_seed = false; - #if defined(ARMA_USE_CXX11) + try { - try - { - std::random_device rd; - - if(rd.entropy() > double(0)) { seed1 = static_cast( rd() ); } - - if(seed1 != seed_type(0)) { have_seed = true; } - } - catch(...) {} + std::random_device rd; + + if(rd.entropy() > double(0)) { seed1 = static_cast( rd() ); } + + have_seed = (seed1 != seed_type(0)); } - #endif + catch(...) {} if(have_seed == false) @@ -129,12 +279,9 @@ arma_rng::set_seed_random() if(f.good()) { f.read((char*)(&(tmp.b[0])), sizeof(seed_type)); } - if(f.good()) - { - seed2 = tmp.a; + if(f.good()) { seed2 = tmp.a; } - if(seed2 != seed_type(0)) { have_seed = true; } - } + have_seed = (seed2 != seed_type(0)); } catch(...) {} } @@ -144,17 +291,11 @@ arma_rng::set_seed_random() { // get better-than-nothing seeds in case reading /dev/urandom failed - #if defined(ARMA_HAVE_GETTIMEOFDAY) - { - struct timeval posix_time; - - gettimeofday(&posix_time, 0); - - seed3 = static_cast(posix_time.tv_usec); - } - #endif + const std::chrono::system_clock::time_point tp_now = std::chrono::system_clock::now(); + + auto since_epoch_usec = std::chrono::duration_cast(tp_now.time_since_epoch()).count(); - seed4 = static_cast( std::time(NULL) & 0xFFFF ); + seed3 = static_cast( since_epoch_usec & 0xFFFF ); union { @@ -164,36 +305,48 @@ arma_rng::set_seed_random() tmp.a = (uword*)malloc(sizeof(uword)); - if(tmp.a != NULL) + if(tmp.a != nullptr) { - for(size_t i=0; i struct arma_rng::randi { - arma_inline + inline operator eT () { #if defined(ARMA_RNG_ALT) { return eT( arma_rng_alt::randi_val() ); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - return eT( arma_rng_cxx11_instance.randi_val() ); + constexpr double scale = double(std::numeric_limits::max()) / double(std::mt19937_64::max()); + + arma_rng::lock_producer(); + + const eT out = eT(double(arma_rng::get_producer()()) * scale); + + arma_rng::unlock_producer(); + + return out; } #else { - return eT( arma_rng_cxx98::randi_val() ); + return eT( arma_rng_cxx03::randi_val() ); } #endif } @@ -208,13 +361,13 @@ struct arma_rng::randi { return arma_rng_alt::randi_max_val(); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - return arma_rng_cxx11::randi_max_val(); + return std::numeric_limits::max(); } #else { - return arma_rng_cxx98::randi_max_val(); + return arma_rng_cxx03::randi_max_val(); } #endif } @@ -229,13 +382,30 @@ struct arma_rng::randi { arma_rng_alt::randi_fill(mem, N, a, b); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - arma_rng_cxx11_instance.randi_fill(mem, N, a, b); + std::uniform_int_distribution local_i_distr(a, b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i local_i_distr(a, b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i struct arma_rng::randu { - arma_inline + inline operator eT () { #if defined(ARMA_RNG_ALT) { return eT( arma_rng_alt::randu_val() ); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - return eT( arma_rng_cxx11_instance.randu_val() ); + constexpr double scale = double(1.0) / double(std::mt19937_64::max()); + + arma_rng::lock_producer(); + + const eT out = eT( double(arma_rng::get_producer()()) * scale ); + + arma_rng::unlock_producer(); + + return out; } #else { - return eT( arma_rng_cxx98::randu_val() ); + return eT( arma_rng_cxx03::randu_val() ); } #endif } @@ -270,21 +452,76 @@ struct arma_rng::randu void fill(eT* mem, const uword N) { - uword j; - - for(j=1; j < N; j+=2) + #if defined(ARMA_RNG_ALT) + { + for(uword i=0; i < N; ++i) { mem[i] = eT( arma_rng_alt::randu_val() ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else { - const eT tmp_i = eT( arma_rng::randu() ); - const eT tmp_j = eT( arma_rng::randu() ); + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randu_val() ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; - (*mem) = tmp_i; mem++; - (*mem) = tmp_j; mem++; + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(local_engine) ); } } - - if((j-1) < N) + #endif + } + + + inline + static + void + fill(eT* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_RNG_ALT) + { + const double r = b - a; + + for(uword i=0; i < N; ++i) { mem[i] = eT( arma_rng_alt::randu_val() * r + a ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(producer) ); } + + arma_rng::unlock_producer(); + } + #else { - (*mem) = eT( arma_rng::randu() ); + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randu_val() * (b - a) + a ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr(a,b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_u_distr(local_engine) ); } } + #endif } }; @@ -296,10 +533,36 @@ struct arma_rng::randu< std::complex > arma_inline operator std::complex () { - const T a = T( arma_rng::randu() ); - const T b = T( arma_rng::randu() ); - - return std::complex(a, b); + #if defined(ARMA_RNG_ALT) + { + const T a = T( arma_rng_alt::randu_val() ); + const T b = T( arma_rng_alt::randu_val() ); + + return std::complex(a, b); + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + const T a = T( local_u_distr(producer) ); + const T b = T( local_u_distr(producer) ); + + arma_rng::unlock_producer(); + + return std::complex(a, b); + } + #else + { + const T a = T( arma_rng_cxx03::randu_val() ); + const T b = T( arma_rng_cxx03::randu_val() ); + + return std::complex(a, b); + } + #endif } @@ -308,18 +571,139 @@ struct arma_rng::randu< std::complex > void fill(std::complex* mem, const uword N) { - for(uword i=0; i < N; ++i) + #if defined(ARMA_RNG_ALT) + { + for(uword i=0; i < N; ++i) + { + const T a = T( arma_rng_alt::randu_val() ); + const T b = T( arma_rng_alt::randu_val() ); + + mem[i] = std::complex(a, b); + } + } + #elif defined(ARMA_USE_CXX11_RNG) { - const T a = T( arma_rng::randu() ); - const T b = T( arma_rng::randu() ); + std::uniform_real_distribution local_u_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); - mem[i] = std::complex(a, b); + for(uword i=0; i < N; ++i) + { + const T a = T( local_u_distr(producer) ); + const T b = T( local_u_distr(producer) ); + + mem[i] = std::complex(a, b); + } + + arma_rng::unlock_producer(); } + #else + { + if(N == uword(1)) + { + const T a = T( arma_rng_cxx03::randu_val() ); + const T b = T( arma_rng_cxx03::randu_val() ); + + mem[0] = std::complex(a, b); + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) + { + const T a = T( local_u_distr(local_engine) ); + const T b = T( local_u_distr(local_engine) ); + + mem[i] = std::complex(a, b); + } + } + #endif + } + + + inline + static + void + fill(std::complex* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_RNG_ALT) + { + const double r = b - a; + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( arma_rng_alt::randu_val() * r + a ); + const T tmp2 = T( arma_rng_alt::randu_val() * r + a ); + + mem[i] = std::complex(tmp1, tmp2); + } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::uniform_real_distribution local_u_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( local_u_distr(producer) ); + const T tmp2 = T( local_u_distr(producer) ); + + mem[i] = std::complex(tmp1, tmp2); + } + + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) + { + const double r = b - a; + + const T tmp1 = T( arma_rng_cxx03::randu_val() * r + a); + const T tmp2 = T( arma_rng_cxx03::randu_val() * r + a); + + mem[0] = std::complex(tmp1, tmp2); + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr(a,b); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) + { + const T tmp1 = T( local_u_distr(local_engine) ); + const T tmp2 = T( local_u_distr(local_engine) ); + + mem[i] = std::complex(tmp1, tmp2); + } + } + #endif } }; +// + + + template struct arma_rng::randn { @@ -330,19 +714,27 @@ struct arma_rng::randn { return eT( arma_rng_alt::randn_val() ); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - return eT( arma_rng_cxx11_instance.randn_val() ); + std::normal_distribution local_n_distr; + + arma_rng::lock_producer(); + + const eT out = eT( local_n_distr(arma_rng::get_producer()) ); + + arma_rng::unlock_producer(); + + return out; } #else { - return eT( arma_rng_cxx98::randn_val() ); + return eT( arma_rng_cxx03::randn_val() ); } #endif } - arma_inline + inline static void dual_val(eT& out1, eT& out2) @@ -351,13 +743,22 @@ struct arma_rng::randn { arma_rng_alt::randn_dual_val(out1, out2); } - #elif defined(ARMA_USE_EXTERN_CXX11_RNG) + #elif defined(ARMA_USE_CXX11_RNG) { - arma_rng_cxx11_instance.randn_dual_val(out1, out2); + std::normal_distribution local_n_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + out1 = eT( local_n_distr(producer) ); + out2 = eT( local_n_distr(producer) ); + + arma_rng::unlock_producer(); } #else { - arma_rng_cxx98::randn_dual_val(out1, out2); + arma_rng_cxx03::randn_dual_val(out1, out2); } #endif } @@ -366,71 +767,110 @@ struct arma_rng::randn inline static void - fill_simple(eT* mem, const uword N) + fill(eT* mem, const uword N) { - uword i, j; - - for(i=0, j=1; j < N; i+=2, j+=2) + #if defined(ARMA_RNG_ALT) { - arma_rng::randn::dual_val( mem[i], mem[j] ); + // NOTE: old method to avoid regressions in user code that assumes specific sequence + + uword i, j; + + for(i=0, j=1; j < N; i+=2, j+=2) { arma_rng_alt::randn_dual_val( mem[i], mem[j] ); } + + if(i < N) { mem[i] = eT( arma_rng_alt::randn_val() ); } } - - if(i < N) + #elif defined(ARMA_USE_CXX11_RNG) { - mem[i] = eT( arma_rng::randn() ); + std::normal_distribution local_n_distr; + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(producer) ); } + + arma_rng::unlock_producer(); } + #else + { + if(N == uword(1)) { mem[0] = eT( arma_rng_cxx03::randn_val() ); return; } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(local_engine) ); } + } + #endif } inline static void - fill(eT* mem, const uword N) + fill(eT* mem, const uword N, const double mu, const double sd) { - #if defined(ARMA_USE_CXX11) && defined(ARMA_USE_OPENMP) + #if defined(ARMA_RNG_ALT) { - if((N < 1024) || omp_in_parallel()) { arma_rng::randn::fill_simple(mem, N); return; } - - typedef std::mt19937_64::result_type seed_type; - - const uword n_threads = uword( mp_thread_limit::get() ); + // NOTE: old method to avoid regressions in user code that assumes specific sequence - std::vector< std::mt19937_64 > engine(n_threads); - std::vector< std::normal_distribution > distr(n_threads); + uword i, j; - for(uword t=0; t < n_threads; ++t) + for(i=0, j=1; j < N; i+=2, j+=2) { - std::mt19937_64& t_engine = engine[t]; + eT val_i = eT(0); + eT val_j = eT(0); - t_engine.seed( seed_type(t) + seed_type(arma_rng::randi()) ); + arma_rng_alt::randn_dual_val( val_i, val_j ); + + mem[i] = (val_i * sd) + mu; + mem[j] = (val_j * sd) + mu; } - const uword chunk_size = N / n_threads; - - #pragma omp parallel for schedule(static) num_threads(int(n_threads)) - for(uword t=0; t < n_threads; ++t) + if(i < N) { - const uword start = (t+0) * chunk_size; - const uword endp1 = (t+1) * chunk_size; - - std::mt19937_64& t_engine = engine[t]; - std::normal_distribution& t_distr = distr[t]; - - for(uword i=start; i < endp1; ++i) { mem[i] = eT( t_distr(t_engine)); } + const eT val_i = eT( arma_rng_alt::randn_val() ); + + mem[i] = (val_i * sd) + mu; } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr(mu, sd); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); - std::mt19937_64& t0_engine = engine[0]; - std::normal_distribution& t0_distr = distr[0]; + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(producer) ); } - for(uword i=(n_threads*chunk_size); i < N; ++i) { mem[i] = eT( t0_distr(t0_engine)); } + arma_rng::unlock_producer(); } #else { - arma_rng::randn::fill_simple(mem, N); + if(N == uword(1)) + { + const eT val = eT( arma_rng_cxx03::randn_val() ); + + mem[0] = (val * sd) + mu; + + return; + } + + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr(mu, sd); + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { mem[i] = eT( local_n_distr(local_engine) ); } } #endif } - }; @@ -460,12 +900,21 @@ struct arma_rng::randn< std::complex > inline static void - fill_simple(std::complex* mem, const uword N) + dual_val(std::complex& out1, std::complex& out2) { - for(uword i=0; i < N; ++i) - { - mem[i] = std::complex( arma_rng::randn< std::complex >() ); - } + #if defined(_MSC_VER) + T a; + T b; + #else + T a(0); + T b(0); + #endif + + arma_rng::randn::dual_val(a,b); + out1 = std::complex(a,b); + + arma_rng::randn::dual_val(a,b); + out2 = std::complex(a,b); } @@ -474,58 +923,115 @@ struct arma_rng::randn< std::complex > void fill(std::complex* mem, const uword N) { - #if defined(ARMA_USE_CXX11) && defined(ARMA_USE_OPENMP) + #if defined(ARMA_RNG_ALT) { - if((N < 512) || omp_in_parallel()) { arma_rng::randn< std::complex >::fill_simple(mem, N); return; } - - typedef std::mt19937_64::result_type seed_type; + for(uword i=0; i < N; ++i) { mem[i] = std::complex( arma_rng::randn< std::complex >() ); } + } + #elif defined(ARMA_USE_CXX11_RNG) + { + std::normal_distribution local_n_distr; - const uword n_threads = uword( mp_thread_limit::get() ); + std::mt19937_64& producer = arma_rng::get_producer(); - std::vector< std::mt19937_64 > engine(n_threads); - std::vector< std::normal_distribution > distr(n_threads); + arma_rng::lock_producer(); - for(uword t=0; t < n_threads; ++t) + for(uword i=0; i < N; ++i) { - std::mt19937_64& t_engine = engine[t]; + const T a = T( local_n_distr(producer) ); + const T b = T( local_n_distr(producer) ); - t_engine.seed( seed_type(t) + seed_type(arma_rng::randi()) ); + mem[i] = std::complex(a,b); } - const uword chunk_size = N / n_threads; - - #pragma omp parallel for schedule(static) num_threads(int(n_threads)) - for(uword t=0; t < n_threads; ++t) + arma_rng::unlock_producer(); + } + #else + { + if(N == uword(1)) { - const uword start = (t+0) * chunk_size; - const uword endp1 = (t+1) * chunk_size; + T a = T(0); + T b = T(0); + + arma_rng_cxx03::randn_dual_val(a,b); - std::mt19937_64& t_engine = engine[t]; - std::normal_distribution& t_distr = distr[t]; + mem[0] = std::complex(a,b); - for(uword i=start; i < endp1; ++i) - { - const T val1 = T( t_distr(t_engine) ); - const T val2 = T( t_distr(t_engine) ); - - mem[i] = std::complex(val1, val2); - } + return; } - std::mt19937_64& t0_engine = engine[0]; - std::normal_distribution& t0_distr = distr[0]; + typedef typename std::mt19937_64::result_type local_seed_type; - for(uword i=(n_threads*chunk_size); i < N; ++i) + std::mt19937_64 local_engine; + std::normal_distribution local_n_distr; + + local_engine.seed( local_seed_type(std::rand()) ); + + for(uword i=0; i < N; ++i) { - const T val1 = T( t0_distr(t0_engine) ); - const T val2 = T( t0_distr(t0_engine) ); + const T a = T( local_n_distr(local_engine) ); + const T b = T( local_n_distr(local_engine) ); - mem[i] = std::complex(val1, val2); + mem[i] = std::complex(a,b); } } + #endif + } + + + inline + static + void + fill(std::complex* mem, const uword N, const double mu, const double sd) + { + arma_rng::randn< std::complex >::fill(mem, N); + + if( (mu == double(0)) && (sd == double(1)) ) { return; } + + for(uword i=0; i& val = mem[i]; + + mem[i] = std::complex( ((val.real() * sd) + mu), ((val.imag() * sd) + mu) ); + } + } + }; + + + +// + + + +template +struct arma_rng::randg + { + inline + static + void + fill(eT* mem, const uword N, const double a, const double b) + { + #if defined(ARMA_USE_CXX11_RNG) + { + std::gamma_distribution local_g_distr(a,b); + + std::mt19937_64& producer = arma_rng::get_producer(); + + arma_rng::lock_producer(); + + for(uword i=0; i >::fill_simple(mem, N); + typedef typename std::mt19937_64::result_type local_seed_type; + + std::mt19937_64 local_engine; + std::gamma_distribution local_g_distr(a,b); + + local_engine.seed( local_seed_type(arma_rng::randi()) ); + + for(uword i=0; i= double(1) ); + while( w >= double(1) ); return double( tmp1 * std::sqrt( (double(-2) * std::log(w)) / w) ); } @@ -114,7 +116,7 @@ arma_rng_cxx98::randn_val() template inline void -arma_rng_cxx98::randn_dual_val(eT& out1, eT& out2) +arma_rng_cxx03::randn_dual_val(eT& out1, eT& out2) { // make sure we are internally using at least floats typedef typename promote_type::result eTp; @@ -130,7 +132,7 @@ arma_rng_cxx98::randn_dual_val(eT& out1, eT& out2) w = tmp1*tmp1 + tmp2*tmp2; } - while ( w >= eTp(1) ); + while( w >= eTp(1) ); const eTp k = std::sqrt( (eTp(-2) * std::log(w)) / w); @@ -143,7 +145,7 @@ arma_rng_cxx98::randn_dual_val(eT& out1, eT& out2) template inline void -arma_rng_cxx98::randi_fill(eT* mem, const uword N, const int a, const int b) +arma_rng_cxx03::randi_fill(eT* mem, const uword N, const int a, const int b) { if( (a == 0) && (b == RAND_MAX) ) { @@ -169,7 +171,7 @@ arma_rng_cxx98::randi_fill(eT* mem, const uword N, const int a, const int b) inline int -arma_rng_cxx98::randi_max_val() +arma_rng_cxx03::randi_max_val() { #if (RAND_MAX == 32767) return ( (32767 << 15) + 32767); diff --git a/src/armadillo_bits/arma_rng_cxx11.hpp b/src/armadillo_bits/arma_rng_cxx11.hpp deleted file mode 100644 index d3cd8d84..00000000 --- a/src/armadillo_bits/arma_rng_cxx11.hpp +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -//! \addtogroup arma_rng_cxx11 -//! @{ - - -#if defined(ARMA_USE_CXX11) - - -class arma_rng_cxx11 - { - public: - - typedef std::mt19937_64::result_type seed_type; - - inline void set_seed(const seed_type val); - - arma_inline int randi_val(); - arma_inline double randu_val(); - arma_inline double randn_val(); - - template - arma_inline void randn_dual_val(eT& out1, eT& out2); - - template - inline void randi_fill(eT* mem, const uword N, const int a, const int b); - - inline static int randi_max_val(); - - template - inline void randg_fill_simple(eT* mem, const uword N, const double a, const double b); - - template - inline void randg_fill(eT* mem, const uword N, const double a, const double b); - - - private: - - arma_aligned std::mt19937_64 engine; // typedef for std::mersenne_twister_engine with preset parameters - - arma_aligned std::uniform_int_distribution i_distr; // by default uses a=0, b=std::numeric_limits::max() - - arma_aligned std::uniform_real_distribution u_distr; // by default uses [0,1) interval - - arma_aligned std::normal_distribution n_distr; // by default uses mean=0.0 and stddev=1.0 - }; - - - -inline -void -arma_rng_cxx11::set_seed(const arma_rng_cxx11::seed_type val) - { - engine.seed(val); - - i_distr.reset(); - u_distr.reset(); - n_distr.reset(); - } - - - -arma_inline -int -arma_rng_cxx11::randi_val() - { - return i_distr(engine); - } - - - -arma_inline -double -arma_rng_cxx11::randu_val() - { - return u_distr(engine); - } - - - -arma_inline -double -arma_rng_cxx11::randn_val() - { - return n_distr(engine); - } - - - -template -arma_inline -void -arma_rng_cxx11::randn_dual_val(eT& out1, eT& out2) - { - out1 = eT( n_distr(engine) ); - out2 = eT( n_distr(engine) ); - } - - - -template -inline -void -arma_rng_cxx11::randi_fill(eT* mem, const uword N, const int a, const int b) - { - std::uniform_int_distribution local_i_distr(a, b); - - for(uword i=0; i::max(); - } - - - -template -inline -void -arma_rng_cxx11::randg_fill_simple(eT* mem, const uword N, const double a, const double b) - { - std::gamma_distribution g_distr(a,b); - - for(uword i=0; i -inline -void -arma_rng_cxx11::randg_fill(eT* mem, const uword N, const double a, const double b) - { - #if defined(ARMA_USE_OPENMP) - { - if((N < 512) || omp_in_parallel()) { (*this).randg_fill_simple(mem, N, a, b); return; } - - typedef std::mt19937_64 motor_type; - typedef std::mt19937_64::result_type ovum_type; - typedef std::gamma_distribution distr_type; - - const uword n_threads = uword( mp_thread_limit::get() ); - - std::vector g_motor(n_threads); - std::vector g_distr(n_threads); - - const distr_type g_distr_base(a,b); - - for(uword t=0; t < n_threads; ++t) - { - motor_type& g_motor_t = g_motor[t]; - distr_type& g_distr_t = g_distr[t]; - - g_motor_t.seed( ovum_type(t) + ovum_type((*this).randi_val()) ); - - g_distr_t.param( g_distr_base.param() ); - } - - const uword chunk_size = N / n_threads; - - #pragma omp parallel for schedule(static) num_threads(int(n_threads)) - for(uword t=0; t < n_threads; ++t) - { - const uword start = (t+0) * chunk_size; - const uword endp1 = (t+1) * chunk_size; - - motor_type& g_motor_t = g_motor[t]; - distr_type& g_distr_t = g_distr[t]; - - for(uword i=start; i < endp1; ++i) { mem[i] = eT( g_distr_t(g_motor_t)); } - } - - motor_type& g_motor_0 = g_motor[0]; - distr_type& g_distr_0 = g_distr[0]; - - for(uword i=(n_threads*chunk_size); i < N; ++i) { mem[i] = eT( g_distr_0(g_motor_0)); } - } - #else - { - (*this).randg_fill_simple(mem, N, a, b); - } - #endif - } - - -#endif - - -//! @} diff --git a/src/armadillo_bits/arma_static_check.hpp b/src/armadillo_bits/arma_static_check.hpp index 7207a3e0..4368d46e 100644 --- a/src/armadillo_bits/arma_static_check.hpp +++ b/src/armadillo_bits/arma_static_check.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,49 +20,11 @@ //! @{ +#undef arma_static_check +#define arma_static_check(condition, message) static_assert( !(condition), message ) -template -struct arma_type_check_cxx1998 - { - arma_inline - static - void - apply() - { - static const char - junk[ ERROR___TYPE_MISMATCH_OR_UNSUPPORTED_TYPE ? -1 : +1 ]; - } - }; - - - -template<> -struct arma_type_check_cxx1998 - { - arma_inline - static - void - apply() - { - } - }; - - - -#if defined(ARMA_USE_CXX11) - - #define arma_static_check(condition, message) static_assert( !(condition), #message ) - - #define arma_type_check(condition) static_assert( !(condition), "error: type mismatch or unsupported type" ) - -#else - - #define arma_static_check(condition, message) static const char message[ (condition) ? -1 : +1 ] - - #define arma_type_check(condition) arma_type_check_cxx1998::apply() - -#endif - +#undef arma_type_check +#define arma_type_check(condition) static_assert( !(condition), "error: type mismatch or unsupported type" ) //! @} diff --git a/src/armadillo_bits/arma_str.hpp b/src/armadillo_bits/arma_str.hpp index d50a5dbd..29f84a6d 100644 --- a/src/armadillo_bits/arma_str.hpp +++ b/src/armadillo_bits/arma_str.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,54 +22,57 @@ namespace arma_str { - - #if ( defined(ARMA_USE_CXX11) || defined(ARMA_HAVE_SNPRINTF) ) + class char_buffer + { + public: - #define arma_snprintf std::snprintf + static constexpr uword n_chars_prealloc = 1024; - #else + char* mem = nullptr; + uword n_chars = 0; - // better-than-nothing emulation of C99 snprintf(), - // with correct return value and null-terminated output string. - // note that _snprintf() provided by MS is not a good substitute for snprintf() + char local_mem[n_chars_prealloc]; inline - int - arma_snprintf(char* out, size_t size, const char* fmt, ...) + ~char_buffer() { - size_t i; - - for(i=0; i n_chars_prealloc) { std::free(mem); } - if(size > 0) - out[size-1] = char(0); - - return int(i); + mem = nullptr; + n_chars = 0; } - #endif - - class format - { - public: - - format(const char* in_fmt) - : A(in_fmt) + inline + char_buffer() { + mem = &(local_mem[0]); + n_chars = n_chars_prealloc; + + if(n_chars > 0) { mem[0] = char(0); } } - format(const std::string& in_fmt) - : A(in_fmt) + inline + void + set_size(const uword new_n_chars) { + if(n_chars > n_chars_prealloc) { std::free(mem); } + + mem = (new_n_chars <= n_chars_prealloc) ? &(local_mem[0]) : (char*)std::malloc(new_n_chars); + n_chars = (new_n_chars <= n_chars_prealloc) ? n_chars_prealloc : new_n_chars; + + if(n_chars > 0) { mem[0] = char(0); } } + }; + + + class format + { + public: + const std::string fmt; - const std::string A; + inline format(const char* in_fmt) : fmt(in_fmt) { } + inline format(const std::string& in_fmt) : fmt(in_fmt) { } private: format(); @@ -80,15 +85,11 @@ namespace arma_str { public: - basic_format(const T1& in_A, const T2& in_B) - : A(in_A) - , B(in_B) - { - } - const T1& A; const T2& B; + inline basic_format(const T1& in_A, const T2& in_B) : A(in_A) , B(in_B) { } + private: basic_format(); }; @@ -120,49 +121,30 @@ namespace arma_str std::string str(const basic_format< format, T2>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - std::string out; + char_buffer buf; + + bool status = false; - do + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.c_str(), X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.fmt.c_str(), X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -174,49 +156,30 @@ namespace arma_str std::string str(const basic_format< basic_format< format, T2>, T3>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - + char_buffer buf; std::string out; - do + bool status = false; + + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.A.c_str(), X.A.B, X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.fmt.c_str(), X.A.B, X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -228,49 +191,30 @@ namespace arma_str std::string str(const basic_format< basic_format< basic_format< format, T2>, T3>, T4>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - + char_buffer buf; std::string out; - do + bool status = false; + + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.A.A.c_str(), X.A.A.B, X.A.B, X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.fmt.c_str(), X.A.A.B, X.A.B, X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -282,49 +226,30 @@ namespace arma_str std::string str(const basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - + char_buffer buf; std::string out; - do + bool status = false; + + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.A.A.A.c_str(), X.A.A.A.B, X.A.A.B, X.A.B, X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.fmt.c_str(), X.A.A.A.B, X.A.A.B, X.A.B, X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -336,49 +261,30 @@ namespace arma_str std::string str(const basic_format< basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>, T6>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - + char_buffer buf; std::string out; - do + bool status = false; + + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.A.A.A.A.c_str(), X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.A.fmt.c_str(), X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -390,49 +296,30 @@ namespace arma_str std::string str(const basic_format< basic_format< basic_format< basic_format< basic_format< basic_format< format, T2>, T3>, T4>, T5>, T6>, T7>& X) { - char local_buffer[1024]; - char* buffer = local_buffer; - - int buffer_size = 1024; - int required_size = buffer_size; - - bool using_local_buffer = true; - + char_buffer buf; std::string out; - do + bool status = false; + + while(status == false) { - if(using_local_buffer == false) - { - buffer = new char[size_t(buffer_size)]; - } - - required_size = arma_snprintf(buffer, size_t(buffer_size), X.A.A.A.A.A.A.A.c_str(), X.A.A.A.A.A.B, X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); + int required_size = (std::snprintf)(buf.mem, size_t(buf.n_chars), X.A.A.A.A.A.A.fmt.c_str(), X.A.A.A.A.A.B, X.A.A.A.A.B, X.A.A.A.B, X.A.A.B, X.A.B, X.B); if(required_size < 0) { break; } - if(required_size < buffer_size) + if(uword(required_size) >= buf.n_chars) { - if(required_size > 0) - { - out = buffer; - } + if(buf.n_chars > char_buffer::n_chars_prealloc) { break; } + + buf.set_size(1 + uword(required_size)); } else { - buffer_size *= 2; + status = true; } - if(using_local_buffer) - { - using_local_buffer = false; - } - else - { - delete[] buffer; - } - - } while( (required_size >= buffer_size) ); + if(status) { out = buf.mem; } + } return out; } @@ -442,7 +329,7 @@ namespace arma_str template struct format_metaprog { - static const uword depth = 0; + static constexpr uword depth = 0; inline static @@ -459,7 +346,7 @@ namespace arma_str template struct format_metaprog< basic_format > { - static const uword depth = 1 + format_metaprog::depth; + static constexpr uword depth = 1 + format_metaprog::depth; inline static @@ -511,7 +398,7 @@ namespace arma_str inline static const T1& - str_wrapper(const T1& x, const typename string_only::result* junk = 0) + str_wrapper(const T1& x, const typename string_only::result* junk = nullptr) { arma_ignore(junk); @@ -524,7 +411,7 @@ namespace arma_str inline static const T1* - str_wrapper(const T1* x, const typename char_only::result* junk = 0) + str_wrapper(const T1* x, const typename char_only::result* junk = nullptr) { arma_ignore(junk); @@ -537,7 +424,7 @@ namespace arma_str inline static std::string - str_wrapper(const T1& x, const typename basic_format_only::result* junk = 0) + str_wrapper(const T1& x, const typename basic_format_only::result* junk = nullptr) { arma_ignore(junk); diff --git a/src/armadillo_bits/arma_version.hpp b/src/armadillo_bits/arma_version.hpp index d9fbbf6a..a335bb3b 100644 --- a/src/armadillo_bits/arma_version.hpp +++ b/src/armadillo_bits/arma_version.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,18 +21,18 @@ -#define ARMA_VERSION_MAJOR 9 -#define ARMA_VERSION_MINOR 800 -#define ARMA_VERSION_PATCH 1 -#define ARMA_VERSION_NAME "Horizon Scraper" +#define ARMA_VERSION_MAJOR 12 +#define ARMA_VERSION_MINOR 6 +#define ARMA_VERSION_PATCH 7 +#define ARMA_VERSION_NAME "Cortisol Retox" struct arma_version { - static const unsigned int major = ARMA_VERSION_MAJOR; - static const unsigned int minor = ARMA_VERSION_MINOR; - static const unsigned int patch = ARMA_VERSION_PATCH; + static constexpr unsigned int major = ARMA_VERSION_MAJOR; + static constexpr unsigned int minor = ARMA_VERSION_MINOR; + static constexpr unsigned int patch = ARMA_VERSION_PATCH; static inline diff --git a/src/armadillo_bits/arrayops_bones.hpp b/src/armadillo_bits/arrayops_bones.hpp index cc14e4dd..0beec3ae 100644 --- a/src/armadillo_bits/arrayops_bones.hpp +++ b/src/armadillo_bits/arrayops_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,48 +25,47 @@ class arrayops public: template - arma_hot arma_inline static void + arma_inline static void copy(eT* dest, const eT* src, const uword n_elem); - - template - arma_cold inline static void - copy_small(eT* dest, const eT* src, const uword n_elem); - - template - arma_hot inline static void + inline static void fill_zeros(eT* dest, const uword n_elem); - template arma_hot inline static void replace(eT* mem, const uword n_elem, const eT old_val, const eT new_val); - template arma_hot inline static void - clean(eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk = 0); - + clean(eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk = nullptr); template arma_hot inline static void clean(std::complex* mem, const uword n_elem, const T abs_limit); + template + inline static void + clamp(eT* mem, const uword n_elem, const eT min_val, const eT max_val, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void + clamp(std::complex* mem, const uword n_elem, const std::complex& min_val, const std::complex& max_val); + // // array = convert(array) template - arma_hot arma_inline static void - convert_cx_scalar(out_eT& out, const in_eT& in, const typename arma_not_cx::result* junk1 = 0, const typename arma_not_cx< in_eT>::result* junk2 = 0); + arma_inline static void + convert_cx_scalar(out_eT& out, const in_eT& in, const typename arma_not_cx::result* junk1 = nullptr, const typename arma_not_cx< in_eT>::result* junk2 = nullptr); template - arma_hot arma_inline static void - convert_cx_scalar(out_eT& out, const std::complex& in, const typename arma_not_cx::result* junk = 0); + arma_inline static void + convert_cx_scalar(out_eT& out, const std::complex& in, const typename arma_not_cx::result* junk = nullptr); template - arma_hot arma_inline static void + arma_inline static void convert_cx_scalar(std::complex& out, const std::complex< in_T>& in); template @@ -132,12 +133,12 @@ class arrayops template arma_hot inline static void - inplace_set_base(eT* dest, const eT val, const uword n_elem); + inplace_set_simple(eT* dest, const eT val, const uword n_elem); template - arma_cold inline static + arma_hot inline static void - inplace_set_small(eT* dest, const eT val, const uword n_elem); + inplace_set_base(eT* dest, const eT val, const uword n_elem); template arma_hot inline static @@ -197,6 +198,16 @@ class arrayops eT product(const eT* src, const uword n_elem); + template + arma_hot inline static + bool + is_zero(const eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk = nullptr); + + template + arma_hot inline static + bool + is_zero(const std::complex* mem, const uword n_elem, const T abs_limit); + template arma_hot inline static bool diff --git a/src/armadillo_bits/arrayops_meat.hpp b/src/armadillo_bits/arrayops_meat.hpp index b5b8b95d..57f1a1d4 100644 --- a/src/armadillo_bits/arrayops_meat.hpp +++ b/src/armadillo_bits/arrayops_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,75 +22,39 @@ template -arma_hot arma_inline void arrayops::copy(eT* dest, const eT* src, const uword n_elem) { - if(is_cx::no) - { - if(n_elem <= 9) - { - arrayops::copy_small(dest, src, n_elem); - } - else - { - std::memcpy(dest, src, n_elem*sizeof(eT)); - } - } - else - { - if(n_elem > 0) { std::memcpy(dest, src, n_elem*sizeof(eT)); } - } - } - - - -template -arma_cold -inline -void -arrayops::copy_small(eT* dest, const eT* src, const uword n_elem) - { - switch(n_elem) - { - case 9: dest[ 8] = src[ 8]; - // fallthrough - case 8: dest[ 7] = src[ 7]; - // fallthrough - case 7: dest[ 6] = src[ 6]; - // fallthrough - case 6: dest[ 5] = src[ 5]; - // fallthrough - case 5: dest[ 4] = src[ 4]; - // fallthrough - case 4: dest[ 3] = src[ 3]; - // fallthrough - case 3: dest[ 2] = src[ 2]; - // fallthrough - case 2: dest[ 1] = src[ 1]; - // fallthrough - case 1: dest[ 0] = src[ 0]; - // fallthrough - default: ; - } + if( (dest == src) || (n_elem == 0) ) { return; } + + std::memcpy(dest, src, n_elem*sizeof(eT)); } template -arma_hot inline void arrayops::fill_zeros(eT* dest, const uword n_elem) { - arrayops::inplace_set(dest, eT(0), n_elem); + typedef typename get_pod_type::result pod_type; + + if(n_elem == 0) { return; } + + if(std::numeric_limits::is_integer || std::numeric_limits::is_iec559) + { + std::memset((void*)dest, 0, sizeof(eT)*n_elem); + } + else + { + arrayops::inplace_set_simple(dest, eT(0), n_elem); + } } template -arma_hot inline void arrayops::replace(eT* mem, const uword n_elem, const eT old_val, const eT new_val) @@ -116,7 +82,6 @@ arrayops::replace(eT* mem, const uword n_elem, const eT old_val, const eT new_va template -arma_hot inline void arrayops::clean(eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk) @@ -127,14 +92,13 @@ arrayops::clean(eT* mem, const uword n_elem, const eT abs_limit, const typename { eT& val = mem[i]; - val = (std::abs(val) <= abs_limit) ? eT(0) : val; + val = (eop_aux::arma_abs(val) <= abs_limit) ? eT(0) : val; } } template -arma_hot inline void arrayops::clean(std::complex* mem, const uword n_elem, const T abs_limit) @@ -164,8 +128,53 @@ arrayops::clean(std::complex* mem, const uword n_elem, const T abs_limit) +template +inline +void +arrayops::clamp(eT* mem, const uword n_elem, const eT min_val, const eT max_val, const typename arma_not_cx::result* junk) + { + arma_ignore(junk); + + for(uword i=0; i max_val) ? max_val : val); + } + } + + + +template +inline +void +arrayops::clamp(std::complex* mem, const uword n_elem, const std::complex& min_val, const std::complex& max_val) + { + typedef typename std::complex eT; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + val = std::complex(val_real,val_imag); + } + } + + + template -arma_hot arma_inline void arrayops::convert_cx_scalar @@ -185,7 +194,6 @@ arrayops::convert_cx_scalar template -arma_hot arma_inline void arrayops::convert_cx_scalar @@ -197,13 +205,16 @@ arrayops::convert_cx_scalar { arma_ignore(junk); - out = out_eT( in.real() ); + const in_T val = in.real(); + + const bool conversion_ok = (std::is_integral::value && std::is_floating_point::value) ? arma_isfinite(val) : true; + + out = conversion_ok ? out_eT(val) : out_eT(0); } template -arma_hot arma_inline void arrayops::convert_cx_scalar @@ -220,7 +231,6 @@ arrayops::convert_cx_scalar template -arma_hot inline void arrayops::convert(out_eT* dest, const in_eT* src, const uword n_elem) @@ -234,6 +244,7 @@ arrayops::convert(out_eT* dest, const in_eT* src, const uword n_elem) return; } + const bool check_finite = (std::is_integral::value && std::is_floating_point::value); uword j; @@ -245,15 +256,26 @@ arrayops::convert(out_eT* dest, const in_eT* src, const uword n_elem) // dest[i] = out_eT( tmp_i ); // dest[j] = out_eT( tmp_j ); - (*dest) = (is_signed::value) - ? out_eT( tmp_i ) - : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ); + const bool ok_i = check_finite ? arma_isfinite(tmp_i) : true; + const bool ok_j = check_finite ? arma_isfinite(tmp_j) : true; + + (*dest) = ok_i + ? ( + (is_signed::value) + ? out_eT( tmp_i ) + : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ) + ) + : out_eT(0); dest++; - (*dest) = (is_signed::value) - ? out_eT( tmp_j ) - : ( cond_rel< is_signed::value >::lt(tmp_j, in_eT(0)) ? out_eT(0) : out_eT(tmp_j) ); + (*dest) = ok_j + ? ( + (is_signed::value) + ? out_eT( tmp_j ) + : ( cond_rel< is_signed::value >::lt(tmp_j, in_eT(0)) ? out_eT(0) : out_eT(tmp_j) ) + ) + : out_eT(0); dest++; } @@ -263,16 +285,21 @@ arrayops::convert(out_eT* dest, const in_eT* src, const uword n_elem) // dest[i] = out_eT( tmp_i ); - (*dest) = (is_signed::value) - ? out_eT( tmp_i ) - : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ); + const bool ok_i = check_finite ? arma_isfinite(tmp_i) : true; + + (*dest) = ok_i + ? ( + (is_signed::value) + ? out_eT( tmp_i ) + : ( cond_rel< is_signed::value >::lt(tmp_i, in_eT(0)) ? out_eT(0) : out_eT(tmp_i) ) + ) + : out_eT(0); } } template -arma_hot inline void arrayops::convert_cx(out_eT* dest, const in_eT* src, const uword n_elem) @@ -294,7 +321,6 @@ arrayops::convert_cx(out_eT* dest, const in_eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_plus(eT* dest, const eT* src, const uword n_elem) @@ -332,7 +358,6 @@ arrayops::inplace_plus(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_minus(eT* dest, const eT* src, const uword n_elem) @@ -370,7 +395,6 @@ arrayops::inplace_minus(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_mul(eT* dest, const eT* src, const uword n_elem) @@ -408,7 +432,6 @@ arrayops::inplace_mul(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_div(eT* dest, const eT* src, const uword n_elem) @@ -446,7 +469,6 @@ arrayops::inplace_div(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_plus_base(eT* dest, const eT* src, const uword n_elem) @@ -482,7 +504,6 @@ arrayops::inplace_plus_base(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_minus_base(eT* dest, const eT* src, const uword n_elem) @@ -518,7 +539,6 @@ arrayops::inplace_minus_base(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_mul_base(eT* dest, const eT* src, const uword n_elem) @@ -554,7 +574,6 @@ arrayops::inplace_mul_base(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_div_base(eT* dest, const eT* src, const uword n_elem) @@ -590,43 +609,42 @@ arrayops::inplace_div_base(eT* dest, const eT* src, const uword n_elem) template -arma_hot inline void arrayops::inplace_set(eT* dest, const eT val, const uword n_elem) { - typedef typename get_pod_type::result pod_type; - - if( (n_elem <= 9) && (is_cx::no) ) + if(val == eT(0)) { - arrayops::inplace_set_small(dest, val, n_elem); + arrayops::fill_zeros(dest, n_elem); } else { - if( (val == eT(0)) && (std::numeric_limits::is_integer || (std::numeric_limits::is_iec559 && is_real::value)) ) - { - if(n_elem > 0) { std::memset((void*)dest, 0, sizeof(eT)*n_elem); } - } - else - { - if(memory::is_aligned(dest)) - { - memory::mark_as_aligned(dest); - - arrayops::inplace_set_base(dest, val, n_elem); - } - else - { - arrayops::inplace_set_base(dest, val, n_elem); - } - } + arrayops::inplace_set_simple(dest, val, n_elem); + } + } + + + +template +inline +void +arrayops::inplace_set_simple(eT* dest, const eT val, const uword n_elem) + { + if(memory::is_aligned(dest)) + { + memory::mark_as_aligned(dest); + + arrayops::inplace_set_base(dest, val, n_elem); + } + else + { + arrayops::inplace_set_base(dest, val, n_elem); } } template -arma_hot inline void arrayops::inplace_set_base(eT* dest, const eT val, const uword n_elem) @@ -658,40 +676,7 @@ arrayops::inplace_set_base(eT* dest, const eT val, const uword n_elem) -template -arma_cold -inline -void -arrayops::inplace_set_small(eT* dest, const eT val, const uword n_elem) - { - switch(n_elem) - { - case 9: dest[ 8] = val; - // fallthrough - case 8: dest[ 7] = val; - // fallthrough - case 7: dest[ 6] = val; - // fallthrough - case 6: dest[ 5] = val; - // fallthrough - case 5: dest[ 4] = val; - // fallthrough - case 4: dest[ 3] = val; - // fallthrough - case 3: dest[ 2] = val; - // fallthrough - case 2: dest[ 1] = val; - // fallthrough - case 1: dest[ 0] = val; - // fallthrough - default:; - } - } - - - template -arma_hot inline void arrayops::inplace_set_fixed(eT* dest, const eT val) @@ -705,7 +690,6 @@ arrayops::inplace_set_fixed(eT* dest, const eT val) template -arma_hot inline void arrayops::inplace_plus(eT* dest, const eT val, const uword n_elem) @@ -725,7 +709,6 @@ arrayops::inplace_plus(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_minus(eT* dest, const eT val, const uword n_elem) @@ -745,7 +728,6 @@ arrayops::inplace_minus(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_mul(eT* dest, const eT val, const uword n_elem) @@ -765,7 +747,6 @@ arrayops::inplace_mul(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_div(eT* dest, const eT val, const uword n_elem) @@ -785,7 +766,6 @@ arrayops::inplace_div(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_plus_base(eT* dest, const eT val, const uword n_elem) @@ -818,7 +798,6 @@ arrayops::inplace_plus_base(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_minus_base(eT* dest, const eT val, const uword n_elem) @@ -851,7 +830,6 @@ arrayops::inplace_minus_base(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_mul_base(eT* dest, const eT val, const uword n_elem) @@ -884,7 +862,6 @@ arrayops::inplace_mul_base(eT* dest, const eT val, const uword n_elem) template -arma_hot inline void arrayops::inplace_div_base(eT* dest, const eT val, const uword n_elem) @@ -917,12 +894,11 @@ arrayops::inplace_div_base(eT* dest, const eT val, const uword n_elem) template -arma_hot inline eT arrayops::accumulate(const eT* src, const uword n_elem) { - #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0) + #if defined(__FAST_MATH__) { eT acc = eT(0); @@ -964,7 +940,6 @@ arrayops::accumulate(const eT* src, const uword n_elem) template -arma_hot inline eT arrayops::product(const eT* src, const uword n_elem) @@ -991,7 +966,70 @@ arrayops::product(const eT* src, const uword n_elem) template -arma_hot +inline +bool +arrayops::is_zero(const eT* mem, const uword n_elem, const eT abs_limit, const typename arma_not_cx::result* junk) + { + arma_ignore(junk); + + if(n_elem == 0) { return false; } + + if(abs_limit == eT(0)) + { + for(uword i=0; i abs_limit) { return false; } + } + } + + return true; + } + + + +template +inline +bool +arrayops::is_zero(const std::complex* mem, const uword n_elem, const T abs_limit) + { + typedef typename std::complex eT; + + if(n_elem == 0) { return false; } + + if(abs_limit == T(0)) + { + for(uword i=0; i abs_limit) { return false; } + if(std::abs(std::imag(val)) > abs_limit) { return false; } + } + } + + return true; + } + + + +template inline bool arrayops::is_finite(const eT* src, const uword n_elem) @@ -1003,18 +1041,13 @@ arrayops::is_finite(const eT* src, const uword n_elem) const eT val_i = (*src); src++; const eT val_j = (*src); src++; - if( (arma_isfinite(val_i) == false) || (arma_isfinite(val_j) == false) ) - { - return false; - } + if(arma_isfinite(val_i) == false) { return false; } + if(arma_isfinite(val_j) == false) { return false; } } if((j-1) < n_elem) { - if(arma_isfinite(*src) == false) - { - return false; - } + if(arma_isfinite(*src) == false) { return false; } } return true; @@ -1023,7 +1056,6 @@ arrayops::is_finite(const eT* src, const uword n_elem) template -arma_hot inline bool arrayops::has_inf(const eT* src, const uword n_elem) @@ -1049,7 +1081,6 @@ arrayops::has_inf(const eT* src, const uword n_elem) template -arma_hot inline bool arrayops::has_nan(const eT* src, const uword n_elem) diff --git a/src/armadillo_bits/auxlib_bones.hpp b/src/armadillo_bits/auxlib_bones.hpp index fa352333..63292dd9 100644 --- a/src/armadillo_bits/auxlib_bones.hpp +++ b/src/armadillo_bits/auxlib_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,59 +20,53 @@ //! @{ -//! interface functions for accessing decompositions in LAPACK and ATLAS +//! low-level interface functions for accessing LAPACK class auxlib { public: + // + // inv - template - struct pos - { - static const uword n2 = row + col*2; - static const uword n3 = row + col*3; - static const uword n4 = row + col*4; - }; + template + inline static bool inv(Mat& A); + template + inline static bool inv(Mat& out, const Mat& X); - // - // inv + template + inline static bool inv_rcond(Mat& A, typename get_pod_type::result& out_rcond); template - inline static bool inv(Mat& out, const Mat& A); + inline static bool inv_tr(Mat& A, const uword layout); template - arma_cold inline static bool inv_tiny(Mat& out, const Mat& X); + inline static bool inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, const uword layout); - template - inline static bool inv_tr(Mat& out, const Base& X, const uword layout); + template + inline static bool inv_sympd(Mat& A, bool& out_sympd_state); - template - inline static bool inv_sympd(Mat& out, const Base& X); + template + inline static bool inv_sympd(Mat& out, const Mat& X); template - arma_cold inline static bool inv_sympd_tiny(Mat& out, const Mat& X); + inline static bool inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond); + template + inline static bool inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond); // - // det - - template - inline static eT det(const Base& X); + // det and log_det template - arma_cold inline static eT det_tinymat(const Mat& X, const uword N); + inline static bool det(eT& out_val, Mat& A); template - inline static eT det_lapack(const Mat& X, const bool make_copy); - - - // - // log_det + inline static bool log_det(eT& out_val, typename get_pod_type::result& out_sign, Mat& A); - template - inline static bool log_det(eT& out_val, typename get_pod_type::result& out_sign, const Base& X); + template + inline static bool log_det_sympd(typename get_pod_type::result& out_val, Mat& A); // @@ -106,6 +102,26 @@ class auxlib inline static bool eig_gen_balance(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base< std::complex, T1 >& expr); + // + // eig_gen_twosided + + template + inline static bool eig_gen_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& expr); + + template + inline static bool eig_gen_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& expr); + + + // + // eig_gen_twosided_balance + + template + inline static bool eig_gen_twosided_balance(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& expr); + + template + inline static bool eig_gen_twosided_balance(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& expr); + + // // eig_pair @@ -116,14 +132,24 @@ class auxlib inline static bool eig_pair(Mat< std::complex >& vals, Mat< std::complex >& vecs, const bool vecs_on, const Base< std::complex, T1 >& A_expr, const Base< std::complex, T2 >& B_expr); + // + // eig_pair_twosided + + template + inline static bool eig_pair_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base& A_expr, const Base& B_expr); + + template + inline static bool eig_pair_twosided(Mat< std::complex >& vals, Mat< std::complex >& lvecs, Mat< std::complex >& rvecs, const Base< std::complex, T1 >& A_expr, const Base< std::complex, T2 >& B_expr); + + // // eig_sym - template - inline static bool eig_sym(Col& eigval, const Base& X); + template + inline static bool eig_sym(Col& eigval, Mat& A); - template - inline static bool eig_sym(Col& eigval, const Base,T1>& X); + template + inline static bool eig_sym(Col& eigval, Mat< std::complex >& A); template inline static bool eig_sym(Col& eigval, Mat& eigvec, const Mat& X); @@ -156,14 +182,17 @@ class auxlib template inline static bool chol_band_common(Mat& X, const uword KD, const uword layout); - + template + inline static bool chol_pivot(Mat& X, Mat& P, const uword layout); + + // // hessenberg decomposition - + template inline static bool hess(Mat& H, const Base& X, Col& tao); - - + + // // qr @@ -173,78 +202,70 @@ class auxlib template inline static bool qr_econ(Mat& Q, Mat& R, const Base& X); - - // - // svd - template - inline static bool svd(Col& S, const Base& X, uword& n_rows, uword& n_cols); + inline static bool qr_pivot(Mat& Q, Mat& R, Mat& P, const Base& X); - template - inline static bool svd(Col& S, const Base, T1>& X, uword& n_rows, uword& n_cols); + template + inline static bool qr_pivot(Mat< std::complex >& Q, Mat< std::complex >& R, Mat& P, const Base,T1>& X); - template - inline static bool svd(Col& S, const Base& X); - template - inline static bool svd(Col& S, const Base, T1>& X); + // + // svd - template - inline static bool svd(Mat& U, Col& S, Mat& V, const Base& X); + template + inline static bool svd(Col& S, Mat& A); - template - inline static bool svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X); + template + inline static bool svd(Col& S, Mat< std::complex >& A); - template - inline static bool svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const char mode); - template - inline static bool svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X, const char mode); + template + inline static bool svd(Mat& U, Col& S, Mat& V, Mat& A); + template + inline static bool svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); - template - inline static bool svd_dc(Col& S, const Base& X, uword& n_rows, uword& n_cols); + template + inline static bool svd_econ(Mat& U, Col& S, Mat& V, Mat& A, const char mode); - template - inline static bool svd_dc(Col& S, const Base, T1>& X, uword& n_rows, uword& n_cols); + template + inline static bool svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A, const char mode); - template - inline static bool svd_dc(Col& S, const Base& X); - template - inline static bool svd_dc(Col& S, const Base, T1>& X); + template + inline static bool svd_dc(Col& S, Mat& A); + template + inline static bool svd_dc(Col& S, Mat< std::complex >& A); - template - inline static bool svd_dc(Mat& U, Col& S, Mat& V, const Base& X); - template - inline static bool svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X); + template + inline static bool svd_dc(Mat& U, Col& S, Mat& V, Mat& A); - template - inline static bool svd_dc_econ(Mat& U, Col& S, Mat& V, const Base& X); + template + inline static bool svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); - template - inline static bool svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X); + template + inline static bool svd_dc_econ(Mat& U, Col& S, Mat& V, Mat& A); + + template + inline static bool svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A); // // solve - template - arma_cold inline static bool solve_square_tiny(Mat& out, const Mat& A, const Base& B_expr); - template inline static bool solve_square_fast(Mat& out, Mat& A, const Base& B_expr); template - inline static bool solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool allow_ugly); + inline static bool solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); template - inline static bool solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate); template - inline static bool solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate); // @@ -255,21 +276,26 @@ class auxlib inline static bool solve_sympd_fast_common(Mat& out, Mat& A, const Base& B_expr); template - inline static bool solve_sympd_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool allow_ugly); + inline static bool solve_sympd_rcond(Mat& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); template - inline static bool solve_sympd_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr, const bool allow_ugly); + inline static bool solve_sympd_rcond(Mat< std::complex >& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr); template - inline static bool solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate); template - inline static bool solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate); // template - inline static bool solve_approx_fast(Mat& out, Mat& A, const Base& B_expr); + inline static bool solve_rect_fast(Mat& out, Mat& A, const Base& B_expr); + + template + inline static bool solve_rect_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr); + + // template inline static bool solve_approx_svd(Mat& out, Mat& A, const Base& B_expr); @@ -283,7 +309,7 @@ class auxlib inline static bool solve_trimat_fast(Mat& out, const Mat& A, const Base& B_expr, const uword layout); template - inline static bool solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout, const bool allow_ugly); + inline static bool solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout); // @@ -297,19 +323,19 @@ class auxlib inline static bool solve_band_fast_common(Mat& out, const Mat& A, const uword KL, const uword KU, const Base& B_expr); template - inline static bool solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool allow_ugly); + inline static bool solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr); template - inline static bool solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr, const bool allow_ugly); + inline static bool solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr); template - inline static bool solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool allow_ugly); + inline static bool solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr); template - inline static bool solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate); template - inline static bool solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly); + inline static bool solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate); // @@ -330,13 +356,13 @@ class auxlib inline static bool schur(Mat& U, Mat& S, const Base& X, const bool calc_U = true); template - inline static bool schur(Mat >& U, Mat >& S, const Base,T1>& X, const bool calc_U = true); + inline static bool schur(Mat< std::complex >& U, Mat< std::complex >& S, const Base,T1>& X, const bool calc_U = true); template - inline static bool schur(Mat >& U, Mat >& S, const bool calc_U = true); + inline static bool schur(Mat< std::complex >& U, Mat< std::complex >& S, const bool calc_U = true); // - // syl (solution of the Sylvester equation AX + XB = C) + // solve the Sylvester equation AX + XB = C template inline static bool syl(Mat& X, const Mat& A, const Mat& B, const Mat& C); @@ -402,14 +428,20 @@ class auxlib template inline static bool crippled_lapack(const Base&); - template - inline static typename T1::pod_type epsilon_lapack(const Base&); - template inline static bool rudimentary_sym_check(const Mat& X); template inline static bool rudimentary_sym_check(const Mat< std::complex >& X); + + template + inline static typename get_pod_type::result norm1_gen(const Mat& A); + + template + inline static typename get_pod_type::result norm1_sym(const Mat& A); + + template + inline static typename get_pod_type::result norm1_band(const Mat& A, const uword KL, const uword KU); }; diff --git a/src/armadillo_bits/auxlib_meat.hpp b/src/armadillo_bits/auxlib_meat.hpp index 9d9edc74..373aaa4e 100644 --- a/src/armadillo_bits/auxlib_meat.hpp +++ b/src/armadillo_bits/auxlib_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,49 +24,35 @@ template inline bool -auxlib::inv(Mat& out, const Mat& A) +auxlib::inv(Mat& A) { arma_extra_debug_sigprint(); - out = A; - - if(out.is_empty()) { return true; } + if(A.is_empty()) { return true; } - #if defined(ARMA_USE_ATLAS) + #if defined(ARMA_USE_LAPACK) { - arma_debug_assert_atlas_size(out); + arma_debug_assert_blas_size(A); - podarray ipiv(out.n_rows); + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; - int info = 0; + podarray ipiv(A.n_rows); - arma_extra_debug_print("atlas::clapack_getrf()"); - info = atlas::clapack_getrf(atlas::CblasColMajor, out.n_rows, out.n_cols, out.memptr(), out.n_rows, ipiv.memptr()); + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n, &n, A.memptr(), &lda, ipiv.memptr(), &info); if(info != 0) { return false; } - arma_extra_debug_print("atlas::clapack_getri()"); - info = atlas::clapack_getri(atlas::CblasColMajor, out.n_rows, out.memptr(), out.n_rows, ipiv.memptr()); - - return (info == 0); - } - #elif defined(ARMA_USE_LAPACK) - { - arma_debug_assert_blas_size(out); - - blas_int n_rows = blas_int(out.n_rows); - blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n_rows); - blas_int info = 0; - - podarray ipiv(out.n_rows); - - if(n_rows > 16) + if(n > 16) { - eT work_query[2]; - blas_int lwork_query = -1; + eT work_query[2] = {}; + blas_int lwork_query = -1; arma_extra_debug_print("lapack::getri()"); - lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), &work_query[0], &lwork_query, &info); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); if(info != 0) { return false; } @@ -75,20 +63,15 @@ auxlib::inv(Mat& out, const Mat& A) podarray work( static_cast(lwork) ); - arma_extra_debug_print("lapack::getrf()"); - lapack::getrf(&n_rows, &n_rows, out.memptr(), &n_rows, ipiv.memptr(), &info); - - if(info != 0) { return false; } - arma_extra_debug_print("lapack::getri()"); - lapack::getri(&n_rows, out.memptr(), &n_rows, ipiv.memptr(), work.memptr(), &lwork, &info); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); return (info == 0); } #else { - out.soft_reset(); - arma_stop_logic_error("inv(): use of ATLAS or LAPACK must be enabled"); + arma_ignore(A); + arma_stop_logic_error("inv(): use of LAPACK must be enabled"); return false; } #endif @@ -97,174 +80,117 @@ auxlib::inv(Mat& out, const Mat& A) template -arma_cold inline bool -auxlib::inv_tiny(Mat& out, const Mat& X) +auxlib::inv(Mat& out, const Mat& X) { arma_extra_debug_sigprint(); - const uword N = X.n_rows; + out = X; - out.set_size(N,N); + return auxlib::inv(out); + } + + + +template +inline +bool +auxlib::inv_rcond(Mat& A, typename get_pod_type::result& out_rcond) + { + arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; - const T det_min = std::numeric_limits::epsilon(); + out_rcond = T(0); - bool calc_ok = false; + if(A.is_empty()) { return true; } - const eT* Xm = X.memptr(); - eT* outm = out.memptr(); - - switch(N) + #if defined(ARMA_USE_LAPACK) { - case 0: - calc_ok = true; - break; + arma_debug_assert_blas_size(A); + + char norm_id = '1'; + blas_int n = blas_int(A.n_rows); + blas_int lda = blas_int(A.n_rows); + blas_int lwork = (std::max)(blas_int(podarray_prealloc_n_elem::val), n); + blas_int info = 0; + T norm_val = T(0); + + podarray junk(1); + podarray ipiv(A.n_rows); + + arma_extra_debug_print("lapack::lange()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &n, &n, A.memptr(), &lda, junk.memptr()); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n, &n, A.memptr(), &lda, ipiv.memptr(), &info); + + if(info != 0) { return false; } + + out_rcond = auxlib::lu_rcond(A, norm_val); - case 1: + if(n > 16) { - outm[0] = eT(1) / Xm[0]; + eT work_query[2] = {}; + blas_int lwork_query = -1; - calc_ok = true; - }; - break; + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), &work_query[0], &lwork_query, &info); - case 2: - { - const eT a = Xm[pos<0,0>::n2]; - const eT b = Xm[pos<0,1>::n2]; - const eT c = Xm[pos<1,0>::n2]; - const eT d = Xm[pos<1,1>::n2]; + if(info != 0) { return false; } - const eT det_val = (a*d - b*c); + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); - if(std::abs(det_val) >= det_min) - { - outm[pos<0,0>::n2] = d / det_val; - outm[pos<0,1>::n2] = -b / det_val; - outm[pos<1,0>::n2] = -c / det_val; - outm[pos<1,1>::n2] = a / det_val; - - calc_ok = true; - } - }; - break; + lwork = (std::max)(lwork_proposed, lwork); + } - case 3: - { - const eT det_val = auxlib::det_tinymat(X,3); - - if(std::abs(det_val) >= det_min) - { - outm[pos<0,0>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]) / det_val; - outm[pos<1,0>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,2>::n3]) / det_val; - outm[pos<2,0>::n3] = (Xm[pos<2,1>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,1>::n3]) / det_val; - - outm[pos<0,1>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]) / det_val; - outm[pos<1,1>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,2>::n3]) / det_val; - outm[pos<2,1>::n3] = -(Xm[pos<2,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,1>::n3]) / det_val; - - outm[pos<0,2>::n3] = (Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]) / det_val; - outm[pos<1,2>::n3] = -(Xm[pos<1,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,2>::n3]) / det_val; - outm[pos<2,2>::n3] = (Xm[pos<1,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,1>::n3]) / det_val; - - const eT check_val = Xm[pos<0,0>::n3]*outm[pos<0,0>::n3] + Xm[pos<0,1>::n3]*outm[pos<1,0>::n3] + Xm[pos<0,2>::n3]*outm[pos<2,0>::n3]; - - const T max_diff = (is_float::value) ? T(1e-4) : T(1e-10); // empirically determined; may need tuning - - if(std::abs(T(1) - check_val) < max_diff) { calc_ok = true; } - } - }; - break; + podarray work( static_cast(lwork) ); - case 4: - { - const eT det_val = auxlib::det_tinymat(X,4); - - if(std::abs(det_val) >= det_min) - { - outm[pos<0,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<1,0>::n4] = ( Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<2,0>::n4] = ( Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<3,0>::n4] = ( Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / det_val; - - outm[pos<0,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<1,1>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<2,1>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,3>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<3,1>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<2,2>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<2,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<2,1>::n4]*Xm[pos<3,2>::n4] ) / det_val; - - outm[pos<0,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<1,2>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<2,2>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<3,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,3>::n4] ) / det_val; - outm[pos<3,2>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<3,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<3,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<3,2>::n4] ) / det_val; - - outm[pos<0,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / det_val; - outm[pos<1,3>::n4] = ( Xm[pos<0,2>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,2>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,3>::n4] ) / det_val; - outm[pos<2,3>::n4] = ( Xm[pos<0,3>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,3>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,3>::n4]*Xm[pos<2,1>::n4] + Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,3>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,3>::n4] ) / det_val; - outm[pos<3,3>::n4] = ( Xm[pos<0,1>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,0>::n4] - Xm[pos<0,2>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,0>::n4] + Xm[pos<0,2>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,0>::n4]*Xm[pos<1,2>::n4]*Xm[pos<2,1>::n4] - Xm[pos<0,1>::n4]*Xm[pos<1,0>::n4]*Xm[pos<2,2>::n4] + Xm[pos<0,0>::n4]*Xm[pos<1,1>::n4]*Xm[pos<2,2>::n4] ) / det_val; - - const eT check_val = Xm[pos<0,0>::n4]*outm[pos<0,0>::n4] + Xm[pos<0,1>::n4]*outm[pos<1,0>::n4] + Xm[pos<0,2>::n4]*outm[pos<2,0>::n4] + Xm[pos<0,3>::n4]*outm[pos<3,0>::n4]; - - const T max_diff = (is_float::value) ? T(1e-4) : T(1e-10); // empirically determined; may need tuning - - if(std::abs(T(1) - check_val) < max_diff) { calc_ok = true; } - } - }; - break; + arma_extra_debug_print("lapack::getri()"); + lapack::getri(&n, A.memptr(), &lda, ipiv.memptr(), work.memptr(), &lwork, &info); - default: - ; + return (info == 0); } - - return calc_ok; + #else + { + arma_ignore(A); + arma_stop_logic_error("inv_rcond(): use of LAPACK must be enabled"); + return false; + } + #endif } -template +template inline bool -auxlib::inv_tr(Mat& out, const Base& X, const uword layout) +auxlib::inv_tr(Mat& A, const uword layout) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - out = X.get_ref(); + if(A.is_empty()) { return true; } - arma_debug_check( (out.is_square() == false), "inv(): given matrix must be square sized" ); - - if(out.is_empty()) { return true; } - - arma_debug_assert_blas_size(out); + arma_debug_assert_blas_size(A); char uplo = (layout == 0) ? 'U' : 'L'; char diag = 'N'; - blas_int n = blas_int(out.n_rows); + blas_int n = blas_int(A.n_rows); blas_int info = 0; arma_extra_debug_print("lapack::trtri()"); - lapack::trtri(&uplo, &diag, &n, out.memptr(), &n, &info); + lapack::trtri(&uplo, &diag, &n, A.memptr(), &n, &info); if(info != 0) { return false; } - if(layout == 0) - { - out = trimatu(out); // upper triangular - } - else - { - out = trimatl(out); // lower triangular - } - return true; } #else { - arma_ignore(out); - arma_ignore(X); + arma_ignore(A); arma_ignore(layout); arma_stop_logic_error("inv(): use of LAPACK must be enabled"); return false; @@ -274,90 +200,90 @@ auxlib::inv_tr(Mat& out, const Base& X, const uword layout) -template +template inline bool -auxlib::inv_sympd(Mat& out, const Base& X) +auxlib::inv_tr_rcond(Mat& A, typename get_pod_type::result& out_rcond, const uword layout) { arma_extra_debug_sigprint(); - out = X.get_ref(); - - arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized" ); - - if(out.is_empty()) { return true; } - - // if(auxlib::rudimentary_sym_check(out) == false) - // { - // if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - // if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - // return false; - // } - - if((arma_config::debug) && (auxlib::rudimentary_sym_check(out) == false)) - { - if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - } - - if(out.n_rows <= 4) - { - Mat tmp; - - const bool status = auxlib::inv_sympd_tiny(tmp, out); - - if(status == true) { out = tmp; return true; } - } - - #if defined(ARMA_USE_ATLAS) + #if defined(ARMA_USE_LAPACK) { - arma_debug_assert_atlas_size(out); + typedef typename get_pod_type::result T; - int info = 0; + if(A.is_empty()) { return true; } - arma_extra_debug_print("atlas::clapack_potrf()"); - info = atlas::clapack_potrf(atlas::CblasColMajor, atlas::CblasLower, out.n_rows, out.memptr(), out.n_rows); + out_rcond = auxlib::rcond_trimat(A, layout); - if(info != 0) { return false; } + arma_debug_assert_blas_size(A); - arma_extra_debug_print("atlas::clapack_potri()"); - info = atlas::clapack_potri(atlas::CblasColMajor, atlas::CblasLower, out.n_rows, out.memptr(), out.n_rows); + char uplo = (layout == 0) ? 'U' : 'L'; + char diag = 'N'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; - if(info != 0) { return false; } + arma_extra_debug_print("lapack::trtri()"); + lapack::trtri(&uplo, &diag, &n, A.memptr(), &n, &info); - out = symmatl(out); + if(info != 0) { out_rcond = T(0); return false; } return true; } - #elif defined(ARMA_USE_LAPACK) + #else + { + arma_ignore(A); + arma_ignore(out_rcond); + arma_ignore(layout); + arma_stop_logic_error("inv(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::inv_sympd(Mat& A, bool& out_sympd_state) + { + arma_extra_debug_sigprint(); + + out_sympd_state = false; + + if(A.is_empty()) { return true; } + + #if defined(ARMA_USE_LAPACK) { - arma_debug_assert_blas_size(out); + arma_debug_assert_blas_size(A); char uplo = 'L'; - blas_int n = blas_int(out.n_rows); + blas_int n = blas_int(A.n_rows); blas_int info = 0; // NOTE: for complex matrices, zpotrf() assumes the matrix is hermitian (not simply symmetric) arma_extra_debug_print("lapack::potrf()"); - lapack::potrf(&uplo, &n, out.memptr(), &n, &info); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); if(info != 0) { return false; } + out_sympd_state = true; + arma_extra_debug_print("lapack::potri()"); - lapack::potri(&uplo, &n, out.memptr(), &n, &info); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); if(info != 0) { return false; } - out = symmatl(out); + A = symmatl(A); return true; } #else { - arma_ignore(out); - arma_ignore(X); - arma_stop_logic_error("inv_sympd(): use of ATLAS or LAPACK must be enabled"); + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_stop_logic_error("inv_sympd(): use of LAPACK must be enabled"); return false; } #endif @@ -366,318 +292,241 @@ auxlib::inv_sympd(Mat& out, const Base& X) template -arma_cold inline bool -auxlib::inv_sympd_tiny(Mat& out, const Mat& X) +auxlib::inv_sympd(Mat& out, const Mat& X) { arma_extra_debug_sigprint(); - // if(sympd_helper::guess_sympd(X) == false) { return false; } + out = X; - return auxlib::inv_tiny(out, X); + bool sympd_state_junk = false; + + return auxlib::inv_sympd(out, sympd_state_junk); } -template +template inline -eT -auxlib::det(const Base& X) +bool +auxlib::inv_sympd_rcond(Mat& A, bool& out_sympd_state, eT& out_rcond) { arma_extra_debug_sigprint(); - typedef typename get_pod_type::result T; - - const bool make_copy = (is_Mat::value) ? true : false; - - const unwrap tmp(X.get_ref()); - const Mat& A = tmp.M; + out_sympd_state = false; - arma_debug_check( (A.is_square() == false), "det(): given matrix must be square sized" ); + if(A.is_empty()) { return true; } - const uword N = A.n_rows; - - if(N <= 4) + #if defined(ARMA_USE_LAPACK) { - const eT det_val = auxlib::det_tinymat(A, N); + typedef typename get_pod_type::result T; - const T det_min = std::numeric_limits::epsilon(); + arma_debug_assert_blas_size(A); - if(std::abs(det_val) >= det_min) { return det_val; } - } - - return auxlib::det_lapack(A, make_copy); - } - - - -template -arma_cold -inline -eT -auxlib::det_tinymat(const Mat& X, const uword N) - { - arma_extra_debug_sigprint(); - - const eT* Xm = X.memptr(); - - switch(N) - { - case 0: - return eT(1); - break; + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + T norm_val = T(0); - case 1: - return Xm[0]; - break; + podarray work(A.n_rows); - case 2: - { - return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); - } - break; + arma_extra_debug_print("lapack::lansy()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); - case 3: - { - // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); - // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); - // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); - // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); - // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); - // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); - // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); - - const eT val1 = Xm[pos<0,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]); - const eT val2 = Xm[pos<1,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]); - const eT val3 = Xm[pos<2,0>::n3]*(Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]); - - return ( val1 - val2 + val3 ); - } - break; + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); - case 4: - { - const eT val = \ - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ - - Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,0>::n4] \ - - Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ - + Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,0>::n4] \ - + Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ - - Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,0>::n4] \ - - Xm[pos<0,3>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ - + Xm[pos<0,2>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,1>::n4] \ - + Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ - - Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,1>::n4] \ - - Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ - + Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,1>::n4] \ - + Xm[pos<0,3>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ - - Xm[pos<0,1>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,2>::n4] \ - - Xm[pos<0,3>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ - + Xm[pos<0,0>::n4] * Xm[pos<1,3>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,2>::n4] \ - + Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ - - Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,3>::n4] * Xm[pos<3,2>::n4] \ - - Xm[pos<0,2>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ - + Xm[pos<0,1>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,0>::n4] * Xm[pos<3,3>::n4] \ - + Xm[pos<0,2>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ - - Xm[pos<0,0>::n4] * Xm[pos<1,2>::n4] * Xm[pos<2,1>::n4] * Xm[pos<3,3>::n4] \ - - Xm[pos<0,1>::n4] * Xm[pos<1,0>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ - + Xm[pos<0,0>::n4] * Xm[pos<1,1>::n4] * Xm[pos<2,2>::n4] * Xm[pos<3,3>::n4] \ - ; - - return val; - } - break; + if(info != 0) { out_rcond = eT(0); return false; } + + out_sympd_state = true; + + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); + + if(arma_isnan(out_rcond)) { return false; } + + arma_extra_debug_print("lapack::potri()"); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); - default: - return eT(0); - ; + if(info != 0) { return false; } + + A = symmatl(A); + + return true; } + #else + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); + return false; + } + #endif } -//! determinant of a matrix -template +template inline -eT -auxlib::det_lapack(const Mat& X, const bool make_copy) +bool +auxlib::inv_sympd_rcond(Mat< std::complex >& A, bool& out_sympd_state, T& out_rcond) { arma_extra_debug_sigprint(); - Mat X_copy; + out_sympd_state = false; - if(make_copy) { X_copy = X; } + if(A.is_empty()) { return true; } - Mat& tmp = (make_copy) ? X_copy : const_cast< Mat& >(X); - - if(tmp.is_empty()) { return eT(1); } - - #if defined(ARMA_USE_ATLAS) + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + return false; + } + #elif defined(ARMA_USE_LAPACK) { - arma_debug_assert_atlas_size(tmp); + arma_debug_assert_blas_size(A); - podarray ipiv(tmp.n_rows); + char norm_id = '1'; + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + T norm_val = T(0); - arma_extra_debug_print("atlas::clapack_getrf()"); - //const int info = - atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); + podarray work(A.n_rows); - // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero - eT val = tmp.at(0,0); - for(uword i=1; i < tmp.n_rows; ++i) - { - val *= tmp.at(i,i); - } + arma_extra_debug_print("lapack::lanhe()"); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); - int sign = +1; - for(uword i=0; i < tmp.n_rows; ++i) - { - if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 - { - sign *= -1; - } - } + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); - return ( (sign < 0) ? -val : val ); - } - #elif defined(ARMA_USE_LAPACK) - { - arma_debug_assert_blas_size(tmp); + if(info != 0) { out_rcond = T(0); return false; } - podarray ipiv(tmp.n_rows); + out_sympd_state = true; - blas_int info = 0; - blas_int n_rows = blas_int(tmp.n_rows); - blas_int n_cols = blas_int(tmp.n_cols); + out_rcond = auxlib::lu_rcond_sympd(A, norm_val); - arma_extra_debug_print("lapack::getrf()"); - lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); + if(arma_isnan(out_rcond)) { return false; } - // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero - eT val = tmp.at(0,0); - for(uword i=1; i < tmp.n_rows; ++i) - { - val *= tmp.at(i,i); - } + arma_extra_debug_print("lapack::potri()"); + lapack::potri(&uplo, &n, A.memptr(), &n, &info); - blas_int sign = +1; - for(uword i=0; i < tmp.n_rows; ++i) - { - if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 - { - sign *= -1; - } - } + if(info != 0) { return false; } - return ( (sign < 0) ? -val : val ); + A = symmatl(A); + + return true; } #else { - arma_stop_logic_error("det(): use of ATLAS or LAPACK must be enabled"); - return eT(0); + arma_ignore(A); + arma_ignore(out_sympd_state); + arma_ignore(out_rcond); + arma_stop_logic_error("inv_sympd_rcond(): use LAPACK must be enabled"); + return false; } #endif } -//! log determinant of a matrix -template +//! determinant of a matrix +template inline bool -auxlib::log_det(eT& out_val, typename get_pod_type::result& out_sign, const Base& X) +auxlib::det(eT& out_val, Mat& A) { arma_extra_debug_sigprint(); - typedef typename get_pod_type::result T; + if(A.is_empty()) { out_val = eT(1); return true; } - #if defined(ARMA_USE_ATLAS) + #if defined(ARMA_USE_LAPACK) { - Mat tmp(X.get_ref()); - arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix must be square sized" ); - - if(tmp.is_empty()) - { - out_val = eT(0); - out_sign = T(1); - return true; - } + arma_debug_assert_blas_size(A); - arma_debug_assert_atlas_size(tmp); + podarray ipiv(A.n_rows); - podarray ipiv(tmp.n_rows); + blas_int info = 0; + blas_int n_rows = blas_int(A.n_rows); + blas_int n_cols = blas_int(A.n_cols); - arma_extra_debug_print("atlas::clapack_getrf()"); - const int info = atlas::clapack_getrf(atlas::CblasColMajor, tmp.n_rows, tmp.n_cols, tmp.memptr(), tmp.n_rows, ipiv.memptr()); + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n_rows, &n_cols, A.memptr(), &n_rows, ipiv.memptr(), &info); if(info < 0) { return false; } - // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero - - sword sign = (is_cx::no) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; - eT val = (is_cx::no) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); - - for(uword i=1; i < tmp.n_rows; ++i) - { - const eT x = tmp.at(i,i); - - sign *= (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; - val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); - } + // on output A appears to be L+U_alt, where U_alt is U with the main diagonal set to zero + eT val = A.at(0,0); + for(uword i=1; i < A.n_rows; ++i) { val *= A.at(i,i); } - for(uword i=0; i < tmp.n_rows; ++i) + blas_int sign = +1; + for(uword i=0; i < A.n_rows; ++i) { - if( int(i) != ipiv.mem[i] ) // NOTE: no adjustment required, as the clapack version of getrf() assumes counting from 0 - { - sign *= -1; - } + // NOTE: adjustment of -1 is required as Fortran counts from 1 + if( blas_int(i) != (ipiv.mem[i] - 1) ) { sign *= -1; } } - out_val = val; - out_sign = T(sign); + out_val = (sign < 0) ? eT(-val) : eT(val); return true; } - #elif defined(ARMA_USE_LAPACK) + #else { - Mat tmp(X.get_ref()); - arma_debug_check( (tmp.is_square() == false), "log_det(): given matrix must be square sized" ); - - if(tmp.is_empty()) - { - out_val = eT(0); - out_sign = T(1); - return true; - } - - arma_debug_assert_blas_size(tmp); + arma_ignore(out_val); + arma_ignore(A); + arma_stop_logic_error("det(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! log determinant of a matrix +template +inline +bool +auxlib::log_det(eT& out_val, typename get_pod_type::result& out_sign, Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_empty()) { out_val = eT(0); out_sign = T(1); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); - podarray ipiv(tmp.n_rows); + podarray ipiv(A.n_rows); blas_int info = 0; - blas_int n_rows = blas_int(tmp.n_rows); - blas_int n_cols = blas_int(tmp.n_cols); + blas_int n_rows = blas_int(A.n_rows); + blas_int n_cols = blas_int(A.n_cols); arma_extra_debug_print("lapack::getrf()"); - lapack::getrf(&n_rows, &n_cols, tmp.memptr(), &n_rows, ipiv.memptr(), &info); + lapack::getrf(&n_rows, &n_cols, A.memptr(), &n_rows, ipiv.memptr(), &info); if(info < 0) { return false; } - // on output tmp appears to be L+U_alt, where U_alt is U with the main diagonal set to zero + // on output A appears to be L+U_alt, where U_alt is U with the main diagonal set to zero - sword sign = (is_cx::no) ? ( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; - eT val = (is_cx::no) ? std::log( (access::tmp_real( tmp.at(0,0) ) < T(0)) ? tmp.at(0,0)*T(-1) : tmp.at(0,0) ) : std::log( tmp.at(0,0) ); + sword sign = (is_cx::no) ? ( (access::tmp_real( A.at(0,0) ) < T(0)) ? -1 : +1 ) : +1; + eT val = (is_cx::no) ? std::log( (access::tmp_real( A.at(0,0) ) < T(0)) ? A.at(0,0)*T(-1) : A.at(0,0) ) : std::log( A.at(0,0) ); - for(uword i=1; i < tmp.n_rows; ++i) + for(uword i=1; i < A.n_rows; ++i) { - const eT x = tmp.at(i,i); + const eT x = A.at(i,i); sign *= (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); } - for(uword i=0; i < tmp.n_rows; ++i) + for(uword i=0; i < A.n_rows; ++i) { if( blas_int(i) != (ipiv.mem[i] - 1) ) // NOTE: adjustment of -1 is required as Fortran counts from 1 { @@ -692,13 +541,54 @@ auxlib::log_det(eT& out_val, typename get_pod_type::result& out_sign, const } #else { - arma_ignore(X); + arma_ignore(A); + arma_ignore(out_val); + arma_ignore(out_sign); + arma_stop_logic_error("log_det(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::log_det_sympd(typename get_pod_type::result& out_val, Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_empty()) { out_val = T(0); return true; } + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_assert_blas_size(A); + + char uplo = 'L'; + blas_int n = blas_int(A.n_rows); + blas_int info = 0; + + arma_extra_debug_print("lapack::potrf()"); + lapack::potrf(&uplo, &n, A.memptr(), &n, &info); + + if(info != 0) { return false; } + + T val = T(0); - out_val = eT(0); - out_sign = T(0); + for(uword i=0; i < A.n_rows; ++i) { val += std::log( access::tmp_real(A.at(i,i)) ); } - arma_stop_logic_error("log_det(): use of ATLAS or LAPACK must be enabled"); + out_val = T(2) * val; + return true; + } + #else + { + arma_ignore(out_val); + arma_ignore(A); + arma_stop_logic_error("log_det_sympd(): use of LAPACK must be enabled"); return false; } #endif @@ -719,47 +609,26 @@ auxlib::lu(Mat& L, Mat& U, podarray& ipiv, const Base& const uword U_n_rows = U.n_rows; const uword U_n_cols = U.n_cols; - if(U.is_empty()) - { - L.set_size(U_n_rows, 0); - U.set_size(0, U_n_cols); - ipiv.reset(); - return true; - } + if(U.is_empty()) { L.set_size(U_n_rows, 0); U.set_size(0, U_n_cols); ipiv.reset(); return true; } - #if defined(ARMA_USE_ATLAS) || defined(ARMA_USE_LAPACK) + #if defined(ARMA_USE_LAPACK) { - #if defined(ARMA_USE_ATLAS) - { - arma_debug_assert_atlas_size(U); - - ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); - - arma_extra_debug_print("atlas::clapack_getrf()"); - int info = atlas::clapack_getrf(atlas::CblasColMajor, U_n_rows, U_n_cols, U.memptr(), U_n_rows, ipiv.memptr()); - - if(info < 0) { return false; } - } - #elif defined(ARMA_USE_LAPACK) - { - arma_debug_assert_blas_size(U); - - ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); - - blas_int info = 0; - - blas_int n_rows = blas_int(U_n_rows); - blas_int n_cols = blas_int(U_n_cols); - - arma_extra_debug_print("lapack::getrf()"); - lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); - - if(info < 0) { return false; } - - // take into account that Fortran counts from 1 - arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); - } - #endif + arma_debug_assert_blas_size(U); + + ipiv.set_size( (std::min)(U_n_rows, U_n_cols) ); + + blas_int info = 0; + + blas_int n_rows = blas_int(U_n_rows); + blas_int n_cols = blas_int(U_n_cols); + + arma_extra_debug_print("lapack::getrf()"); + lapack::getrf(&n_rows, &n_cols, U.memptr(), &n_rows, ipiv.memptr(), &info); + + if(info < 0) { return false; } + + // take into account that Fortran counts from 1 + arrayops::inplace_minus(ipiv.memptr(), blas_int(1), ipiv.n_elem); L.copy_size(U); @@ -770,7 +639,7 @@ auxlib::lu(Mat& L, Mat& U, podarray& ipiv, const Base& L.at(row,col) = eT(0); } - if( L.in_range(col,col) == true ) + if( L.in_range(col,col) ) { L.at(col,col) = eT(1); } @@ -786,7 +655,7 @@ auxlib::lu(Mat& L, Mat& U, podarray& ipiv, const Base& } #else { - arma_stop_logic_error("lu(): use of ATLAS or LAPACK must be enabled"); + arma_stop_logic_error("lu(): use of LAPACK must be enabled"); return false; } #endif @@ -939,18 +808,13 @@ auxlib::eig_gen arma_debug_assert_blas_size(X); - if(X.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } - if(X.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } vals.set_size(X.n_rows, 1); - Mat tmp(1,1); + Mat tmp(1, 1, arma_nozeros_indicator()); if(vecs_on) { @@ -967,7 +831,7 @@ auxlib::eig_gen T* vr = (vecs_on) ? tmp.memptr() : junk.memptr(); blas_int ldvl = blas_int(1); blas_int ldvr = (vecs_on) ? blas_int(tmp.n_rows) : blas_int(1); - blas_int lwork = (vecs_on) ? (3 * ((std::max)(blas_int(1), 4*N)) ) : (3 * ((std::max)(blas_int(1), 3*N)) ); + blas_int lwork = 64*N; // lwork_min = (vecs_on) ? (std::max)(blas_int(1), 4*N) : (std::max)(blas_int(1), 3*N) blas_int info = 0; podarray work( static_cast(lwork) ); @@ -1052,14 +916,9 @@ auxlib::eig_gen arma_debug_assert_blas_size(X); - if(X.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } - if(X.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } vals.set_size(X.n_rows, 1); @@ -1074,7 +933,7 @@ auxlib::eig_gen eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); blas_int ldvl = blas_int(1); blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); - blas_int lwork = 3 * ((std::max)(blas_int(1), 2*N)); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 2*N) blas_int info = 0; podarray work( static_cast(lwork) ); @@ -1124,18 +983,13 @@ auxlib::eig_gen_balance arma_debug_assert_blas_size(X); - if(X.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } - if(X.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } vals.set_size(X.n_rows, 1); - Mat tmp(1,1); + Mat tmp(1, 1, arma_nozeros_indicator()); if(vecs_on) { @@ -1157,7 +1011,7 @@ auxlib::eig_gen_balance blas_int ilo = blas_int(0); blas_int ihi = blas_int(0); T abnrm = T(0); - blas_int lwork = 3 * ((std::max)(blas_int(1), blas_int(2*N))); + blas_int lwork = 64*N; // lwork_min = (vecs_on) ? (std::max)(blas_int(1), 2*N) : (std::max)(blas_int(1), 3*N) blas_int info = blas_int(0); podarray scale(X.n_rows); @@ -1186,22 +1040,380 @@ auxlib::eig_gen_balance { for(uword j=0; j < X.n_rows; ++j) { - if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + vecs.at(i,j) = std::complex( tmp.at(i,j), tmp.at(i,j+1) ); + vecs.at(i,j+1) = std::complex( tmp.at(i,j), -tmp.at(i,j+1) ); + } + + ++j; + } + else + { + for(uword i=0; i(tmp.at(i,j), T(0)); + } + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! eigen decomposition of general square matrix (complex, balance given matrix) +template +inline +bool +auxlib::eig_gen_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& vecs, + const bool vecs_on, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_CRIPPLED_LAPACK) + { + arma_extra_debug_print("auxlib::eig_gen_balance(): redirecting to auxlib::eig_gen() due to crippled LAPACK"); + + return auxlib::eig_gen(vals, vecs, vecs_on, expr); + } + #elif defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); vecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + if(vecs_on) { vecs.set_size(X.n_rows, X.n_rows); } + + podarray junk(1); + + char bal = 'B'; + char jobvl = 'N'; + char jobvr = (vecs_on) ? 'V' : 'N'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + eT* vl = junk.memptr(); + eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); + blas_int ldvl = blas_int(1); + blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(2*N)) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geevx() -- START"); + lapack::cx_geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals.memptr(), vl, &ldvl, vr, &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geevx() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(vecs); + arma_ignore(vecs_on); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (real) +template +inline +bool +auxlib::eig_gen_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + Mat ltmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + Mat rtmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 4*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geev() -- START"); + lapack::geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, work.memptr(), &lwork, &info); + arma_extra_debug_print("lapack::geev() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) + { + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); + } + ++j; + } + else + { + for(uword i=0; i(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (complex) +template +inline +bool +auxlib::eig_gen_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base< std::complex, T1 >& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 2*N) + blas_int info = 0; + + podarray work( static_cast(lwork) ); + podarray< T> rwork( static_cast(2*N) ); + + arma_extra_debug_print("lapack::cx_geev() -- START"); + lapack::cx_geev(&jobvl, &jobvr, &N, X.memptr(), &N, vals.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + arma_extra_debug_print("lapack::cx_geev() -- END"); + + return (info == 0); + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(expr); + arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigen decomposition of general square matrix (real, balance given matrix) +template +inline +bool +auxlib::eig_gen_twosided_balance + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + + Mat X = expr.get_ref(); + + arma_debug_check( (X.is_square() == false), "eig_gen(): given matrix must be square sized" ); + + arma_debug_assert_blas_size(X); + + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } + + vals.set_size(X.n_rows, 1); + + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); + + Mat ltmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + Mat rtmp(X.n_rows, X.n_rows, arma_nozeros_indicator()); + + char bal = 'B'; + char jobvl = 'V'; + char jobvr = 'V'; + char sense = 'N'; + blas_int N = blas_int(X.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int ilo = blas_int(0); + blas_int ihi = blas_int(0); + T abnrm = T(0); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(3*N)) + blas_int info = blas_int(0); + + podarray scale(X.n_rows); + podarray rconde(X.n_rows); + podarray rcondv(X.n_rows); + + podarray work( static_cast(lwork) ); + podarray iwork( uword(1) ); // iwork not used by lapack::geevx() as sense = 'N' + + podarray vals_real(X.n_rows); + podarray vals_imag(X.n_rows); + + arma_extra_debug_print("lapack::geevx() -- START"); + lapack::geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals_real.memptr(), vals_imag.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, iwork.memptr(), &info); + arma_extra_debug_print("lapack::geevx() -- END"); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + std::complex* vals_mem = vals.memptr(); + + for(uword i=0; i < X.n_rows; ++i) { vals_mem[i] = std::complex(vals_real[i], vals_imag[i]); } + + for(uword j=0; j < X.n_rows; ++j) + { + if( (j < (X.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < X.n_rows; ++i) { - for(uword i=0; i < X.n_rows; ++i) - { - vecs.at(i,j) = std::complex( tmp.at(i,j), tmp.at(i,j+1) ); - vecs.at(i,j+1) = std::complex( tmp.at(i,j), -tmp.at(i,j+1) ); - } - - ++j; + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); } - else + ++j; + } + else + { + for(uword i=0; i(tmp.at(i,j), T(0)); - } + lvecs.at(i,j) = std::complex(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); } } } @@ -1211,8 +1423,8 @@ auxlib::eig_gen_balance #else { arma_ignore(vals); - arma_ignore(vecs); - arma_ignore(vecs_on); + arma_ignore(lvecs); + arma_ignore(rvecs); arma_ignore(expr); arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); return false; @@ -1222,15 +1434,15 @@ auxlib::eig_gen_balance -//! eigen decomposition of general square matrix (complex, balance given matrix) +//! two-sided eigen decomposition of general square matrix (complex, balance given matrix) template inline bool -auxlib::eig_gen_balance +auxlib::eig_gen_twosided_balance ( - Mat< std::complex >& vals, - Mat< std::complex >& vecs, - const bool vecs_on, + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, const Base< std::complex, T1 >& expr ) { @@ -1238,9 +1450,9 @@ auxlib::eig_gen_balance #if defined(ARMA_CRIPPLED_LAPACK) { - arma_extra_debug_print("auxlib::eig_gen_balance(): redirecting to auxlib::eig_gen() due to crippled LAPACK"); + arma_extra_debug_print("auxlib::eig_gen_twosided_balance(): redirecting to auxlib::eig_gen() due to crippled LAPACK"); - return auxlib::eig_gen(vals, vecs, vecs_on, expr); + return auxlib::eig_gen(vals, lvecs, rvecs, expr); } #elif defined(ARMA_USE_LAPACK) { @@ -1253,34 +1465,26 @@ auxlib::eig_gen_balance arma_debug_assert_blas_size(X); - if(X.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(X.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } - if(X.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && X.internal_has_nonfinite()) { return false; } vals.set_size(X.n_rows, 1); - if(vecs_on) { vecs.set_size(X.n_rows, X.n_rows); } - - podarray junk(1); + lvecs.set_size(X.n_rows, X.n_rows); + rvecs.set_size(X.n_rows, X.n_rows); char bal = 'B'; - char jobvl = 'N'; - char jobvr = (vecs_on) ? 'V' : 'N'; + char jobvl = 'V'; + char jobvr = 'V'; char sense = 'N'; blas_int N = blas_int(X.n_rows); - eT* vl = junk.memptr(); - eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); - blas_int ldvl = blas_int(1); - blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); blas_int ilo = blas_int(0); blas_int ihi = blas_int(0); T abnrm = T(0); - blas_int lwork = 3 * ((std::max)(blas_int(1), blas_int(2*N))); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), blas_int(2*N)) blas_int info = blas_int(0); podarray scale(X.n_rows); @@ -1291,7 +1495,7 @@ auxlib::eig_gen_balance podarray< T> rwork( static_cast(2*N) ); arma_extra_debug_print("lapack::cx_geevx() -- START"); - lapack::cx_geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals.memptr(), vl, &ldvl, vr, &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); + lapack::cx_geevx(&bal, &jobvl, &jobvr, &sense, &N, X.memptr(), &N, vals.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, &ilo, &ihi, scale.memptr(), &abnrm, rconde.memptr(), rcondv.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); arma_extra_debug_print("lapack::cx_geevx() -- END"); return (info == 0); @@ -1299,8 +1503,8 @@ auxlib::eig_gen_balance #else { arma_ignore(vals); - arma_ignore(vecs); - arma_ignore(vecs_on); + arma_ignore(lvecs); + arma_ignore(rvecs); arma_ignore(expr); arma_stop_logic_error("eig_gen(): use of LAPACK must be enabled"); return false; @@ -1310,7 +1514,7 @@ auxlib::eig_gen_balance -//! eigendecomposition of general square real matrix pair (real) +//! eigendecomposition of general square matrix pair (real) template inline bool @@ -1339,19 +1543,14 @@ auxlib::eig_pair arma_debug_assert_blas_size(A); - if(A.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(A.is_empty()) { vals.reset(); vecs.reset(); return true; } - if(A.is_finite() == false) { return false; } - if(B.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } vals.set_size(A.n_rows, 1); - Mat tmp(1,1); + Mat tmp(1, 1, arma_nozeros_indicator()); if(vecs_on) { @@ -1368,7 +1567,7 @@ auxlib::eig_pair T* vr = (vecs_on) ? tmp.memptr() : junk.memptr(); blas_int ldvl = blas_int(1); blas_int ldvr = (vecs_on) ? blas_int(tmp.n_rows) : blas_int(1); - blas_int lwork = 3 * ((std::max)(blas_int(1), 8*N)); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 8*N) blas_int info = 0; podarray alphar(A.n_rows); @@ -1410,7 +1609,7 @@ auxlib::eig_pair } } - if(beta_has_zero) { arma_debug_warn("eig_pair(): given matrices appear ill-conditioned"); } + if(beta_has_zero) { arma_debug_warn_level(1, "eig_pair(): given matrices appear ill-conditioned"); } if(vecs_on) { @@ -1453,7 +1652,7 @@ auxlib::eig_pair -//! eigendecomposition of general square real matrix pair (complex) +//! eigendecomposition of general square matrix pair (complex) template inline bool @@ -1482,15 +1681,10 @@ auxlib::eig_pair arma_debug_assert_blas_size(A); - if(A.is_empty()) - { - vals.reset(); - vecs.reset(); - return true; - } + if(A.is_empty()) { vals.reset(); vecs.reset(); return true; } - if(A.is_finite() == false) { return false; } - if(B.is_finite() == false) { return false; } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } vals.set_size(A.n_rows, 1); @@ -1505,7 +1699,7 @@ auxlib::eig_pair eT* vr = (vecs_on) ? vecs.memptr() : junk.memptr(); blas_int ldvl = blas_int(1); blas_int ldvr = (vecs_on) ? blas_int(vecs.n_rows) : blas_int(1); - blas_int lwork = 3 * ((std::max)(blas_int(1),2*N)); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) blas_int info = 0; podarray alpha(A.n_rows); @@ -1536,7 +1730,7 @@ auxlib::eig_pair beta_has_zero = (beta_has_zero || (beta_val == zero)); } - if(beta_has_zero) { arma_debug_warn("eig_pair(): given matrices appear ill-conditioned"); } + if(beta_has_zero) { arma_debug_warn_level(1, "eig_pair(): given matrices appear ill-conditioned"); } return true; } @@ -1555,37 +1749,252 @@ auxlib::eig_pair -//! eigenvalues of a symmetric real matrix -template +//! two-sided eigendecomposition of general square matrix pair (real) +template inline bool -auxlib::eig_sym(Col& eigval, const Base& X) +auxlib::eig_pair_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base& A_expr, + const Base& B_expr + ) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + typedef typename T1::pod_type T; + typedef std::complex eT; - arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix must be square sized" ); + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); - if(A.is_empty()) + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + lvecs.set_size(A.n_rows, A.n_rows); + rvecs.set_size(A.n_rows, A.n_rows); + + Mat ltmp(A.n_rows, A.n_rows, arma_nozeros_indicator()); + Mat rtmp(A.n_rows, A.n_rows, arma_nozeros_indicator()); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(A.n_rows); + blas_int ldvl = blas_int(ltmp.n_rows); + blas_int ldvr = blas_int(rtmp.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1), 8*N) + blas_int info = 0; + + podarray alphar(A.n_rows); + podarray alphai(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + + arma_extra_debug_print("lapack::ggev()"); + lapack::ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alphar.memptr(), alphai.memptr(), beta.memptr(), ltmp.memptr(), &ldvl, rtmp.memptr(), &ldvr, work.memptr(), &lwork, &info); + + if(info != 0) { return false; } + + arma_extra_debug_print("reformatting eigenvalues and eigenvectors"); + + eT* vals_mem = vals.memptr(); + const T* alphar_mem = alphar.memptr(); + const T* alphai_mem = alphai.memptr(); + const T* beta_mem = beta.memptr(); + + bool beta_has_zero = false; + + for(uword j=0; j(re, im); + + if( (alphai_val > T(0)) && (j < (A.n_rows-1)) ) + { + ++j; + vals_mem[j] = std::complex(re,-im); // force exact conjugate + } + } + + if(beta_has_zero) { arma_debug_warn_level(1, "eig_pair(): given matrices appear ill-conditioned"); } + + for(uword j=0; j < A.n_rows; ++j) + { + if( (j < (A.n_rows-1)) && (vals_mem[j] == std::conj(vals_mem[j+1])) ) + { + for(uword i=0; i < A.n_rows; ++i) + { + lvecs.at(i,j) = std::complex( ltmp.at(i,j), ltmp.at(i,j+1) ); + lvecs.at(i,j+1) = std::complex( ltmp.at(i,j), -ltmp.at(i,j+1) ); + rvecs.at(i,j) = std::complex( rtmp.at(i,j), rtmp.at(i,j+1) ); + rvecs.at(i,j+1) = std::complex( rtmp.at(i,j), -rtmp.at(i,j+1) ); + } + ++j; + } + else + { + for(uword i=0; i(ltmp.at(i,j), T(0)); + rvecs.at(i,j) = std::complex(rtmp.at(i,j), T(0)); + } + } + } + + return true; + } + #else + { + arma_ignore(vals); + arma_ignore(lvecs); + arma_ignore(rvecs); + arma_ignore(A_expr); + arma_ignore(B_expr); + arma_stop_logic_error("eig_pair(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! two-sided eigendecomposition of general square matrix pair (complex) +template +inline +bool +auxlib::eig_pair_twosided + ( + Mat< std::complex >& vals, + Mat< std::complex >& lvecs, + Mat< std::complex >& rvecs, + const Base< std::complex, T1 >& A_expr, + const Base< std::complex, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::pod_type T; + typedef typename std::complex eT; + + Mat A(A_expr.get_ref()); + Mat B(B_expr.get_ref()); + + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "eig_pair(): given matrices must be square sized" ); + + arma_debug_check( (A.n_rows != B.n_rows), "eig_pair(): given matrices must have the same size" ); + + arma_debug_assert_blas_size(A); + + if(A.is_empty()) { vals.reset(); lvecs.reset(); rvecs.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } + + vals.set_size(A.n_rows, 1); + + lvecs.set_size(A.n_rows, A.n_rows); + rvecs.set_size(A.n_rows, A.n_rows); + + char jobvl = 'V'; + char jobvr = 'V'; + blas_int N = blas_int(A.n_rows); + blas_int ldvl = blas_int(lvecs.n_rows); + blas_int ldvr = blas_int(rvecs.n_rows); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) + blas_int info = 0; + + podarray alpha(A.n_rows); + podarray beta(A.n_rows); + + podarray work( static_cast(lwork) ); + podarray rwork( static_cast(8*N) ); + + arma_extra_debug_print("lapack::cx_ggev()"); + lapack::cx_ggev(&jobvl, &jobvr, &N, A.memptr(), &N, B.memptr(), &N, alpha.memptr(), beta.memptr(), lvecs.memptr(), &ldvl, rvecs.memptr(), &ldvr, work.memptr(), &lwork, rwork.memptr(), &info); + + if(info != 0) { return false; } + + eT* vals_mem = vals.memptr(); + const eT* alpha_mem = alpha.memptr(); + const eT* beta_mem = beta.memptr(); + + const std::complex zero(T(0), T(0)); + + bool beta_has_zero = false; + + for(uword i=0; i +inline +bool +auxlib::eig_sym(Col& eigval, Mat& A) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix must be square sized" ); + + if(A.is_empty()) { eigval.reset(); return true; } if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false)) { - arma_debug_warn("eig_sym(): given matrix is not symmetric"); + arma_debug_warn_level(1, "eig_sym(): given matrix is not symmetric"); } + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(A)) { return false; } + arma_debug_assert_blas_size(A); eigval.set_size(A.n_rows); @@ -1594,7 +2003,7 @@ auxlib::eig_sym(Col& eigval, const Base& X) char uplo = 'U'; blas_int N = blas_int(A.n_rows); - blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) ); + blas_int lwork = (64+2)*N; // lwork_min = (std::max)(blas_int(1), 3*N-1) blas_int info = 0; podarray work( static_cast(lwork) ); @@ -1607,7 +2016,7 @@ auxlib::eig_sym(Col& eigval, const Base& X) #else { arma_ignore(eigval); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); return false; } @@ -1617,10 +2026,10 @@ auxlib::eig_sym(Col& eigval, const Base& X) //! eigenvalues of a hermitian complex matrix -template +template inline bool -auxlib::eig_sym(Col& eigval, const Base,T1>& X) +auxlib::eig_sym(Col& eigval, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -1628,27 +2037,17 @@ auxlib::eig_sym(Col& eigval, const Base,T1>& X) { typedef typename std::complex eT; - Mat A(X.get_ref()); - arma_debug_check( (A.is_square() == false), "eig_sym(): given matrix must be square sized" ); - if(A.is_empty()) - { - eigval.reset(); - return true; - } - - // if(auxlib::rudimentary_sym_check(A) == false) - // { - // arma_debug_warn("eig_sym(): given matrix is not hermitian"); - // return false; - // } + if(A.is_empty()) { eigval.reset(); return true; } if((arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false)) { - arma_debug_warn("eig_sym(): given matrix is not hermitian"); + arma_debug_warn_level(1, "eig_sym(): given matrix is not hermitian"); } + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(A)) { return false; } + arma_debug_assert_blas_size(A); eigval.set_size(A.n_rows); @@ -1657,11 +2056,11 @@ auxlib::eig_sym(Col& eigval, const Base,T1>& X) char uplo = 'U'; blas_int N = blas_int(A.n_rows); - blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) ); + blas_int lwork = (64+1)*N; // lwork_min = (std::max)(blas_int(1), 2*N-1) blas_int info = 0; podarray work( static_cast(lwork) ); - podarray rwork( static_cast( (std::max)(blas_int(1), 3*N-2) ) ); + podarray rwork( static_cast( (std::max)(blas_int(1), 3*N) ) ); arma_extra_debug_print("lapack::heev()"); lapack::heev(&jobz, &uplo, &N, A.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); @@ -1671,7 +2070,7 @@ auxlib::eig_sym(Col& eigval, const Base,T1>& X) #else { arma_ignore(eigval); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("eig_sym(): use of LAPACK must be enabled"); return false; } @@ -1690,16 +2089,13 @@ auxlib::eig_sym(Col& eigval, Mat& eigvec, const Mat& X) #if defined(ARMA_USE_LAPACK) { - eigvec = X; + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); - arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix must be square sized" ); + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } - if(eigvec.is_empty()) - { - eigval.reset(); - eigvec.reset(); - return true; - } + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } arma_debug_assert_blas_size(eigvec); @@ -1709,7 +2105,7 @@ auxlib::eig_sym(Col& eigval, Mat& eigvec, const Mat& X) char uplo = 'U'; blas_int N = blas_int(eigvec.n_rows); - blas_int lwork = 3 * ( (std::max)(blas_int(1), 3*N-1) ); + blas_int lwork = (64+2)*N; // lwork_min = (std::max)(blas_int(1), 3*N-1) blas_int info = 0; podarray work( static_cast(lwork) ); @@ -1744,16 +2140,13 @@ auxlib::eig_sym(Col& eigval, Mat< std::complex >& eigvec, const Mat< std:: { typedef typename std::complex eT; - eigvec = X; + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); - arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix must be square sized" ); + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } - if(eigvec.is_empty()) - { - eigval.reset(); - eigvec.reset(); - return true; - } + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } arma_debug_assert_blas_size(eigvec); @@ -1763,11 +2156,11 @@ auxlib::eig_sym(Col& eigval, Mat< std::complex >& eigvec, const Mat< std:: char uplo = 'U'; blas_int N = blas_int(eigvec.n_rows); - blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*N-1) ); + blas_int lwork = (64+1)*N; // lwork_min = (std::max)(blas_int(1), 2*N-1) blas_int info = 0; podarray work( static_cast(lwork) ); - podarray rwork( static_cast((std::max)(blas_int(1), 3*N-2)) ); + podarray rwork( static_cast((std::max)(blas_int(1), 3*N)) ); arma_extra_debug_print("lapack::heev()"); lapack::heev(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &info); @@ -1797,16 +2190,13 @@ auxlib::eig_sym_dc(Col& eigval, Mat& eigvec, const Mat& X) #if defined(ARMA_USE_LAPACK) { - eigvec = X; + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); - arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix must be square sized" ); + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } - if(eigvec.is_empty()) - { - eigval.reset(); - eigvec.reset(); - return true; - } + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } arma_debug_assert_blas_size(eigvec); @@ -1815,16 +2205,39 @@ auxlib::eig_sym_dc(Col& eigval, Mat& eigvec, const Mat& X) char jobz = 'V'; char uplo = 'U'; - blas_int N = blas_int(eigvec.n_rows); - blas_int lwork = 2 * (1 + 6*N + 2*(N*N)); - blas_int liwork = 3 * (3 + 5*N); - blas_int info = 0; + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork_min = 1 + 6*N + 2*(N*N); + blas_int liwork_min = 3 + 5*N; + blas_int info = 0; + + blas_int lwork_proposed = 0; + blas_int liwork_proposed = 0; + + if(N >= 32) + { + eT work_query[2] = {}; + blas_int iwork_query[2] = {}; + + blas_int lwork_query = -1; + blas_int liwork_query = -1; + + arma_extra_debug_print("lapack::syevd()"); + lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), &work_query[0], &lwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + liwork_proposed = iwork_query[0]; + } - podarray work( static_cast( lwork) ); - podarray iwork( static_cast(liwork) ); + blas_int lwork_final = (std::max)( lwork_proposed, lwork_min); + blas_int liwork_final = (std::max)(liwork_proposed, liwork_min); + + podarray work( static_cast( lwork_final) ); + podarray iwork( static_cast(liwork_final) ); arma_extra_debug_print("lapack::syevd()"); - lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, iwork.memptr(), &liwork, &info); + lapack::syevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork_final, iwork.memptr(), &liwork_final, &info); return (info == 0); } @@ -1853,16 +2266,13 @@ auxlib::eig_sym_dc(Col& eigval, Mat< std::complex >& eigvec, const Mat< st { typedef typename std::complex eT; - eigvec = X; + arma_debug_check( (X.is_square() == false), "eig_sym(): given matrix must be square sized" ); - arma_debug_check( (eigvec.is_square() == false), "eig_sym(): given matrix must be square sized" ); + if(arma_config::check_nonfinite && trimat_helper::has_nonfinite_triu(X)) { return false; } - if(eigvec.is_empty()) - { - eigval.reset(); - eigvec.reset(); - return true; - } + eigvec = X; + + if(eigvec.is_empty()) { eigval.reset(); eigvec.reset(); return true; } arma_debug_assert_blas_size(eigvec); @@ -1871,18 +2281,46 @@ auxlib::eig_sym_dc(Col& eigval, Mat< std::complex >& eigvec, const Mat< st char jobz = 'V'; char uplo = 'U'; - blas_int N = blas_int(eigvec.n_rows); - blas_int lwork = 2 * (2*N + N*N); - blas_int lrwork = 2 * (1 + 5*N + 2*(N*N)); - blas_int liwork = 3 * (3 + 5*N); - blas_int info = 0; + blas_int N = blas_int(eigvec.n_rows); + blas_int lwork_min = 2*N + N*N; + blas_int lrwork_min = 1 + 5*N + 2*(N*N); + blas_int liwork_min = 3 + 5*N; + blas_int info = 0; - podarray work( static_cast(lwork) ); - podarray rwork( static_cast(lrwork) ); - podarray iwork( static_cast(liwork) ); + blas_int lwork_proposed = 0; + blas_int lrwork_proposed = 0; + blas_int liwork_proposed = 0; + + if(N >= 32) + { + eT work_query[2] = {}; + T rwork_query[2] = {}; + blas_int iwork_query[2] = {}; + + blas_int lwork_query = -1; + blas_int lrwork_query = -1; + blas_int liwork_query = -1; + + arma_extra_debug_print("lapack::heevd()"); + lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), &work_query[0], &lwork_query, &rwork_query[0], &lrwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + lrwork_proposed = static_cast( rwork_query[0] ); + liwork_proposed = iwork_query[0]; + } + + blas_int lwork_final = (std::max)( lwork_proposed, lwork_min); + blas_int lrwork_final = (std::max)(lrwork_proposed, lrwork_min); + blas_int liwork_final = (std::max)(liwork_proposed, liwork_min); + + podarray work( static_cast( lwork_final) ); + podarray< T> rwork( static_cast(lrwork_final) ); + podarray iwork( static_cast(liwork_final) ); arma_extra_debug_print("lapack::heevd()"); - lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork, rwork.memptr(), &lrwork, iwork.memptr(), &liwork, &info); + lapack::heevd(&jobz, &uplo, &N, eigvec.memptr(), &N, eigval.memptr(), work.memptr(), &lwork_final, rwork.memptr(), &lrwork_final, iwork.memptr(), &liwork_final, &info); return (info == 0); } @@ -1906,18 +2344,7 @@ auxlib::chol_simple(Mat& X) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_ATLAS) - { - arma_debug_assert_atlas_size(X); - - int info = 0; - - arma_extra_debug_print("atlas::clapack_potrf()"); - info = atlas::clapack_potrf(atlas::CblasColMajor, atlas::CblasUpper, X.n_rows, X.memptr(), X.n_rows); - - return (info == 0); - } - #elif defined(ARMA_USE_LAPACK) + #if defined(ARMA_USE_LAPACK) { arma_debug_assert_blas_size(X); @@ -1934,7 +2361,7 @@ auxlib::chol_simple(Mat& X) { arma_ignore(X); - arma_stop_logic_error("chol(): use of ATLAS or LAPACK must be enabled"); + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); return false; } #endif @@ -1949,22 +2376,7 @@ auxlib::chol(Mat& X, const uword layout) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_ATLAS) - { - arma_debug_assert_atlas_size(X); - - int info = 0; - - arma_extra_debug_print("atlas::clapack_potrf()"); - info = atlas::clapack_potrf(atlas::CblasColMajor, ((layout == 0) ? atlas::CblasUpper : atlas::CblasLower), X.n_rows, X.memptr(), X.n_rows); - - if(info != 0) { return false; } - - X = (layout == 0) ? trimatu(X) : trimatl(X); // trimatu() and trimatl() return the same type - - return true; - } - #elif defined(ARMA_USE_LAPACK) + #if defined(ARMA_USE_LAPACK) { arma_debug_assert_blas_size(X); @@ -1986,7 +2398,7 @@ auxlib::chol(Mat& X, const uword layout) arma_ignore(X); arma_ignore(layout); - arma_stop_logic_error("chol(): use of ATLAS or LAPACK must be enabled"); + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); return false; } #endif @@ -2077,6 +2489,61 @@ auxlib::chol_band_common(Mat& X, const uword KD, const uword layout) } + +template +inline +bool +auxlib::chol_pivot(Mat& X, Mat& P, const uword layout) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename get_pod_type::result T; + + arma_debug_assert_blas_size(X); + + char uplo = (layout == 0) ? 'U' : 'L'; + blas_int n = blas_int(X.n_rows); + blas_int rank = 0; + T tol = T(-1); + blas_int info = 0; + + podarray ipiv( X.n_rows); + podarray work(2*X.n_rows); + + ipiv.zeros(); + + arma_extra_debug_print("lapack::pstrf()"); + lapack::pstrf(&uplo, &n, X.memptr(), &n, ipiv.memptr(), &rank, &tol, work.memptr(), &info); + + if(info != 0) { return false; } + + X = (layout == 0) ? trimatu(X) : trimatl(X); // trimatu() and trimatl() return the same type + + P.set_size(X.n_rows, 1); + + for(uword i=0; i < X.n_rows; ++i) + { + P[i] = uword(ipiv[i] - 1); // take into account that Fortran counts from 1 + } + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(P); + arma_ignore(layout); + + arma_stop_logic_error("chol(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + // // hessenberg decomposition template @@ -2092,10 +2559,7 @@ auxlib::hess(Mat& H, const Base& X, Col& tao) arma_debug_check( (H.is_square() == false), "hess(): given matrix must be square sized" ); - if(H.is_empty()) - { - return true; - } + if(H.is_empty()) { return true; } arma_debug_assert_blas_size(H); @@ -2118,14 +2582,195 @@ auxlib::hess(Mat& H, const Base& X, Col& tao) return (info == 0); } - return true; + return true; + } + #else + { + arma_ignore(H); + arma_ignore(X); + arma_ignore(tao); + arma_stop_logic_error("hess(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr(Mat& Q, Mat& R, const Base& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + R = X.get_ref(); + + const uword R_n_rows = R.n_rows; + const uword R_n_cols = R.n_cols; + + if(R.is_empty()) { Q.eye(R_n_rows, R_n_rows); return true; } + + arma_debug_assert_blas_size(R); + + blas_int m = static_cast(R_n_rows); + blas_int n = static_cast(R_n_cols); + blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + Q.set_size(R_n_rows, R_n_rows); + + arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); + + // + // construct R + + for(uword col=0; col < R_n_cols; ++col) + { + for(uword row=(col+1); row < R_n_rows; ++row) + { + R.at(row,col) = eT(0); + } + } + + + if( (is_float::value) || (is_double::value) ) + { + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + else + if( (is_cx_float::value) || (is_cx_double::value) ) + { + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + + return (info == 0); + } + #else + { + arma_ignore(Q); + arma_ignore(R); + arma_ignore(X); + arma_stop_logic_error("qr(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + if(is_Mat::value) + { + const unwrap tmp(X.get_ref()); + const Mat& M = tmp.M; + + if(M.n_rows < M.n_cols) { return auxlib::qr(Q, R, X); } + } + + Q = X.get_ref(); + + const uword Q_n_rows = Q.n_rows; + const uword Q_n_cols = Q.n_cols; + + if( Q_n_rows <= Q_n_cols ) { return auxlib::qr(Q, R, Q); } + + if(Q.is_empty()) { Q.set_size(Q_n_rows, 0); R.set_size(0, Q_n_cols); return true; } + + arma_debug_assert_blas_size(Q); + + blas_int m = static_cast(Q_n_rows); + blas_int n = static_cast(Q_n_cols); + blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int k = (std::min)(m,n); + blas_int info = 0; + + podarray tau( static_cast(k) ); + + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::geqrf()"); + lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + + if(info != 0) { return false; } + + R.zeros(Q_n_cols, Q_n_cols); + + // + // construct R + + for(uword col=0; col < Q_n_cols; ++col) + { + for(uword row=0; row <= col; ++row) + { + R.at(row,col) = Q.at(row,col); + } + } + + if( (is_float::value) || (is_double::value) ) + { + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + else + if( (is_cx_float::value) || (is_cx_double::value) ) + { + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); + } + + return (info == 0); } #else { - arma_ignore(H); + arma_ignore(Q); + arma_ignore(R); arma_ignore(X); - arma_ignore(tao); - arma_stop_logic_error("hess(): use of LAPACK must be enabled"); + arma_stop_logic_error("qr_econ(): use of LAPACK must be enabled"); return false; } #endif @@ -2136,7 +2781,7 @@ auxlib::hess(Mat& H, const Base& X, Col& tao) template inline bool -auxlib::qr(Mat& Q, Mat& R, const Base& X) +auxlib::qr_pivot(Mat& Q, Mat& R, Mat& P, const Base& X) { arma_extra_debug_sigprint(); @@ -2150,6 +2795,11 @@ auxlib::qr(Mat& Q, Mat& R, const Base& X) if(R.is_empty()) { Q.eye(R_n_rows, R_n_rows); + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) { P.at(col) = col; } + return true; } @@ -2157,29 +2807,30 @@ auxlib::qr(Mat& Q, Mat& R, const Base& X) blas_int m = static_cast(R_n_rows); blas_int n = static_cast(R_n_cols); - blas_int lwork = 0; - blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int lwork_min = (std::max)(blas_int(3*n + 1), (std::max)(m,n)); // take into account requirements of geqp3() and orgqr() blas_int k = (std::min)(m,n); blas_int info = 0; - podarray tau( static_cast(k) ); + podarray tau( static_cast(k) ); + podarray jpvt( R_n_cols ); - eT work_query[2]; - blas_int lwork_query = -1; + jpvt.zeros(); - arma_extra_debug_print("lapack::geqrf()"); - lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqp3()"); + lapack::geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), &work_query[0], &lwork_query, &info); if(info != 0) { return false; } blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - lwork = (std::max)(lwork_proposed, lwork_min); + podarray work( static_cast(lwork_final) ); - podarray work( static_cast(lwork) ); - - arma_extra_debug_print("lapack::geqrf()"); - lapack::geqrf(&m, &n, R.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); + arma_extra_debug_print("lapack::geqp3()"); + lapack::geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), work.memptr(), &lwork_final, &info); if(info != 0) { return false; } @@ -2188,28 +2839,19 @@ auxlib::qr(Mat& Q, Mat& R, const Base& X) arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); // - // construct R + // construct R and P + + P.set_size(R_n_cols, 1); for(uword col=0; col < R_n_cols; ++col) { - for(uword row=(col+1); row < R_n_rows; ++row) - { - R.at(row,col) = eT(0); - } + for(uword row=(col+1); row < R_n_rows; ++row) { R.at(row,col) = eT(0); } + + P.at(col) = jpvt[col] - 1; // take into account that Fortran counts from 1 } - - if( (is_float::value) || (is_double::value) ) - { - arma_extra_debug_print("lapack::orgqr()"); - lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); - } - else - if( (is_cx_float::value) || (is_cx_double::value) ) - { - arma_extra_debug_print("lapack::ungqr()"); - lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); - } + arma_extra_debug_print("lapack::orgqr()"); + lapack::orgqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); return (info == 0); } @@ -2217,6 +2859,7 @@ auxlib::qr(Mat& Q, Mat& R, const Base& X) { arma_ignore(Q); arma_ignore(R); + arma_ignore(P); arma_ignore(X); arma_stop_logic_error("qr(): use of LAPACK must be enabled"); return false; @@ -2226,102 +2869,83 @@ auxlib::qr(Mat& Q, Mat& R, const Base& X) -template +template inline -bool -auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) +bool +auxlib::qr_pivot(Mat< std::complex >& Q, Mat< std::complex >& R, Mat& P, const Base,T1>& X) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - if(is_Mat::value) - { - const unwrap tmp(X.get_ref()); - const Mat& M = tmp.M; - - if(M.n_rows < M.n_cols) - { - return auxlib::qr(Q, R, X); - } - } - - Q = X.get_ref(); + typedef typename std::complex eT; - const uword Q_n_rows = Q.n_rows; - const uword Q_n_cols = Q.n_cols; + R = X.get_ref(); - if( Q_n_rows <= Q_n_cols ) - { - return auxlib::qr(Q, R, Q); - } + const uword R_n_rows = R.n_rows; + const uword R_n_cols = R.n_cols; - if(Q.is_empty()) + if(R.is_empty()) { - Q.set_size(Q_n_rows, 0 ); - R.set_size(0, Q_n_cols); + Q.eye(R_n_rows, R_n_rows); + + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) { P.at(col) = col; } + return true; } - arma_debug_assert_blas_size(Q); + arma_debug_assert_blas_size(R); - blas_int m = static_cast(Q_n_rows); - blas_int n = static_cast(Q_n_cols); - blas_int lwork = 0; - blas_int lwork_min = (std::max)(blas_int(1), (std::max)(m,n)); // take into account requirements of geqrf() _and_ orgqr()/ungqr() + blas_int m = static_cast(R_n_rows); + blas_int n = static_cast(R_n_cols); + blas_int lwork_min = (std::max)(blas_int(3*n + 1), (std::max)(m,n)); // take into account requirements of geqp3() and ungqr() blas_int k = (std::min)(m,n); blas_int info = 0; - podarray tau( static_cast(k) ); + podarray tau( static_cast(k) ); + podarray< T> rwork( 2*R_n_cols ); + podarray jpvt( R_n_cols ); - eT work_query[2]; - blas_int lwork_query = -1; + jpvt.zeros(); - arma_extra_debug_print("lapack::geqrf()"); - lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), &work_query[0], &lwork_query, &info); + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::geqp3()"); + lapack::cx_geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), &work_query[0], &lwork_query, rwork.memptr(), &info); if(info != 0) { return false; } blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - lwork = (std::max)(lwork_proposed, lwork_min); - - podarray work( static_cast(lwork) ); + podarray work( static_cast(lwork_final) ); - arma_extra_debug_print("lapack::geqrf()"); - lapack::geqrf(&m, &n, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); + arma_extra_debug_print("lapack::geqp3()"); + lapack::cx_geqp3(&m, &n, R.memptr(), &m, jpvt.memptr(), tau.memptr(), work.memptr(), &lwork_final, rwork.memptr(), &info); if(info != 0) { return false; } - R.set_size(Q_n_cols, Q_n_cols); + Q.set_size(R_n_rows, R_n_rows); + + arrayops::copy( Q.memptr(), R.memptr(), (std::min)(Q.n_elem, R.n_elem) ); // - // construct R + // construct R and P - for(uword col=0; col < Q_n_cols; ++col) + P.set_size(R_n_cols, 1); + + for(uword col=0; col < R_n_cols; ++col) { - for(uword row=0; row <= col; ++row) - { - R.at(row,col) = Q.at(row,col); - } + for(uword row=(col+1); row < R_n_rows; ++row) { R.at(row,col) = eT(0); } - for(uword row=(col+1); row < Q_n_cols; ++row) - { - R.at(row,col) = eT(0); - } + P.at(col) = jpvt[col] - 1; // take into account that Fortran counts from 1 } - if( (is_float::value) || (is_double::value) ) - { - arma_extra_debug_print("lapack::orgqr()"); - lapack::orgqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); - } - else - if( (is_cx_float::value) || (is_cx_double::value) ) - { - arma_extra_debug_print("lapack::ungqr()"); - lapack::ungqr(&m, &n, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork, &info); - } + arma_extra_debug_print("lapack::ungqr()"); + lapack::ungqr(&m, &m, &k, Q.memptr(), &m, tau.memptr(), work.memptr(), &lwork_final, &info); return (info == 0); } @@ -2329,8 +2953,9 @@ auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) { arma_ignore(Q); arma_ignore(R); + arma_ignore(P); arma_ignore(X); - arma_stop_logic_error("qr_econ(): use of LAPACK must be enabled"); + arma_stop_logic_error("qr(): use of LAPACK must be enabled"); return false; } #endif @@ -2338,71 +2963,66 @@ auxlib::qr_econ(Mat& Q, Mat& R, const Base& X) -template +template inline bool -auxlib::svd(Col& S, const Base& X, uword& X_n_rows, uword& X_n_cols) +auxlib::svd(Col& S, Mat& A) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); - - X_n_rows = A.n_rows; - X_n_cols = A.n_cols; + if(A.is_empty()) { S.reset(); return true; } - if(A.is_empty()) - { - S.reset(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); - Mat U(1, 1); - Mat V(1, A.n_cols); + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, A.n_cols, arma_nozeros_indicator()); char jobu = 'N'; char jobvt = 'N'; - blas_int m = A.n_rows; - blas_int n = A.n_cols; - blas_int min_mn = (std::min)(m,n); - blas_int lda = A.n_rows; - blas_int ldu = U.n_rows; - blas_int ldvt = V.n_rows; - blas_int lwork = 0; - blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; S.set_size( static_cast(min_mn) ); - eT work_query[2]; - blas_int lwork_query = -1; - - arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); - - if(info != 0) { return false; } + blas_int lwork_proposed = 0; - blas_int lwork_proposed = static_cast( work_query[0] ); + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } - lwork = (std::max)(lwork_proposed, lwork_min); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - podarray work( static_cast(lwork) ); + podarray work( static_cast(lwork_final) ); arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); return (info == 0); } #else { arma_ignore(S); - arma_ignore(X); - arma_ignore(X_n_rows); - arma_ignore(X_n_cols); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -2411,10 +3031,10 @@ auxlib::svd(Col& S, const Base& X, uword& X_n_rows, uword& X_n_cols) -template +template inline bool -auxlib::svd(Col& S, const Base, T1>& X, uword& X_n_rows, uword& X_n_cols) +auxlib::svd(Col& S, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -2422,66 +3042,59 @@ auxlib::svd(Col& S, const Base, T1>& X, uword& X_n_rows, uwor { typedef std::complex eT; - Mat A(X.get_ref()); + if(A.is_empty()) { S.reset(); return true; } - X_n_rows = A.n_rows; - X_n_cols = A.n_cols; - - if(A.is_empty()) - { - S.reset(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); - Mat U(1, 1); - Mat V(1, A.n_cols); + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, A.n_cols, arma_nozeros_indicator()); char jobu = 'N'; char jobvt = 'N'; - blas_int m = A.n_rows; - blas_int n = A.n_cols; - blas_int min_mn = (std::min)(m,n); - blas_int lda = A.n_rows; - blas_int ldu = U.n_rows; - blas_int ldvt = V.n_rows; - blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn+(std::max)(m,n) ) ); - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), 2*min_mn+(std::max)(m,n) ); + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); - podarray< T> rwork( static_cast(5*min_mn) ); - - blas_int lwork_tmp = -1; // let gesvd_() calculate the optimum size of the workspace - - arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_tmp, rwork.memptr(), &info); - - if(info != 0) { return false; } + podarray rwork( static_cast(5*min_mn) ); - blas_int proposed_lwork = static_cast(real(work[0])); + blas_int lwork_proposed = 0; - if(proposed_lwork > lwork) + if(A.n_elem >= 256) { - lwork = proposed_lwork; - work.set_size( static_cast(lwork) ); + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); } + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), &info); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); return (info == 0); } #else { arma_ignore(S); - arma_ignore(X); - arma_ignore(X_n_rows); - arma_ignore(X_n_cols); - + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -2490,50 +3103,18 @@ auxlib::svd(Col& S, const Base, T1>& X, uword& X_n_rows, uwor -template -inline -bool -auxlib::svd(Col& S, const Base& X) - { - arma_extra_debug_sigprint(); - - uword junk; - return auxlib::svd(S, X, junk, junk); - } - - - -template -inline -bool -auxlib::svd(Col& S, const Base, T1>& X) - { - arma_extra_debug_sigprint(); - - uword junk; - return auxlib::svd(S, X, junk, junk); - } - - - -template +template inline bool -auxlib::svd(Mat& U, Col& S, Mat& V, const Base& X) +auxlib::svd(Mat& U, Col& S, Mat& V, Mat& A) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } - if(A.is_empty()) - { - U.eye(A.n_rows, A.n_rows); - S.reset(); - V.eye(A.n_cols, A.n_cols); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -2543,35 +3124,39 @@ auxlib::svd(Mat& U, Col& S, Mat& V, const Base& X) char jobu = 'A'; char jobvt = 'A'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); - blas_int lwork = 0; - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; S.set_size( static_cast(min_mn) ); - // let gesvd_() calculate the optimum size of the workspace - eT work_query[2]; - blas_int lwork_query = -1; - - arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); - - if(info != 0) { return false; } + blas_int lwork_proposed = 0; - blas_int lwork_proposed = static_cast( work_query[0] ); + if(A.n_elem >= 1024) + { + // query to find optimum size of workspace + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } - lwork = (std::max)(lwork_proposed, lwork_min); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - podarray work( static_cast(lwork) ); + podarray work( static_cast(lwork_final) ); arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); if(info != 0) { return false; } @@ -2584,7 +3169,7 @@ auxlib::svd(Mat& U, Col& S, Mat& V, const Base& X) arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -2593,10 +3178,10 @@ auxlib::svd(Mat& U, Col& S, Mat& V, const Base& X) -template +template inline bool -auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X) +auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -2604,15 +3189,9 @@ auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, con { typedef std::complex eT; - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } - if(A.is_empty()) - { - U.eye(A.n_rows, A.n_rows); - S.reset(); - V.eye(A.n_cols, A.n_cols); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -2622,37 +3201,40 @@ auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, con char jobu = 'A'; char jobvt = 'A'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork = 3 * ( (std::max)(blas_int(1), 2*min_mn + (std::max)(m,n) ) ); - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = (std::max)( blas_int(1), 2*min_mn + (std::max)(m,n) ); + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); - podarray rwork( static_cast(5*min_mn) ); - - blas_int lwork_tmp = -1; // let gesvd_() calculate the optimum size of the workspace - - arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_tmp, rwork.memptr(), &info); + podarray rwork( static_cast(5*min_mn) ); - if(info != 0) { return false; } - - blas_int proposed_lwork = static_cast(real(work[0])); + blas_int lwork_proposed = 0; - if(proposed_lwork > lwork) + if(A.n_elem >= 256) { - lwork = proposed_lwork; - work.set_size( static_cast(lwork) ); + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); } + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), &info); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); if(info != 0) { return false; } @@ -2665,7 +3247,7 @@ auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, con arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -2674,24 +3256,18 @@ auxlib::svd(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, con -template +template inline bool -auxlib::svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const char mode) +auxlib::svd_econ(Mat& U, Col& S, Mat& V, Mat& A, const char mode) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(); S.reset(); V.eye(); return true; } - if(A.is_empty()) - { - U.eye(); - S.reset(); - V.eye(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -2745,29 +3321,30 @@ auxlib::svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const } - blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) ); - blas_int info = 0; - - - podarray work( static_cast(lwork) ); - - blas_int lwork_tmp = -1; // let gesvd_() calculate the optimum size of the workspace - - arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_tmp, &info); - - if(info != 0) { return false; } + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; - blas_int proposed_lwork = static_cast(work[0]); + blas_int lwork_proposed = 0; - if(proposed_lwork > lwork) + if(A.n_elem >= 1024) { - lwork = proposed_lwork; - work.set_size( static_cast(lwork) ); + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::gesvd()"); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); } + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::gesvd()"); - lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, &info); + lapack::gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, &info); if(info != 0) { return false; } @@ -2780,7 +3357,7 @@ auxlib::svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_ignore(mode); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; @@ -2790,10 +3367,10 @@ auxlib::svd_econ(Mat& U, Col& S, Mat& V, const Base& X, const -template +template inline bool -auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X, const char mode) +auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A, const char mode) { arma_extra_debug_sigprint(); @@ -2801,15 +3378,9 @@ auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V { typedef std::complex eT; - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(); S.reset(); V.eye(); return true; } - if(A.is_empty()) - { - U.eye(); - S.reset(); - V.eye(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -2862,30 +3433,32 @@ auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V V.set_size( static_cast(ldvt), static_cast(n) ); } - blas_int lwork = 3 * ( (std::max)(blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ) ); - blas_int info = 0; - - - podarray work( static_cast(lwork ) ); - podarray rwork( static_cast(5*min_mn) ); - - blas_int lwork_tmp = -1; // let gesvd_() calculate the optimum size of the workspace - - arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_tmp, rwork.memptr(), &info); + blas_int lwork_min = (std::max)( blas_int(1), (std::max)( (3*min_mn + (std::max)(m,n)), 5*min_mn ) ); + blas_int info = 0; - if(info != 0) { return false; } + podarray rwork( static_cast(5*min_mn) ); - blas_int proposed_lwork = static_cast(real(work[0])); + blas_int lwork_proposed = 0; - if(proposed_lwork > lwork) + if(A.n_elem >= 256) { - lwork = proposed_lwork; - work.set_size( static_cast(lwork) ); + eT work_query[2] = {}; + blas_int lwork_query = -1; // query to find optimum size of workspace + + arma_extra_debug_print("lapack::cx_gesvd()"); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); } + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesvd()"); - lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), &info); + lapack::cx_gesvd(&jobu, &jobvt, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), &info); if(info != 0) { return false; } @@ -2898,7 +3471,7 @@ auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_ignore(mode); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; @@ -2908,58 +3481,68 @@ auxlib::svd_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V -template +template inline bool -auxlib::svd_dc(Col& S, const Base& X, uword& X_n_rows, uword& X_n_cols) +auxlib::svd_dc(Col& S, Mat& A) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + if(A.is_empty()) { S.reset(); return true; } - X_n_rows = A.n_rows; - X_n_cols = A.n_cols; - - if(A.is_empty()) - { - S.reset(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); - Mat U(1, 1); - Mat V(1, 1); + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, 1, arma_nozeros_indicator()); char jobz = 'N'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork = 3 * ( 3*min_mn + std::max( std::max(m,n), 7*min_mn ) ); - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = 3*min_mn + (std::max)( max_mn, 7*min_mn ); + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( work_query[0] ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::gesdd()"); - lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, iwork.memptr(), &info); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); return (info == 0); } #else { arma_ignore(S); - arma_ignore(X); - arma_ignore(X_n_rows); - arma_ignore(X_n_cols); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -2968,10 +3551,10 @@ auxlib::svd_dc(Col& S, const Base& X, uword& X_n_rows, uword& X_n_col -template +template inline bool -auxlib::svd_dc(Col& S, const Base, T1>& X, uword& X_n_rows, uword& X_n_cols) +auxlib::svd_dc(Col& S, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -2979,50 +3562,60 @@ auxlib::svd_dc(Col& S, const Base, T1>& X, uword& X_n_rows, u { typedef std::complex eT; - Mat A(X.get_ref()); - - X_n_rows = A.n_rows; - X_n_cols = A.n_cols; + if(A.is_empty()) { S.reset(); return true; } - if(A.is_empty()) - { - S.reset(); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); - Mat U(1, 1); - Mat V(1, 1); + Mat U(1, 1, arma_nozeros_indicator()); + Mat V(1, 1, arma_nozeros_indicator()); char jobz = 'N'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork = 3 * (2*min_mn + std::max(m,n)); - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = 2*min_mn + max_mn; + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); - podarray rwork( static_cast(7*min_mn) ); // LAPACK 3.4.2 docs state 5*min(m,n), while zgesdd() seems to write past the end + podarray rwork( static_cast(7*min_mn) ); // from LAPACK 3.8 docs: LAPACK <= v3.6 needs 7*mn podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesdd()"); - lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); return (info == 0); } #else { arma_ignore(S); - arma_ignore(X); - arma_ignore(X_n_rows); - arma_ignore(X_n_cols); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -3031,50 +3624,18 @@ auxlib::svd_dc(Col& S, const Base, T1>& X, uword& X_n_rows, u -template -inline -bool -auxlib::svd_dc(Col& S, const Base& X) - { - arma_extra_debug_sigprint(); - - uword junk; - return auxlib::svd_dc(S, X, junk, junk); - } - - - -template -inline -bool -auxlib::svd_dc(Col& S, const Base, T1>& X) - { - arma_extra_debug_sigprint(); - - uword junk; - return auxlib::svd_dc(S, X, junk, junk); - } - - - -template +template inline bool -auxlib::svd_dc(Mat& U, Col& S, Mat& V, const Base& X) +auxlib::svd_dc(Mat& U, Col& S, Mat& V, Mat& A) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } - if(A.is_empty()) - { - U.eye(A.n_rows, A.n_rows); - S.reset(); - V.eye(A.n_cols, A.n_cols); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -3083,25 +3644,43 @@ auxlib::svd_dc(Mat& U, Col& S, Mat& V, const Base& X) char jobz = 'A'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int max_mn = (std::max)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork1 = 3*min_mn*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 4*min_mn ); - blas_int lwork2 = 3*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 3*min_mn + max_mn ); - blas_int lwork = 2 * ((std::max)(lwork1, lwork2)); // due to differences between lapack 3.1 and 3.4 - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork1 = 3*min_mn*min_mn + (std::max)(max_mn, 4*min_mn*min_mn + 4*min_mn); // as per LAPACK 3.2 docs + blas_int lwork2 = 4*min_mn*min_mn + 6*min_mn + max_mn; // as per LAPACK 3.8 docs; consistent with LAPACK 3.4 docs + blas_int lwork_min = (std::max)(lwork1, lwork2); // due to differences between LAPACK 3.2 and 3.8 + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::gesdd()"); - lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, iwork.memptr(), &info); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); if(info != 0) { return false; } @@ -3114,7 +3693,7 @@ auxlib::svd_dc(Mat& U, Col& S, Mat& V, const Base& X) arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -3123,10 +3702,10 @@ auxlib::svd_dc(Mat& U, Col& S, Mat& V, const Base& X) -template +template inline bool -auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X) +auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -3134,15 +3713,9 @@ auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, { typedef std::complex eT; - Mat A(X.get_ref()); + if(A.is_empty()) { U.eye(A.n_rows, A.n_rows); S.reset(); V.eye(A.n_cols, A.n_cols); return true; } - if(A.is_empty()) - { - U.eye(A.n_rows, A.n_rows); - S.reset(); - V.eye(A.n_cols, A.n_cols); - return true; - } + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -3151,27 +3724,43 @@ auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, char jobz = 'A'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int max_mn = (std::max)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = blas_int(U.n_rows); - blas_int ldvt = blas_int(V.n_rows); - blas_int lwork = 2 * (min_mn*min_mn + 2*min_mn + max_mn); - blas_int lrwork1 = 5*min_mn*min_mn + 7*min_mn; - blas_int lrwork2 = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); - blas_int lrwork = (std::max)(lrwork1, lrwork2); // due to differences between lapack 3.1 and 3.4 - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = blas_int(U.n_rows); + blas_int ldvt = blas_int(V.n_rows); + blas_int lwork_min = min_mn*min_mn + 2*min_mn + max_mn; // as per LAPACK 3.2, 3.4, 3.8 docs + blas_int lrwork = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); // as per LAPACK 3.4 docs; LAPACK 3.8 uses 5*min_mn+5 instead of 5*min_mn+7 + blas_int info = 0; S.set_size( static_cast(min_mn) ); - podarray work( static_cast(lwork ) ); podarray rwork( static_cast(lrwork ) ); podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesdd()"); - lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); if(info != 0) { return false; } @@ -3184,7 +3773,7 @@ auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -3193,32 +3782,32 @@ auxlib::svd_dc(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, -template +template inline bool -auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, const Base& X) +auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, Mat& A) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_LAPACK) { - Mat A(X.get_ref()); + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); char jobz = 'S'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int max_mn = (std::max)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = m; - blas_int ldvt = min_mn; - blas_int lwork1 = 3*min_mn*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 4*min_mn ); - blas_int lwork2 = 3*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 3*min_mn + max_mn ); - blas_int lwork = 2 * ((std::max)(lwork1, lwork2)); // due to differences between lapack 3.1 and 3.4 - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = m; + blas_int ldvt = min_mn; + blas_int lwork1 = 3*min_mn*min_mn + (std::max)( max_mn, 4*min_mn*min_mn + 4*min_mn ); // as per LAPACK 3.2 docs + blas_int lwork2 = 4*min_mn*min_mn + 6*min_mn + max_mn; // as per LAPACK 3.4 docs; LAPACK 3.8 requires 4*min_mn*min_mn + 7*min_mn + blas_int lwork_min = (std::max)(lwork1, lwork2); // due to differences between LAPACK 3.2 and 3.4 + blas_int info = 0; if(A.is_empty()) { @@ -3234,11 +3823,29 @@ auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, const Base& X) V.set_size( static_cast(min_mn), static_cast(n) ); - podarray work( static_cast(lwork ) ); podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 1024) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::gesdd()"); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast(work_query[0]); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::gesdd()"); - lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, iwork.memptr(), &info); + lapack::gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, iwork.memptr(), &info); if(info != 0) { return false; } @@ -3251,7 +3858,7 @@ auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, const Base& X) arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -3260,10 +3867,10 @@ auxlib::svd_dc_econ(Mat& U, Col& S, Mat& V, const Base& X) -template +template inline bool -auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, const Base< std::complex, T1>& X) +auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex >& V, Mat< std::complex >& A) { arma_extra_debug_sigprint(); @@ -3271,24 +3878,22 @@ auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex > { typedef std::complex eT; - Mat A(X.get_ref()); + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); char jobz = 'S'; - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int min_mn = (std::min)(m,n); - blas_int max_mn = (std::max)(m,n); - blas_int lda = blas_int(A.n_rows); - blas_int ldu = m; - blas_int ldvt = min_mn; - blas_int lwork = 2 * (min_mn*min_mn + 2*min_mn + max_mn); - blas_int lrwork1 = 5*min_mn*min_mn + 7*min_mn; - blas_int lrwork2 = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); - blas_int lrwork = (std::max)(lrwork1, lrwork2); // due to differences between lapack 3.1 and 3.4 - blas_int info = 0; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int max_mn = (std::max)(m,n); + blas_int lda = blas_int(A.n_rows); + blas_int ldu = m; + blas_int ldvt = min_mn; + blas_int lwork_min = min_mn*min_mn + 2*min_mn + max_mn; // as per LAPACK 3.2 docs + blas_int lrwork = min_mn * ((std::max)(5*min_mn+7, 2*max_mn + 2*min_mn+1)); // LAPACK 3.8 uses 5*min_mn+5 instead of 5*min_mn+7 + blas_int info = 0; if(A.is_empty()) { @@ -3304,12 +3909,30 @@ auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex > V.set_size( static_cast(min_mn), static_cast(n) ); - podarray work( static_cast(lwork ) ); podarray rwork( static_cast(lrwork ) ); podarray iwork( static_cast(8*min_mn) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= 256) + { + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::cx_gesdd()"); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + arma_extra_debug_print("lapack::cx_gesdd()"); - lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info); + lapack::cx_gesdd(&jobz, &m, &n, A.memptr(), &lda, S.memptr(), U.memptr(), &ldu, V.memptr(), &ldvt, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); if(info != 0) { return false; } @@ -3322,7 +3945,7 @@ auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex > arma_ignore(U); arma_ignore(S); arma_ignore(V); - arma_ignore(X); + arma_ignore(A); arma_stop_logic_error("svd(): use of LAPACK must be enabled"); return false; } @@ -3331,61 +3954,6 @@ auxlib::svd_dc_econ(Mat< std::complex >& U, Col& S, Mat< std::complex > -//! solve a system of linear equations via explicit inverse (tiny matrices) -template -arma_cold -inline -bool -auxlib::solve_square_tiny(Mat& out, const Mat& A, const Base& B_expr) - { - arma_extra_debug_sigprint(); - - // NOTE: assuming A has a size <= 4x4 - - typedef typename T1::elem_type eT; - - const uword A_n_rows = A.n_rows; - - Mat A_inv(A_n_rows, A_n_rows); - - const bool status = auxlib::inv_tiny(A_inv, A); - - if(status == false) { return false; } - - const quasi_unwrap UB(B_expr.get_ref()); - const Mat& B = UB.M; - - const uword B_n_rows = B.n_rows; - const uword B_n_cols = B.n_cols; - - arma_debug_check( (A_n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); - - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_cols, B_n_cols); - return true; - } - - if(UB.is_alias(out)) - { - Mat tmp(A_n_rows, B_n_cols); - - gemm_emul::apply(tmp, A_inv, B); - - out.steal_mem(tmp); - } - else - { - out.set_size(A_n_rows, B_n_cols); - - gemm_emul::apply(out, A_inv, B); - } - - return true; - } - - - //! solve a system of linear equations via LU decomposition template inline @@ -3394,52 +3962,28 @@ auxlib::solve_square_fast(Mat& out, Mat ipiv(A_n_rows + 2); // +2 for paranoia: old versions of Atlas might be trashing memory - - arma_extra_debug_print("atlas::clapack_gesv()"); - int info = atlas::clapack_gesv(atlas::CblasColMajor, A_n_rows, B_n_cols, A.memptr(), A_n_rows, ipiv.memptr(), out.memptr(), A_n_rows); + typedef typename T1::elem_type eT; - return (info == 0); - } - #elif defined(ARMA_USE_LAPACK) - { arma_debug_assert_blas_size(A); - blas_int n = blas_int(A_n_rows); // assuming A is square - blas_int lda = blas_int(A_n_rows); + blas_int n = blas_int(A.n_rows); // assuming A is square + blas_int lda = blas_int(A.n_rows); blas_int ldb = blas_int(B_n_rows); blas_int nrhs = blas_int(B_n_cols); blas_int info = blas_int(0); - podarray ipiv(A_n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory + podarray ipiv(A.n_rows + 2); // +2 for paranoia: some versions of Lapack might be trashing memory arma_extra_debug_print("lapack::gesv()"); lapack::gesv(&n, &nrhs, A.memptr(), &lda, ipiv.memptr(), out.memptr(), &ldb, &info); @@ -3448,7 +3992,7 @@ auxlib::solve_square_fast(Mat& out, Mat& out, Mat inline bool -auxlib::solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool allow_ugly) +auxlib::solve_square_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) { arma_extra_debug_sigprint(); @@ -3475,14 +4019,10 @@ auxlib::solve_square_rcond(Mat& out, typename T1::pod_ty const uword B_n_rows = out.n_rows; const uword B_n_cols = out.n_cols; - - arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); - - if(A.is_empty() || out.is_empty()) - { - out.zeros(A.n_cols, B_n_cols); - return true; - } + + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); + + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } arma_debug_assert_blas_size(A); @@ -3499,7 +4039,7 @@ auxlib::solve_square_rcond(Mat& out, typename T1::pod_ty podarray ipiv(A.n_rows + 2); // +2 for paranoia arma_extra_debug_print("lapack::lange()"); - norm_val = lapack::lange(&norm_id, &n, &n, A.memptr(), &lda, junk.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &n, &n, A.memptr(), &lda, junk.memptr()); arma_extra_debug_print("lapack::getrf()"); lapack::getrf(&n, &n, A.memptr(), &n, ipiv.memptr(), &info); @@ -3513,8 +4053,6 @@ auxlib::solve_square_rcond(Mat& out, typename T1::pod_ty out_rcond = auxlib::lu_rcond(A, norm_val); - if( (allow_ugly == false) && (out_rcond < auxlib::epsilon_lapack(A)) ) { return false; } - return true; } #else @@ -3523,7 +4061,6 @@ auxlib::solve_square_rcond(Mat& out, typename T1::pod_ty arma_ignore(out_rcond); arma_ignore(A); arma_ignore(B_expr); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -3536,7 +4073,7 @@ auxlib::solve_square_rcond(Mat& out, typename T1::pod_ty template inline bool -auxlib::solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_square_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -3556,13 +4093,9 @@ auxlib::solve_square_refine(Mat& out, typename T1::pod_ty const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } arma_debug_assert_blas_size(A,B); @@ -3580,7 +4113,7 @@ auxlib::solve_square_refine(Mat& out, typename T1::pod_ty blas_int info = blas_int(0); eT rcond = eT(0); - Mat AF(A.n_rows, A.n_rows); + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); podarray IPIV( A.n_rows); podarray R( A.n_rows); @@ -3615,7 +4148,7 @@ auxlib::solve_square_refine(Mat& out, typename T1::pod_ty out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -3624,7 +4157,6 @@ auxlib::solve_square_refine(Mat& out, typename T1::pod_ty arma_ignore(A); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -3637,7 +4169,7 @@ auxlib::solve_square_refine(Mat& out, typename T1::pod_ty template inline bool -auxlib::solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_square_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -3658,13 +4190,9 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } arma_debug_assert_blas_size(A,B); @@ -3682,7 +4210,7 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ blas_int info = blas_int(0); T rcond = T(0); - Mat AF(A.n_rows, A.n_rows); + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); podarray IPIV( A.n_rows); podarray< T> R( A.n_rows); @@ -3717,7 +4245,7 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -3726,7 +4254,6 @@ auxlib::solve_square_refine(Mat< std::complex >& out, typ arma_ignore(A); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -3764,51 +4291,25 @@ auxlib::solve_sympd_fast_common(Mat& out, Mat(atlas::CblasColMajor, atlas::CblasLower, A_n_rows, B_n_cols, A.memptr(), A_n_rows, out.memptr(), B_n_rows); - - return (info == 0); - } - #elif defined(ARMA_USE_LAPACK) + #if defined(ARMA_USE_LAPACK) { typedef typename T1::elem_type eT; arma_debug_assert_blas_size(A, out); char uplo = 'L'; - blas_int n = blas_int(A_n_rows); // assuming A is square + blas_int n = blas_int(A.n_rows); // assuming A is square blas_int nrhs = blas_int(B_n_cols); - blas_int lda = blas_int(A_n_rows); + blas_int lda = blas_int(A.n_rows); blas_int ldb = blas_int(B_n_rows); blas_int info = blas_int(0); @@ -3822,7 +4323,7 @@ auxlib::solve_sympd_fast_common(Mat& out, Mat& out, Mat inline bool -auxlib::solve_sympd_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool allow_ugly) +auxlib::solve_sympd_rcond(Mat& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) { arma_extra_debug_sigprint(); @@ -3843,20 +4344,17 @@ auxlib::solve_sympd_rcond(Mat& out, typename T1::pod_type typedef typename T1::elem_type eT; typedef typename T1::pod_type T; - out_rcond = T(0); + out_sympd_state = false; + out_rcond = T(0); out = B_expr.get_ref(); const uword B_n_rows = out.n_rows; const uword B_n_cols = out.n_cols; - arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); - if(A.is_empty() || out.is_empty()) - { - out.zeros(A.n_cols, B_n_cols); - return true; - } + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } arma_debug_assert_blas_size(A, out); @@ -3870,13 +4368,15 @@ auxlib::solve_sympd_rcond(Mat& out, typename T1::pod_type podarray work(A.n_rows); arma_extra_debug_print("lapack::lansy()"); - norm_val = lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); arma_extra_debug_print("lapack::potrf()"); lapack::potrf(&uplo, &n, A.memptr(), &n, &info); if(info != 0) { return false; } + out_sympd_state = true; + arma_extra_debug_print("lapack::potrs()"); lapack::potrs(&uplo, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); @@ -3884,17 +4384,15 @@ auxlib::solve_sympd_rcond(Mat& out, typename T1::pod_type out_rcond = auxlib::lu_rcond_sympd(A, norm_val); - if( (allow_ugly == false) && (out_rcond < auxlib::epsilon_lapack(A)) ) { return false; } - return true; } #else { arma_ignore(out); + arma_ignore(out_sympd_state); arma_ignore(out_rcond); arma_ignore(A); arma_ignore(B_expr); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -3907,7 +4405,7 @@ auxlib::solve_sympd_rcond(Mat& out, typename T1::pod_type template inline bool -auxlib::solve_sympd_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr, const bool allow_ugly) +auxlib::solve_sympd_rcond(Mat< std::complex >& out, bool& out_sympd_state, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base< std::complex,T1>& B_expr) { arma_extra_debug_sigprint(); @@ -3915,27 +4413,26 @@ auxlib::solve_sympd_rcond(Mat< std::complex >& out, typen { arma_extra_debug_print("auxlib::solve_sympd_rcond(): redirecting to auxlib::solve_square_rcond() due to crippled LAPACK"); - return auxlib::solve_square_rcond(out, out_rcond, A, B_expr, allow_ugly); + out_sympd_state = false; + + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); } #elif defined(ARMA_USE_LAPACK) { typedef typename T1::pod_type T; typedef typename std::complex eT; - out_rcond = T(0); + out_sympd_state = false; + out_rcond = T(0); out = B_expr.get_ref(); const uword B_n_rows = out.n_rows; const uword B_n_cols = out.n_cols; - arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); - if(A.is_empty() || out.is_empty()) - { - out.zeros(A.n_cols, B_n_cols); - return true; - } + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } arma_debug_assert_blas_size(A, out); @@ -3949,13 +4446,15 @@ auxlib::solve_sympd_rcond(Mat< std::complex >& out, typen podarray work(A.n_rows); arma_extra_debug_print("lapack::lanhe()"); - norm_val = lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &n, work.memptr()); arma_extra_debug_print("lapack::potrf()"); lapack::potrf(&uplo, &n, A.memptr(), &n, &info); if(info != 0) { return false; } + out_sympd_state = true; + arma_extra_debug_print("lapack::potrs()"); lapack::potrs(&uplo, &n, &nrhs, A.memptr(), &n, out.memptr(), &n, &info); @@ -3963,17 +4462,15 @@ auxlib::solve_sympd_rcond(Mat< std::complex >& out, typen out_rcond = auxlib::lu_rcond_sympd(A, norm_val); - if( (allow_ugly == false) && (out_rcond < auxlib::epsilon_lapack(A)) ) { return false; } - return true; } #else { arma_ignore(out); + arma_ignore(out_sympd_state); arma_ignore(out_rcond); arma_ignore(A); arma_ignore(B_expr); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -3986,7 +4483,7 @@ auxlib::solve_sympd_rcond(Mat< std::complex >& out, typen template inline bool -auxlib::solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_sympd_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -4006,13 +4503,9 @@ auxlib::solve_sympd_refine(Mat& out, typename T1::pod_typ const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } arma_debug_assert_blas_size(A,B); @@ -4030,7 +4523,7 @@ auxlib::solve_sympd_refine(Mat& out, typename T1::pod_typ blas_int info = blas_int(0); eT rcond = eT(0); - Mat AF(A.n_rows, A.n_rows); + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); podarray S( A.n_rows); podarray FERR( B.n_cols); @@ -4044,9 +4537,10 @@ auxlib::solve_sympd_refine(Mat& out, typename T1::pod_typ // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + // NOTE: lapack::posvx() sets rcond to zero if A is not sympd out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -4055,7 +4549,6 @@ auxlib::solve_sympd_refine(Mat& out, typename T1::pod_typ arma_ignore(A); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -4068,7 +4561,7 @@ auxlib::solve_sympd_refine(Mat& out, typename T1::pod_typ template inline bool -auxlib::solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_sympd_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const Base,T1>& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -4076,7 +4569,7 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type { arma_extra_debug_print("auxlib::solve_sympd_refine(): redirecting to auxlib::solve_square_refine() due to crippled LAPACK"); - return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate, allow_ugly); + return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate); } #elif defined(ARMA_USE_LAPACK) { @@ -4095,13 +4588,9 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type const Mat& B = (use_copy) ? B_tmp : UB_M_as_Mat; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } arma_debug_assert_blas_size(A,B); @@ -4119,7 +4608,7 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type blas_int info = blas_int(0); T rcond = T(0); - Mat AF(A.n_rows, A.n_rows); + Mat AF(A.n_rows, A.n_rows, arma_nozeros_indicator()); podarray< T> S( A.n_rows); podarray< T> FERR( B.n_cols); @@ -4133,9 +4622,10 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type // NOTE: using const_cast(B.memptr()) to allow B to be overwritten for equilibration; // NOTE: B is created as a copy of B_expr if equilibration is enabled; otherwise B is a reference to B_expr + // NOTE: lapack::cx_posvx() sets rcond to zero if A is not sympd out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -4144,7 +4634,6 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type arma_ignore(A); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -4157,7 +4646,7 @@ auxlib::solve_sympd_refine(Mat< std::complex >& out, type template inline bool -auxlib::solve_approx_fast(Mat& out, Mat& A, const Base& B_expr) +auxlib::solve_rect_fast(Mat& out, Mat& A, const Base& B_expr) { arma_extra_debug_sigprint(); @@ -4168,17 +4657,107 @@ auxlib::solve_approx_fast(Mat& out, Mat U(B_expr.get_ref()); const Mat& B = U.M; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + arma_debug_assert_blas_size(A,B); + + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); + + if(arma::size(tmp) == arma::size(B)) { - out.zeros(A.n_cols, B.n_cols); - return true; + tmp = B; + } + else + { + tmp.zeros(); + tmp(0,0, arma::size(B)) = B; + } + + char trans = 'N'; + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + blas_int nrhs = blas_int(B.n_cols); + blas_int min_mn = (std::min)(m,n); + blas_int lwork_min = (std::max)(blas_int(1), min_mn + (std::max)(min_mn, nrhs)); + blas_int info = 0; + + blas_int lwork_proposed = 0; + + if(A.n_elem >= ((is_cx::yes) ? uword(256) : uword(1024))) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, &work_query[0], &lwork_query, &info ); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork_final, &info ); + + if(info != 0) { return false; } + + if(tmp.n_rows == A.n_cols) + { + out.steal_mem(tmp); } + else + { + out = tmp.head_rows(A.n_cols); + } + + return true; + } + #else + { + arma_ignore(out); + arma_ignore(A); + arma_ignore(B_expr); + arma_stop_logic_error("solve(): use of LAPACK must be enabled"); + return false; + } + #endif + } + + + +//! solve a non-square full-rank system via QR or LQ decomposition with rcond estimate (experimental) +template +inline +bool +auxlib::solve_rect_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_LAPACK) + { + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out_rcond = T(0); + + const unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); + + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } arma_debug_assert_blas_size(A,B); - Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols ); + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); if(arma::size(tmp) == arma::size(B)) { @@ -4190,23 +4769,82 @@ auxlib::solve_approx_fast(Mat& out, Mat work( static_cast(lwork) ); + blas_int lwork_proposed = 0; + + if(A.n_elem >= ((is_cx::yes) ? uword(256) : uword(1024))) + { + eT work_query[2] = {}; + blas_int lwork_query = -1; + + arma_extra_debug_print("lapack::gels()"); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, &work_query[0], &lwork_query, &info ); + + if(info != 0) { return false; } + + lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + } + + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); + + podarray work( static_cast(lwork_final) ); arma_extra_debug_print("lapack::gels()"); - lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork, &info ); + lapack::gels( &trans, &m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, work.memptr(), &lwork_final, &info ); if(info != 0) { return false; } + if(A.n_rows >= A.n_cols) + { + arma_extra_debug_print("estimating rcond via R"); + + // xGELS docs: for M >= N, A contains details of its QR decomposition as returned by xGEQRF + // xGEQRF docs: elements on and above the diagonal contain the min(M,N)-by-N upper trapezoidal matrix R + + Mat R(A.n_cols, A.n_cols, arma_zeros_indicator()); + + for(uword col=0; col < A.n_cols; ++col) + { + for(uword row=0; row <= col; ++row) + { + R.at(row,col) = A.at(row,col); + } + } + + // determine quality of solution + out_rcond = auxlib::rcond_trimat(R, 0); // 0: upper triangular; 1: lower triangular + } + else + if(A.n_rows < A.n_cols) + { + arma_extra_debug_print("estimating rcond via L"); + + // xGELS docs: for M < N, A contains details of its LQ decomposition as returned by xGELQF + // xGELQF docs: elements on and below the diagonal contain the M-by-min(M,N) lower trapezoidal matrix L + + Mat L(A.n_rows, A.n_rows, arma_zeros_indicator()); + + for(uword col=0; col < A.n_rows; ++col) + { + for(uword row=col; row < A.n_rows; ++row) + { + L.at(row,col) = A.at(row,col); + } + } + + // determine quality of solution + out_rcond = auxlib::rcond_trimat(L, 1); // 0: upper triangular; 1: lower triangular + } + if(tmp.n_rows == A.n_cols) { out.steal_mem(tmp); @@ -4221,6 +4859,7 @@ auxlib::solve_approx_fast(Mat& out, Mat& out, Mat U(B_expr.get_ref()); const Mat& B = U.M; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_cols, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A,B); - Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols ); + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); if(arma::size(tmp) == arma::size(B)) { @@ -4267,18 +4905,18 @@ auxlib::solve_approx_svd(Mat& out, Mat::epsilon(); + blas_int rank = blas_int(0); + blas_int info = blas_int(0); - podarray S(min_mn); + podarray S( static_cast(min_mn) ); // NOTE: with LAPACK 3.8, can use the workspace query to also obtain liwork, // NOTE: which makes the call to lapack::laenv() redundant @@ -4301,13 +4939,15 @@ auxlib::solve_approx_svd(Mat& out, Mat iwork( static_cast(liwork) ); - eT work_query[2]; - blas_int lwork_query = blas_int(-1); + blas_int lwork_min = blas_int(12)*min_mn + blas_int(2)*min_mn*smlsiz + blas_int(8)*min_mn*nlvl + min_mn*nrhs + smlsiz_p1*smlsiz_p1; + + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); arma_extra_debug_print("lapack::gelsd()"); lapack::gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, &work_query[0], &lwork_query, iwork.memptr(), &info); @@ -4316,12 +4956,13 @@ auxlib::solve_approx_svd(Mat& out, Mat( access::tmp_real(work_query[0]) ); + blas_int lwork_proposed = static_cast( access::tmp_real(work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - podarray work( static_cast(lwork) ); + podarray work( static_cast(lwork_final) ); arma_extra_debug_print("lapack::gelsd()"); - lapack::gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork, iwork.memptr(), &info); + lapack::gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork_final, iwork.memptr(), &info); if(info != 0) { return false; } @@ -4364,17 +5005,16 @@ auxlib::solve_approx_svd(Mat< std::complex >& out, Mat< s const unwrap U(B_expr.get_ref()); const Mat& B = U.M; - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_cols, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_cols, B.n_cols); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A,B); - Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols ); + Mat tmp( (std::max)(A.n_rows, A.n_cols), B.n_cols, arma_nozeros_indicator() ); if(arma::size(tmp) == arma::size(B)) { @@ -4386,18 +5026,18 @@ auxlib::solve_approx_svd(Mat< std::complex >& out, Mat< s tmp(0,0, arma::size(B)) = B; } - blas_int m = blas_int(A.n_rows); - blas_int n = blas_int(A.n_cols); - blas_int nrhs = blas_int(B.n_cols); - blas_int lda = blas_int(A.n_rows); - blas_int ldb = blas_int(tmp.n_rows); - T rcond = T(-1); // -1 means "use machine precision" - blas_int rank = blas_int(0); - blas_int info = blas_int(0); - - const uword min_mn = (std::min)(A.n_rows, A.n_cols); + blas_int m = blas_int(A.n_rows); + blas_int n = blas_int(A.n_cols); + blas_int min_mn = (std::min)(m, n); + blas_int nrhs = blas_int(B.n_cols); + blas_int lda = blas_int(A.n_rows); + blas_int ldb = blas_int(tmp.n_rows); + //T rcond = T(-1); // -1 means "use machine precision" + T rcond = (std::max)(A.n_rows, A.n_cols) * std::numeric_limits::epsilon(); + blas_int rank = blas_int(0); + blas_int info = blas_int(0); - podarray S(min_mn); + podarray S( static_cast(min_mn) ); blas_int ispec = blas_int(9); @@ -4414,10 +5054,10 @@ auxlib::solve_approx_svd(Mat< std::complex >& out, Mat< s blas_int laenv_result = (arma_config::hidden_args) ? blas_int(lapack::laenv(&ispec, name, opts, &n1, &n2, &n3, &n4, 6, 1)) : blas_int(0); - blas_int smlsiz = (std::max)( blas_int(25), laenv_result ); + blas_int smlsiz = (std::max)( blas_int(25), laenv_result ); blas_int smlsiz_p1 = blas_int(1) + smlsiz; - blas_int nlvl = (std::max)( blas_int(0), blas_int(1) + blas_int( std::log(double(min_mn) / double(smlsiz_p1))/double(0.69314718055994530942) ) ); + blas_int nlvl = (std::max)( blas_int(0), blas_int(1) + blas_int( std::log2( double(min_mn)/double(smlsiz_p1) ) ) ); blas_int lrwork = (m >= n) ? blas_int(10)*n + blas_int(2)*n*smlsiz + blas_int(8)*n*nlvl + blas_int(3)*smlsiz*nrhs + (std::max)( (smlsiz_p1)*(smlsiz_p1), n*(blas_int(1)+nrhs) + blas_int(2)*nrhs ) @@ -4428,20 +5068,23 @@ auxlib::solve_approx_svd(Mat< std::complex >& out, Mat< s podarray rwork( static_cast(lrwork) ); podarray iwork( static_cast(liwork) ); - eT work_query[2]; - blas_int lwork_query = blas_int(-1); + blas_int lwork_min = 2*min_mn + min_mn*nrhs; + + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); arma_extra_debug_print("lapack::cx_gelsd()"); lapack::cx_gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, &work_query[0], &lwork_query, rwork.memptr(), iwork.memptr(), &info); if(info != 0) { return false; } - blas_int lwork = static_cast( access::tmp_real( work_query[0]) ); + blas_int lwork_proposed = static_cast( access::tmp_real( work_query[0]) ); + blas_int lwork_final = (std::max)(lwork_proposed, lwork_min); - podarray work( static_cast(lwork) ); + podarray work( static_cast(lwork_final) ); arma_extra_debug_print("lapack::cx_gelsd()"); - lapack::cx_gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork, rwork.memptr(), iwork.memptr(), &info); + lapack::cx_gelsd(&m, &n, &nrhs, A.memptr(), &lda, tmp.memptr(), &ldb, S.memptr(), &rcond, &rank, work.memptr(), &lwork_final, rwork.memptr(), iwork.memptr(), &info); if(info != 0) { return false; } @@ -4483,13 +5126,9 @@ auxlib::solve_trimat_fast(Mat& out, const Mat& out, const Mat inline bool -auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout, const bool allow_ugly) +auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const Base& B_expr, const uword layout) { arma_extra_debug_sigprint(); @@ -4537,13 +5176,9 @@ auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_ty const uword B_n_rows = out.n_rows; const uword B_n_cols = out.n_cols; - arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); - if(A.is_empty() || out.is_empty()) - { - out.zeros(A.n_cols, B_n_cols); - return true; - } + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_cols, B_n_cols); return true; } arma_debug_assert_blas_size(A,out); @@ -4562,8 +5197,6 @@ auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_ty // determine quality of solution out_rcond = auxlib::rcond_trimat(A, layout); - if( (allow_ugly == false) && (out_rcond < auxlib::epsilon_lapack(A)) ) { return false; } - return true; } #else @@ -4573,7 +5206,6 @@ auxlib::solve_trimat_rcond(Mat& out, typename T1::pod_ty arma_ignore(A); arma_ignore(B_expr); arma_ignore(layout); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -4638,13 +5270,9 @@ auxlib::solve_band_fast_common(Mat& out, const Mat& out, const Mat inline bool -auxlib::solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool allow_ugly) +auxlib::solve_band_rcond(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr) { arma_extra_debug_sigprint(); - return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr, allow_ugly); + return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr); } @@ -4704,7 +5332,7 @@ auxlib::solve_band_rcond(Mat& out, typename T1::pod_type& template inline bool -auxlib::solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr, const bool allow_ugly) +auxlib::solve_band_rcond(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base< std::complex,T1>& B_expr) { arma_extra_debug_sigprint(); @@ -4715,11 +5343,11 @@ auxlib::solve_band_rcond(Mat< std::complex >& out, typena arma_ignore(KL); arma_ignore(KU); - return auxlib::solve_square_rcond(out, out_rcond, A, B_expr, allow_ugly); + return auxlib::solve_square_rcond(out, out_rcond, A, B_expr); } #else { - return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr, allow_ugly); + return auxlib::solve_band_rcond_common(out, out_rcond, A, KL, KU, B_expr); } #endif } @@ -4730,7 +5358,7 @@ auxlib::solve_band_rcond(Mat< std::complex >& out, typena template inline bool -auxlib::solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool allow_ugly) +auxlib::solve_band_rcond_common(Mat& out, typename T1::pod_type& out_rcond, const Mat& A, const uword KL, const uword KU, const Base& B_expr) { arma_extra_debug_sigprint(); @@ -4746,13 +5374,9 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p const uword B_n_rows = out.n_rows; const uword B_n_cols = out.n_cols; - arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B_n_rows), "solve(): number of rows in given matrices must be the same", [&](){ out.soft_reset(); } ); - if(A.is_empty() || out.is_empty()) - { - out.zeros(A.n_rows, B_n_cols); - return true; - } + if(A.is_empty() || out.is_empty()) { out.zeros(A.n_rows, B_n_cols); return true; } // for gbtrf, matrix AB size: 2*KL+KU+1 x N; band representation of A stored in rows KL+1 to 2*KL+KU+1 (note: fortran counts from 1) @@ -4763,7 +5387,7 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p arma_debug_assert_blas_size(AB,out); - char norm_id = '1'; + //char norm_id = '1'; char trans = 'N'; blas_int n = blas_int(N); // assuming square matrix blas_int kl = blas_int(KL); @@ -4774,11 +5398,14 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p blas_int info = blas_int(0); T norm_val = T(0); - podarray junk(1); + //podarray junk(1); podarray ipiv(N + 2); // +2 for paranoia - arma_extra_debug_print("lapack::langb()"); - norm_val = lapack::langb(&norm_id, &n, &kl, &ku, AB.memptr(), &ldab, junk.memptr()); + // // NOTE: lapack::langb() and lapack::gbtrf() use incompatible storage formats for the band matrix + // arma_extra_debug_print("lapack::langb()"); + // norm_val = lapack::langb(&norm_id, &n, &kl, &ku, AB.memptr(), &ldab, junk.memptr()); + + norm_val = auxlib::norm1_band(A,KL,KU); arma_extra_debug_print("lapack::gbtrf()"); lapack::gbtrf(&n, &n, &kl, &ku, AB.memptr(), &ldab, ipiv.memptr(), &info); @@ -4792,8 +5419,6 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p out_rcond = auxlib::lu_rcond_band(AB, KL, KU, ipiv, norm_val); - if( (allow_ugly == false) && (out_rcond < auxlib::epsilon_lapack(AB)) ) { return false; } - return true; } #else @@ -4804,7 +5429,6 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p arma_ignore(KL); arma_ignore(KU); arma_ignore(B_expr); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -4817,7 +5441,7 @@ auxlib::solve_band_rcond_common(Mat& out, typename T1::p template inline bool -auxlib::solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_band_refine(Mat& out, typename T1::pod_type& out_rcond, Mat& A, const uword KL, const uword KU, const Base& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -4827,13 +5451,9 @@ auxlib::solve_band_refine(Mat& out, typename T1::pod_type Mat B = B_expr.get_ref(); // B is overwritten - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } // for gbsvx, matrix AB size: KL+KU+1 x N; band representation of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1) @@ -4846,7 +5466,7 @@ auxlib::solve_band_refine(Mat& out, typename T1::pod_type out.set_size(N, B.n_cols); - Mat AFB(2*KL+KU+1, N); + Mat AFB(2*KL+KU+1, N, arma_nozeros_indicator()); char fact = (equilibrate) ? 'E' : 'N'; char trans = 'N'; @@ -4892,7 +5512,7 @@ auxlib::solve_band_refine(Mat& out, typename T1::pod_type out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -4903,7 +5523,6 @@ auxlib::solve_band_refine(Mat& out, typename T1::pod_type arma_ignore(KU); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -4916,7 +5535,7 @@ auxlib::solve_band_refine(Mat& out, typename T1::pod_type template inline bool -auxlib::solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate, const bool allow_ugly) +auxlib::solve_band_refine(Mat< std::complex >& out, typename T1::pod_type& out_rcond, Mat< std::complex >& A, const uword KL, const uword KU, const Base,T1>& B_expr, const bool equilibrate) { arma_extra_debug_sigprint(); @@ -4927,7 +5546,7 @@ auxlib::solve_band_refine(Mat< std::complex >& out, typen arma_ignore(KL); arma_ignore(KU); - return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate, allow_ugly); + return auxlib::solve_square_refine(out, out_rcond, A, B_expr, equilibrate); } #elif defined(ARMA_USE_LAPACK) { @@ -4936,13 +5555,9 @@ auxlib::solve_band_refine(Mat< std::complex >& out, typen Mat B = B_expr.get_ref(); // B is overwritten - arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in the given matrices must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "solve(): number of rows in given matrices must be the same" ); - if(A.is_empty() || B.is_empty()) - { - out.zeros(A.n_rows, B.n_cols); - return true; - } + if(A.is_empty() || B.is_empty()) { out.zeros(A.n_rows, B.n_cols); return true; } // for gbsvx, matrix AB size: KL+KU+1 x N; band representation of A stored in rows 1 to KL+KU+1 (note: fortran counts from 1) @@ -4955,7 +5570,7 @@ auxlib::solve_band_refine(Mat< std::complex >& out, typen out.set_size(N, B.n_cols); - Mat AFB(2*KL+KU+1, N); + Mat AFB(2*KL+KU+1, N, arma_nozeros_indicator()); char fact = (equilibrate) ? 'E' : 'N'; char trans = 'N'; @@ -5001,7 +5616,7 @@ auxlib::solve_band_refine(Mat< std::complex >& out, typen out_rcond = rcond; - return (allow_ugly) ? ((info == 0) || (info == (n+1))) : (info == 0); + return ((info == 0) || (info == (n+1))); } #else { @@ -5012,7 +5627,6 @@ auxlib::solve_band_refine(Mat< std::complex >& out, typen arma_ignore(KU); arma_ignore(B_expr); arma_ignore(equilibrate); - arma_ignore(allow_ugly); arma_stop_logic_error("solve(): use of LAPACK must be enabled"); return false; } @@ -5074,13 +5688,9 @@ auxlib::solve_tridiag_fast_common(Mat& out, const Mat tridiag; band_helper::extract_tridiag(tridiag, A); @@ -5126,12 +5736,7 @@ auxlib::schur(Mat& U, Mat& S, const Base& X, const bool calc_U) arma_debug_check( (S.is_square() == false), "schur(): given matrix must be square sized" ); - if(S.is_empty()) - { - U.reset(); - S.reset(); - return true; - } + if(S.is_empty()) { U.reset(); S.reset(); return true; } arma_debug_assert_blas_size(S); @@ -5145,7 +5750,7 @@ auxlib::schur(Mat& U, Mat& S, const Base& X, const bool calc_U) blas_int n = blas_int(S_n_rows); blas_int sdim = 0; blas_int ldvs = calc_U ? n : blas_int(1); - blas_int lwork = 3 * ((std::max)(blas_int(1), 3*n)); + blas_int lwork = 64*n; // lwork_min = (std::max)(blas_int(1), 3*n) blas_int info = 0; podarray wr(S_n_rows); @@ -5176,7 +5781,7 @@ auxlib::schur(Mat& U, Mat& S, const Base& X, const bool calc_U) template inline bool -auxlib::schur(Mat >& U, Mat >& S, const Base,T1>& X, const bool calc_U) +auxlib::schur(Mat< std::complex >& U, Mat< std::complex >& S, const Base,T1>& X, const bool calc_U) { arma_extra_debug_sigprint(); @@ -5192,7 +5797,7 @@ auxlib::schur(Mat >& U, Mat >& S, const Base inline bool -auxlib::schur(Mat >& U, Mat >& S, const bool calc_U) +auxlib::schur(Mat< std::complex >& U, Mat< std::complex >& S, const bool calc_U) { arma_extra_debug_sigprint(); @@ -5200,12 +5805,7 @@ auxlib::schur(Mat >& U, Mat >& S, const bool cal { typedef std::complex eT; - if(S.is_empty()) - { - U.reset(); - S.reset(); - return true; - } + if(S.is_empty()) { U.reset(); S.reset(); return true; } arma_debug_assert_blas_size(S); @@ -5219,7 +5819,7 @@ auxlib::schur(Mat >& U, Mat >& S, const bool cal blas_int n = blas_int(S_n_rows); blas_int sdim = 0; blas_int ldvs = calc_U ? n : blas_int(1); - blas_int lwork = 3 * ((std::max)(blas_int(1), 2*n)); + blas_int lwork = 64*n; // lwork_min = (std::max)(blas_int(1), 2*n) blas_int info = 0; podarray w(S_n_rows); @@ -5246,7 +5846,7 @@ auxlib::schur(Mat >& U, Mat >& S, const bool cal // -// syl (solution of the Sylvester equation AX + XB = C) +// solve the Sylvester equation AX + XB = C template inline @@ -5261,21 +5861,14 @@ auxlib::syl(Mat& X, const Mat& A, const Mat& B, const Mat& C) arma_debug_check( (C.n_rows != A.n_rows) || (C.n_cols != B.n_cols), "syl(): matrices are not conformant" ); - if(A.is_empty() || B.is_empty() || C.is_empty()) - { - X.reset(); - return true; - } - + if(A.is_empty() || B.is_empty() || C.is_empty()) { X.reset(); return true; } + Mat Z1, Z2, T1, T2; const bool status_sd1 = auxlib::schur(Z1, T1, A); const bool status_sd2 = auxlib::schur(Z2, T2, B); - if( (status_sd1 == false) || (status_sd2 == false) ) - { - return false; - } + if( (status_sd1 == false) || (status_sd2 == false) ) { return false; } char trana = 'N'; char tranb = 'N'; @@ -5329,18 +5922,14 @@ auxlib::qz(Mat& A, Mat& B, Mat& vsl, Mat& vsr, const Base& X_e A = X_expr.get_ref(); B = Y_expr.get_ref(); - arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized" ); + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized", [&](){ A.soft_reset(); B.soft_reset(); } ); arma_debug_check( (A.n_rows != B.n_rows), "qz(): given matrices must have the same size" ); - if(A.is_empty()) - { - A.reset(); - B.reset(); - vsl.reset(); - vsr.reset(); - return true; - } + if(A.is_empty()) { A.reset(); B.reset(); vsl.reset(); vsr.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -5353,7 +5942,7 @@ auxlib::qz(Mat& A, Mat& B, Mat& vsl, Mat& vsr, const Base& X_e void* selctg = 0; blas_int N = blas_int(A.n_rows); blas_int sdim = 0; - blas_int lwork = 3 * ((std::max)(blas_int(1),8*N+16)); + blas_int lwork = 64*N+16; // lwork_min = (std::max)(blas_int(1),8*N+16) blas_int info = 0; if(mode == 'l') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::select_lhp)); } @@ -5420,18 +6009,14 @@ auxlib::qz(Mat< std::complex >& A, Mat< std::complex >& B, Mat< std::compl A = X_expr.get_ref(); B = Y_expr.get_ref(); - arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized" ); + arma_debug_check( ((A.is_square() == false) || (B.is_square() == false)), "qz(): given matrices must be square sized", [&](){ A.soft_reset(); B.soft_reset(); } ); arma_debug_check( (A.n_rows != B.n_rows), "qz(): given matrices must have the same size" ); - if(A.is_empty()) - { - A.reset(); - B.reset(); - vsl.reset(); - vsr.reset(); - return true; - } + if(A.is_empty()) { A.reset(); B.reset(); vsl.reset(); vsr.reset(); return true; } + + if(arma_config::check_nonfinite && A.internal_has_nonfinite()) { return false; } + if(arma_config::check_nonfinite && B.internal_has_nonfinite()) { return false; } arma_debug_assert_blas_size(A); @@ -5444,7 +6029,7 @@ auxlib::qz(Mat< std::complex >& A, Mat< std::complex >& B, Mat< std::compl void* selctg = 0; blas_int N = blas_int(A.n_rows); blas_int sdim = 0; - blas_int lwork = 3 * ((std::max)(blas_int(1),2*N)); + blas_int lwork = 64*N; // lwork_min = (std::max)(blas_int(1),2*N) blas_int info = 0; if(mode == 'l') { eigsort = 'S'; selctg = qz_helper::ptr_cast(&(qz_helper::cx_select_lhp)); } @@ -5516,7 +6101,7 @@ auxlib::rcond(Mat& A) podarray ipiv( (std::min)(A.n_rows, A.n_cols) ); arma_extra_debug_print("lapack::lange()"); - norm_val = lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, work.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, work.memptr()); arma_extra_debug_print("lapack::getrf()"); lapack::getrf(&m, &n, A.memptr(), &lda, ipiv.memptr(), &info); @@ -5566,7 +6151,7 @@ auxlib::rcond(Mat< std::complex >& A) podarray ipiv( (std::min)(A.n_rows, A.n_cols) ); arma_extra_debug_print("lapack::lange()"); - norm_val = lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, junk.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_gen(A) : lapack::lange(&norm_id, &m, &n, A.memptr(), &lda, junk.memptr()); arma_extra_debug_print("lapack::getrf()"); lapack::getrf(&m, &n, A.memptr(), &lda, ipiv.memptr(), &info); @@ -5614,7 +6199,7 @@ auxlib::rcond_sympd(Mat& A, bool& calc_ok) podarray iwork( A.n_rows); arma_extra_debug_print("lapack::lansy()"); - norm_val = lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &lda, work.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lansy(&norm_id, &uplo, &n, A.memptr(), &lda, work.memptr()); arma_extra_debug_print("lapack::potrf()"); lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); @@ -5675,7 +6260,7 @@ auxlib::rcond_sympd(Mat< std::complex >& A, bool& calc_ok) podarray< T> rwork( A.n_rows); arma_extra_debug_print("lapack::lanhe()"); - norm_val = lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, rwork.memptr()); + norm_val = (has_blas_float_bug::value) ? auxlib::norm1_sym(A) : lapack::lanhe(&norm_id, &uplo, &n, A.memptr(), &lda, rwork.memptr()); arma_extra_debug_print("lapack::potrf()"); lapack::potrf(&uplo, &n, A.memptr(), &lda, &info); @@ -6038,58 +6623,6 @@ auxlib::crippled_lapack(const Base&) -template -inline -typename T1::pod_type -auxlib::epsilon_lapack(const Base&) - { - typedef typename T1::pod_type T; - - return T(0.5)*std::numeric_limits::epsilon(); - - // value reverse engineered from dgesvx.f and dlamch.f - // http://www.netlib.org/lapack/explore-html/da/d21/dgesvx_8f.html - // http://www.netlib.org/lapack/explore-html/d5/dd4/dlamch_8f.html - // - // Fortran epsilon(X) function: - // https://gcc.gnu.org/onlinedocs/gfortran/EPSILON.html - // "EPSILON(X) returns the smallest number E of the same kind as X such that 1 + E > 1" - // - // C++ std::numeric_limits::epsilon() function: - // https://en.cppreference.com/w/cpp/types/numeric_limits/epsilon - // "the difference between 1.0 and the next value representable by the floating-point type T" - // - // extract from dgesvx.f: - // - // IF( rcond.LT.dlamch( 'Epsilon' ) ) - // info = n + 1 - // RETURN - // - // extract from dlamch.f: - // - // * rnd = 1.0 when rounding occurs in addition, 0.0 otherwise - // ... - // * Assume rounding, not chopping. Always - // - // rnd = one - // - // IF( one.EQ.rnd ) THEN - // eps = epsilon(zero) * 0.5 - // ELSE - // eps = epsilon(zero) - // END IF - // ... - // IF( lsame( cmach, 'E' ) ) THEN - // rmach = eps - // ... - // END IF - // ... - // dlamch = rmach - // RETURN - } - - - template inline bool @@ -6149,6 +6682,11 @@ auxlib::rudimentary_sym_check(const Mat< std::complex >& X) const eT* X_mem = X.memptr(); + const T tol = T(10000)*std::numeric_limits::epsilon(); // allow some leeway + + if(std::abs(X_mem[0 ].imag()) > tol) { return false; } // check top-left + if(std::abs(X_mem[X.n_elem-1].imag()) > tol) { return false; } // check bottom-right + const eT& A = X_mem[Nm1 ]; // bottom-left corner (ie. last value in first column) const eT& B = X_mem[Nm1*N]; // top-right corner (ie. first value in last column) @@ -6158,8 +6696,6 @@ auxlib::rudimentary_sym_check(const Mat< std::complex >& X) const T delta_real = std::abs(A.real() - B.real()); const T delta_imag = std::abs(A.imag() + B.imag()); // take into account the conjugate - const T tol = T(10000)*std::numeric_limits::epsilon(); // allow some leeway - const bool okay_real = ( (delta_real <= tol) || (delta_real <= (C_real * tol)) ); const bool okay_imag = ( (delta_imag <= tol) || (delta_imag <= (C_imag * tol)) ); @@ -6168,6 +6704,105 @@ auxlib::rudimentary_sym_check(const Mat< std::complex >& X) +template +inline +typename get_pod_type::result +auxlib::norm1_gen(const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + T max_val = T(0); + + for(uword c=0; c < n_cols; ++c) + { + const eT* colmem = A.colptr(c); + T acc_val = T(0); + + for(uword r=0; r < n_rows; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + +template +inline +typename get_pod_type::result +auxlib::norm1_sym(const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword N = (std::min)(A.n_rows, A.n_cols); + + T max_val = T(0); + + for(uword col=0; col < N; ++col) + { + const eT* colmem = A.colptr(col); + T acc_val = T(0); + + for(uword c=0; c < col; ++c) { acc_val += std::abs(A.at(col,c)); } + + for(uword r=col; r < N; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + +template +inline +typename get_pod_type::result +auxlib::norm1_band(const Mat& A, const uword KL, const uword KU) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.n_elem == 0) { return T(0); } + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + T max_val = T(0); + + for(uword c=0; c < n_cols; ++c) + { + const eT* colmem = A.colptr(c); + T acc_val = T(0); + + // use values only from main diagonal + KU upper diagonals + KL lower diagonals + + const uword start = ( c > KU ) ? (c - KU) : 0; + const uword end = ((c + KL) < n_rows) ? (c + KL) : (n_rows-1); + + for(uword r=start; r <= end; ++r) { acc_val += std::abs(colmem[r]); } + + max_val = (acc_val > max_val) ? acc_val : max_val; + } + + return max_val; + } + + + // diff --git a/src/armadillo_bits/band_helper.hpp b/src/armadillo_bits/band_helper.hpp index 6cdba739..4493c704 100644 --- a/src/armadillo_bits/band_helper.hpp +++ b/src/armadillo_bits/band_helper.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -37,7 +39,7 @@ is_band(uword& out_KL, uword& out_KU, const Mat& A, const uword N_min) if(N < N_min) { return false; } - // first, quickly check bottom-right and top-left corners + // first, quickly check bottom-left and top-right corners const eT eT_zero = eT(0); @@ -115,7 +117,7 @@ is_band_lower(uword& out_KD, const Mat& A, const uword N_min) if(N < N_min) { return false; } - // first, quickly check bottom-right corner + // first, quickly check bottom-left corner const eT eT_zero = eT(0); @@ -178,7 +180,7 @@ is_band_upper(uword& out_KD, const Mat& A, const uword N_min) if(N < N_min) { return false; } - // first, quickly check top-left corner + // first, quickly check top-right corner const eT eT_zero = eT(0); diff --git a/src/armadillo_bits/compiler_check.hpp b/src/armadillo_bits/compiler_check.hpp new file mode 100644 index 00000000..8a653d24 --- /dev/null +++ b/src/armadillo_bits/compiler_check.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#undef ARMA_HAVE_CXX11 +#undef ARMA_HAVE_CXX14 +#undef ARMA_HAVE_CXX17 +#undef ARMA_HAVE_CXX20 + +#if (__cplusplus >= 201103L) + #define ARMA_HAVE_CXX11 +#endif + +#if (__cplusplus >= 201402L) + #define ARMA_HAVE_CXX14 +#endif + +#if (__cplusplus >= 201703L) + #define ARMA_HAVE_CXX17 +#endif + +#if (__cplusplus >= 202002L) + #define ARMA_HAVE_CXX20 +#endif + + +// MS really can't get its proverbial shit together +#if defined(_MSVC_LANG) + + #if (_MSVC_LANG >= 201402L) + #undef ARMA_HAVE_CXX11 + #define ARMA_HAVE_CXX11 + + #undef ARMA_HAVE_CXX14 + #define ARMA_HAVE_CXX14 + #endif + + #if (_MSVC_LANG >= 201703L) + #undef ARMA_HAVE_CXX17 + #define ARMA_HAVE_CXX17 + #endif + + #if (_MSVC_LANG >= 202002L) + #undef ARMA_HAVE_CXX20 + #define ARMA_HAVE_CXX20 + #endif + +#endif + + +// warn about ignored option used in old versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11) + #pragma message ("WARNING: option ARMA_DONT_USE_CXX11 ignored") +#endif + + +#if !defined(ARMA_HAVE_CXX11) + #error "*** C++11 compiler required; enable C++11 mode in your compiler, or use an earlier version of Armadillo" +#endif + + +// for compatibility with earlier versions of Armadillo +#undef ARMA_USE_CXX11 +#define ARMA_USE_CXX11 diff --git a/src/armadillo_bits/compiler_setup.hpp b/src/armadillo_bits/compiler_setup.hpp index 5aad533f..775b82a9 100644 --- a/src/armadillo_bits/compiler_setup.hpp +++ b/src/armadillo_bits/compiler_setup.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,6 +23,7 @@ #undef arma_align_mem #undef arma_warn_unused #undef arma_deprecated +#undef arma_frown #undef arma_malloc #undef arma_inline #undef arma_noinline @@ -32,6 +35,7 @@ #define arma_align_mem #define arma_warn_unused #define arma_deprecated +#define arma_frown(msg) #define arma_malloc #define arma_inline inline #define arma_noinline @@ -78,29 +82,17 @@ #undef ARMA_INCFILE_WRAP #define ARMA_INCFILE_WRAP(x) -#if defined(ARMA_USE_CXX11) - - #undef ARMA_USE_U64S64 - #define ARMA_USE_U64S64 - - #if !defined(ARMA_32BIT_WORD) - #undef ARMA_64BIT_WORD - #define ARMA_64BIT_WORD - #endif - - #if defined(ARMA_64BIT_WORD) && defined(SIZE_MAX) - #if (SIZE_MAX < 0xFFFFFFFFFFFFFFFFull) - // #pragma message ("WARNING: disabled use of 64 bit integers, as std::size_t is smaller than 64 bits") - #undef ARMA_64BIT_WORD - #endif - #endif - -#endif +#if !defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD + #define ARMA_64BIT_WORD +#endif -#if defined(ARMA_64BIT_WORD) - #undef ARMA_USE_U64S64 - #define ARMA_USE_U64S64 +#if defined(ARMA_64BIT_WORD) && defined(SIZE_MAX) + #if (SIZE_MAX < 0xFFFFFFFFFFFFFFFFull) + // #pragma message ("WARNING: disabled use of 64 bit integers, as std::size_t is smaller than 64 bits") + #undef ARMA_64BIT_WORD + #endif #endif @@ -111,20 +103,6 @@ #undef ARMA_GOOD_COMPILER -#undef ARMA_HAVE_TR1 -#undef ARMA_HAVE_GETTIMEOFDAY -#undef ARMA_HAVE_SNPRINTF -#undef ARMA_HAVE_ISFINITE -#undef ARMA_HAVE_LOG1P -#undef ARMA_HAVE_ISINF -#undef ARMA_HAVE_ISNAN - - -#if (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - #define ARMA_HAVE_GETTIMEOFDAY -#endif - - // posix_memalign() is part of IEEE standard 1003.1 // http://pubs.opengroup.org/onlinepubs/009696899/functions/posix_memalign.html // http://pubs.opengroup.org/onlinepubs/9699919799/basedefs/unistd.h.html @@ -136,15 +114,13 @@ #if defined(__APPLE__) || defined(__apple_build_version__) - #undef ARMA_BLAS_SDOT_BUG - #define ARMA_BLAS_SDOT_BUG + // NOTE: Apple accelerate framework has broken implementations of functions that return a float value, + // NOTE: such as sdot(), slange(), clange(), slansy(), clanhe(), slangb() + #undef ARMA_BLAS_FLOAT_BUG + #define ARMA_BLAS_FLOAT_BUG // #undef ARMA_HAVE_POSIX_MEMALIGN // NOTE: posix_memalign() is available since macOS 10.6 (late 2009 onwards) - - #undef ARMA_USE_EXTERN_CXX11_RNG - // TODO: thread_local seems to work in Apple clang since Xcode 8 (mid 2016 onwards) - // NOTE: https://stackoverflow.com/questions/28094794/why-does-apple-clang-disallow-c11-thread-local-when-official-clang-supports #endif @@ -161,10 +137,8 @@ #define ARMA_FNSIG __FUNCSIG__ #elif defined(__INTEL_COMPILER) #define ARMA_FNSIG __FUNCTION__ -#elif defined(ARMA_USE_CXX11) - #define ARMA_FNSIG __func__ #else - #define ARMA_FNSIG "(unknown)" + #define ARMA_FNSIG __func__ #endif @@ -187,19 +161,13 @@ #undef ARMA_GCC_VERSION #define ARMA_GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) - #if (ARMA_GCC_VERSION < 40400) - #error "*** newer compiler required ***" - #endif - - #if (ARMA_GCC_VERSION < 40800) - #undef ARMA_PRINT_CXX98_WARNING - #define ARMA_PRINT_CXX98_WARNING + #if (ARMA_GCC_VERSION < 40803) + #error "*** newer compiler required; need gcc 4.8.3 or newer ***" #endif - #if ( (ARMA_GCC_VERSION >= 40700) && (ARMA_GCC_VERSION <= 40701) ) - #error "gcc versions 4.7.0 and 4.7.1 are unsupported; use 4.7.2 or later" - // due to http://gcc.gnu.org/bugzilla/show_bug.cgi?id=53549 - #endif + // #if (ARMA_GCC_VERSION < 60100) + // #pragma message ("WARNING: support for gcc versions older than 6.1 is deprecated") + // #endif #define ARMA_GOOD_COMPILER @@ -209,39 +177,27 @@ #undef arma_align_mem #undef arma_warn_unused #undef arma_deprecated + #undef arma_frown #undef arma_malloc #undef arma_inline #undef arma_noinline - #define arma_hot __attribute__((__hot__)) - #define arma_cold __attribute__((__cold__)) - #define arma_aligned __attribute__((__aligned__)) - #define arma_align_mem __attribute__((__aligned__(16))) - #define arma_warn_unused __attribute__((__warn_unused_result__)) - #define arma_deprecated __attribute__((__deprecated__)) - #define arma_malloc __attribute__((__malloc__)) - #define arma_inline inline __attribute__((__always_inline__)) - #define arma_noinline __attribute__((__noinline__)) + #define arma_hot __attribute__((__hot__)) + #define arma_cold __attribute__((__cold__)) + #define arma_aligned __attribute__((__aligned__)) + #define arma_align_mem __attribute__((__aligned__(16))) + #define arma_warn_unused __attribute__((__warn_unused_result__)) + #define arma_deprecated __attribute__((__deprecated__)) + #define arma_frown(msg) __attribute__((__deprecated__(msg))) + #define arma_malloc __attribute__((__malloc__)) + #define arma_inline __attribute__((__always_inline__)) inline + #define arma_noinline __attribute__((__noinline__)) #undef ARMA_HAVE_ALIGNED_ATTRIBUTE #define ARMA_HAVE_ALIGNED_ATTRIBUTE - #if defined(ARMA_USE_CXX11) - #if (ARMA_GCC_VERSION < 40800) - #undef ARMA_PRINT_CXX11_WARNING - #define ARMA_PRINT_CXX11_WARNING - #endif - #endif - - #if !defined(ARMA_USE_CXX11) && !defined(__GXX_EXPERIMENTAL_CXX0X__) && (__cplusplus < 201103L) && !defined(ARMA_DONT_USE_TR1) - #if defined(_GLIBCXX_USE_C99_MATH_TR1) && defined(_GLIBCXX_USE_C99_COMPLEX_TR1) - #define ARMA_HAVE_TR1 - #endif - #endif - - #if (ARMA_GCC_VERSION >= 40700) - #define ARMA_HAVE_GCC_ASSUME_ALIGNED - #endif + #undef ARMA_HAVE_GCC_ASSUME_ALIGNED + #define ARMA_HAVE_GCC_ASSUME_ALIGNED // gcc's vectoriser can handle elaborate loops #undef ARMA_SIMPLE_LOOPS @@ -250,17 +206,10 @@ #define ARMA_SIMPLE_LOOPS #endif - #if !defined(ARMA_USE_CXX11) && (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - #define ARMA_HAVE_SNPRINTF - #define ARMA_HAVE_ISFINITE - #define ARMA_HAVE_LOG1P - #define ARMA_HAVE_ISINF - #define ARMA_HAVE_ISNAN - #endif - #endif +// TODO: __INTEL_CLANG_COMPILER indicates the clang based intel compiler, distinct from the classic intel compiler #if !defined(ARMA_ALLOW_FAKE_CLANG) #if defined(__clang__) && (defined(__INTEL_COMPILER) || defined(__NVCC__) || defined(__CUDACC__) || defined(__PGI) || defined(__PATHSCALE__) || defined(__ARMCC_VERSION) || defined(__IBMCPP__)) #undef ARMA_DETECTED_FAKE_CLANG @@ -304,6 +253,11 @@ #define arma_deprecated __attribute__((__deprecated__)) #endif + #if __has_attribute(__deprecated__) + #undef arma_frown + #define arma_frown(msg) __attribute__((__deprecated__(msg))) + #endif + #if __has_attribute(__malloc__) #undef arma_malloc #define arma_malloc __attribute__((__malloc__)) @@ -311,7 +265,7 @@ #if __has_attribute(__always_inline__) #undef arma_inline - #define arma_inline inline __attribute__((__always_inline__)) + #define arma_inline __attribute__((__always_inline__)) inline #endif #if __has_attribute(__noinline__) @@ -324,12 +278,12 @@ #define arma_hot __attribute__((__hot__)) #endif - #if __has_attribute(__minsize__) - #undef arma_cold - #define arma_cold __attribute__((__minsize__)) - #elif __has_attribute(__cold__) + #if __has_attribute(__cold__) #undef arma_cold #define arma_cold __attribute__((__cold__)) + #elif __has_attribute(__minsize__) + #undef arma_cold + #define arma_cold __attribute__((__minsize__)) #endif #if defined(__has_builtin) && __has_builtin(__builtin_assume_aligned) @@ -337,14 +291,6 @@ #define ARMA_HAVE_GCC_ASSUME_ALIGNED #endif - #if !defined(ARMA_USE_CXX11) && (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - #define ARMA_HAVE_SNPRINTF - #define ARMA_HAVE_ISFINITE - #define ARMA_HAVE_LOG1P - #define ARMA_HAVE_ISINF - #define ARMA_HAVE_ISNAN - #endif - #endif @@ -354,7 +300,7 @@ #error "*** newer compiler required ***" #endif - #if (__INTEL_COMPILER < 1300) + #if (__INTEL_COMPILER < 1500) #error "*** newer compiler required ***" #endif @@ -362,49 +308,31 @@ #undef ARMA_HAVE_ICC_ASSUME_ALIGNED #define ARMA_HAVE_ICC_ASSUME_ALIGNED - #if defined(ARMA_USE_CXX11) - #if (__INTEL_COMPILER < 1500) - #undef ARMA_PRINT_CXX11_WARNING - #define ARMA_PRINT_CXX11_WARNING - #endif - #endif - #endif #if defined(_MSC_VER) - #if (_MSC_VER < 1700) + #if (_MSC_VER < 1900) #error "*** newer compiler required ***" #endif - #if (_MSC_VER < 1800) - #undef ARMA_PRINT_CXX98_WARNING - #define ARMA_PRINT_CXX98_WARNING - #endif - - #if defined(ARMA_USE_CXX11) - #if (_MSC_VER < 1900) - #undef ARMA_PRINT_CXX11_WARNING - #define ARMA_PRINT_CXX11_WARNING - #endif - #endif - #undef arma_deprecated #define arma_deprecated __declspec(deprecated) // #undef arma_inline - // #define arma_inline inline __forceinline + // #define arma_inline __forceinline inline #pragma warning(push) #pragma warning(disable: 4127) // conditional expression is constant #pragma warning(disable: 4180) // qualifier has no meaning - #pragma warning(disable: 4244) // possible loss of data when converting types + #pragma warning(disable: 4244) // possible loss of data when converting types (see also 4305) #pragma warning(disable: 4510) // default constructor could not be generated #pragma warning(disable: 4511) // copy constructor can't be generated #pragma warning(disable: 4512) // assignment operator can't be generated #pragma warning(disable: 4513) // destructor can't be generated #pragma warning(disable: 4514) // unreferenced inline function has been removed + #pragma warning(disable: 4519) // default template args are only allowed on a class template (C++11) #pragma warning(disable: 4522) // multiple assignment operators specified #pragma warning(disable: 4623) // default constructor can't be generated #pragma warning(disable: 4624) // destructor can't be generated @@ -416,10 +344,12 @@ #pragma warning(disable: 4714) // __forceinline can't be inlined #pragma warning(disable: 4800) // value forced to bool - #if defined(ARMA_USE_CXX11) - #pragma warning(disable: 4519) // default template args are only allowed on a class template - #endif + // NOTE: also possible to disable 4146 (unary minus operator applied to unsigned type, result still unsigned) + #if defined(ARMA_HAVE_CXX17) + #pragma warning(disable: 26812) // unscoped enum + #pragma warning(disable: 26819) // unannotated fallthrough + #endif // #if (_MANAGED == 1) || (_M_CEE == 1) // @@ -451,49 +381,39 @@ // http://www.oracle.com/technetwork/server-storage/solarisstudio/training/index-jsp-141991.html // http://www.oracle.com/technetwork/server-storage/solarisstudio/documentation/cplusplus-faq-355066.html - #if (__SUNPRO_CC < 0x5100) + #if (__SUNPRO_CC < 0x5140) #error "*** newer compiler required ***" #endif - #if defined(ARMA_USE_CXX11) - #if (__SUNPRO_CC < 0x5130) - #undef ARMA_PRINT_CXX11_WARNING - #define ARMA_PRINT_CXX11_WARNING - #endif - #endif - #endif -#if defined(ARMA_USE_CXX11) && defined(__CYGWIN__) && !defined(ARMA_DONT_PRINT_CXX11_WARNING) - #pragma message ("WARNING: Cygwin may have incomplete support for C++11 features.") -#endif - +#if defined(ARMA_HAVE_CXX14) + #undef arma_deprecated + #define arma_deprecated [[deprecated]] -#if defined(ARMA_USE_CXX11) && (__cplusplus < 201103L) - #undef ARMA_PRINT_CXX11_WARNING - #define ARMA_PRINT_CXX11_WARNING + #undef arma_frown + #define arma_frown(msg) [[deprecated(msg)]] #endif -#if defined(ARMA_PRINT_CXX98_WARNING) && !defined(ARMA_DONT_PRINT_CXX98_WARNING) - #pragma message ("WARNING: this compiler is OUTDATED and has INCOMPLETE support for the C++ standard;") - #pragma message ("WARNING: if something breaks, you get to keep all the pieces.") +#if defined(ARMA_HAVE_CXX17) + #undef arma_warn_unused + #define arma_warn_unused [[nodiscard]] #endif -#if defined(ARMA_PRINT_CXX11_WARNING) && !defined(ARMA_DONT_PRINT_CXX11_WARNING) - #pragma message ("WARNING: use of C++11 features has been enabled,") - #pragma message ("WARNING: but this compiler has INCOMPLETE support for C++11;") - #pragma message ("WARNING: if something breaks, you get to keep all the pieces.") - #pragma message ("WARNING: to forcefully prevent Armadillo from using C++11 features,") - #pragma message ("WARNING: #define ARMA_DONT_USE_CXX11 before #include ") +#if !defined(ARMA_DONT_USE_OPENMP) + #if (defined(_OPENMP) && (_OPENMP >= 201107)) + #undef ARMA_USE_OPENMP + #define ARMA_USE_OPENMP + #endif #endif #if ( defined(ARMA_USE_OPENMP) && (!defined(_OPENMP) || (defined(_OPENMP) && (_OPENMP < 201107))) ) - // OpenMP 3.1 required for atomic read and atomic write // OpenMP 3.0 required for parallelisation of loops with unsigned integers + // OpenMP 3.1 required for atomic read and atomic write #undef ARMA_USE_OPENMP #undef ARMA_PRINT_OPENMP_WARNING #define ARMA_PRINT_OPENMP_WARNING @@ -504,26 +424,13 @@ #pragma message ("WARNING: use of OpenMP disabled; compiler support for OpenMP 3.1+ not detected") #if (defined(_OPENMP) && (_OPENMP < 201107)) - #pragma message ("NOTE: your compiler appears to have an ancient version of OpenMP") + #pragma message ("NOTE: your compiler has an outdated version of OpenMP") #pragma message ("NOTE: consider upgrading to a better compiler") #endif #endif -#if defined(ARMA_USE_OPENMP) && !defined(ARMA_USE_CXX11) - #if (defined(ARMA_GCC_VERSION) && (ARMA_GCC_VERSION >= 50400)) || (defined(__clang__) && !defined(ARMA_FAKE_CLANG)) - #undef ARMA_PRINT_OPENMP_CXX11_WARNING - #define ARMA_PRINT_OPENMP_CXX11_WARNING - #endif -#endif - - -#if defined(ARMA_PRINT_OPENMP_CXX11_WARNING) && !defined(ARMA_DONT_PRINT_OPENMP_WARNING) - #pragma message ("WARNING: support for OpenMP requires C++11/C++14; add -std=c++11 or -std=c++14 to compiler flags") -#endif - - -#if defined(ARMA_USE_OPENMP) && defined(ARMA_USE_CXX11) +#if defined(ARMA_USE_OPENMP) #if (defined(ARMA_GCC_VERSION) && (ARMA_GCC_VERSION < 50400)) // due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=57580 #undef ARMA_USE_OPENMP @@ -534,48 +441,71 @@ #endif -#if defined(ARMA_GCC_VERSION) && (ARMA_GCC_VERSION >= 50400) && !defined(ARMA_USE_CXX11) - #if !defined(ARMA_PRINT_CXX11_WARNING) && !defined(ARMA_PRINT_OPENMP_CXX11_WARNING) && !defined(ARMA_DONT_PRINT_CXX11_WARNING) - #pragma message ("NOTE: suggest to enable C++14 mode for faster code; add -std=c++14 to compiler flags") - #endif +#if (defined(__FAST_MATH__) || (defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)) || defined(_M_FP_FAST)) + #undef ARMA_FAST_MATH + #define ARMA_FAST_MATH +#endif + + +#if defined(ARMA_FAST_MATH) && !defined(ARMA_DONT_PRINT_FAST_MATH_WARNING) + #pragma message ("WARNING: compiler is in fast math mode; some functions may be unreliable.") + #pragma message ("WARNING: to suppress this warning and related warnings,") + #pragma message ("WARNING: #define ARMA_DONT_PRINT_FAST_MATH_WARNING before #include ") +#endif + + +#if ( (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) && (!defined(__MINGW32__) && !defined(__MINGW64__)) ) + #undef ARMA_PRINT_EXCEPTIONS_INTERNAL + #define ARMA_PRINT_EXCEPTIONS_INTERNAL +#endif + + +#if (defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) && !defined(ARMA_ALIEN_MEM_FREE_FUNCTION)) || (!defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) && defined(ARMA_ALIEN_MEM_FREE_FUNCTION)) + #error "*** both ARMA_ALIEN_MEM_ALLOC_FUNCTION and ARMA_ALIEN_MEM_FREE_FUNCTION must be defined ***" #endif // cleanup -#undef ARMA_FAKE_GCC -#undef ARMA_FAKE_CLANG +#undef ARMA_DETECTED_FAKE_GCC +#undef ARMA_DETECTED_FAKE_CLANG #undef ARMA_GCC_VERSION -#undef ARMA_PRINT_CXX98_WARNING -#undef ARMA_PRINT_CXX11_WARNING #undef ARMA_PRINT_OPENMP_WARNING -#undef ARMA_PRINT_OPENMP_CXX11_WARNING +// undefine conflicting macros + #if defined(log2) #undef log2 - #pragma message ("WARNING: detected 'log2' macro and undefined it") + #pragma message ("WARNING: undefined conflicting 'log2' macro") #endif - - -// -// whoever defined macros with the names "min" and "max" should be permanently removed from the gene pool +#if defined(check) + #undef check + #pragma message ("WARNING: undefined conflicting 'check' macro") +#endif #if defined(min) || defined(max) #undef min #undef max - #pragma message ("WARNING: detected 'min' and/or 'max' macros and undefined them;") - #pragma message ("WARNING: you may wish to define NOMINMAX before including any windows header") + #pragma message ("WARNING: undefined conflicting 'min' and/or 'max' macros") #endif - - -// -// handle more stupid macros // https://sourceware.org/bugzilla/show_bug.cgi?id=19239 - #undef minor #undef major + + +// optionally allow disabling of compile-time deprecation messages (not recommended) +// NOTE: option 'ARMA_IGNORE_DEPRECATED_MARKER' will be removed +// NOTE: disabling deprecation messages is counter-productive + +#if defined(ARMA_IGNORE_DEPRECATED_MARKER) && (!defined(ARMA_DONT_IGNORE_DEPRECATED_MARKER)) && (!defined(ARMA_EXTRA_DEBUG)) + #undef arma_deprecated + #define arma_deprecated + + #undef arma_frown + #define arma_frown(msg) +#endif diff --git a/src/armadillo_bits/compiler_setup_post.hpp b/src/armadillo_bits/compiler_setup_post.hpp index c30341cf..6274b7ee 100644 --- a/src/armadillo_bits/compiler_setup_post.hpp +++ b/src/armadillo_bits/compiler_setup_post.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/cond_rel_bones.hpp b/src/armadillo_bits/cond_rel_bones.hpp index c56d2f11..a160d26d 100644 --- a/src/armadillo_bits/cond_rel_bones.hpp +++ b/src/armadillo_bits/cond_rel_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/cond_rel_meat.hpp b/src/armadillo_bits/cond_rel_meat.hpp index c4742a1a..a285774d 100644 --- a/src/armadillo_bits/cond_rel_meat.hpp +++ b/src/armadillo_bits/cond_rel_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/config.hpp b/src/armadillo_bits/config.hpp index 4d8b7ff9..6d7874ac 100644 --- a/src/armadillo_bits/config.hpp +++ b/src/armadillo_bits/config.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,17 +17,32 @@ +#if !defined(ARMA_WARN_LEVEL) + #define ARMA_WARN_LEVEL 2 +#endif +//// The level of warning messages printed to ARMA_CERR_STREAM. +//// Must be an integer >= 0. The default value is 2. +//// 0 = no warnings; generally not recommended +//// 1 = only critical warnings about arguments and/or data which are likely to lead to incorrect results +//// 2 = as per level 1, and warnings about poorly conditioned systems (low rcond) detected by solve(), spsolve(), etc +//// 3 = as per level 2, and warnings about failed decompositions, failed saving/loading, etc + +// #define ARMA_USE_WRAPPER +//// Comment out the above line if you prefer to directly link with BLAS, LAPACK, etc +//// instead of the Armadillo runtime library. +//// You will need to link your programs directly with -lopenblas -llapack instead of -larmadillo + #if !defined(ARMA_USE_LAPACK) #define ARMA_USE_LAPACK //// Comment out the above line if you don't have LAPACK or a high-speed replacement for LAPACK, -//// such as Intel MKL, AMD ACML, or the Accelerate framework. +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. //// LAPACK is required for matrix decompositions (eg. SVD) and matrix inverse. #endif #if !defined(ARMA_USE_BLAS) #define ARMA_USE_BLAS //// Comment out the above line if you don't have BLAS or a high-speed replacement for BLAS, -//// such as OpenBLAS, GotoBLAS, Intel MKL, AMD ACML, or the Accelerate framework. +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. //// BLAS is used for matrix multiplication. //// Without BLAS, matrix multiplication will still work, but might be slower. #endif @@ -56,13 +73,30 @@ //// Make sure the directory has a trailing / #endif -// #define ARMA_USE_WRAPPER -//// Comment out the above line if you're getting linking errors when compiling your programs, -//// or if you prefer to directly link with LAPACK, BLAS + etc instead of the Armadillo runtime library. -//// You will then need to link your programs directly with -llapack -lblas instead of -larmadillo +#if !defined(ARMA_USE_ATLAS) +// #define ARMA_USE_ATLAS +//// NOTE: support for ATLAS is deprecated and will be removed. +#endif + +#if !defined(ARMA_USE_HDF5) +// #define ARMA_USE_HDF5 +//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; +//// the hdf5.h header file must be available on your system, +//// and you will need to link with the hdf5 library (eg. -lhdf5) +#endif + +#if !defined(ARMA_USE_FFTW3) +// #define ARMA_USE_FFTW3 +//// Uncomment the above line to allow the use of the FFTW3 library by fft() and ifft() functions; +//// you will need to link with the FFTW3 library (eg. -lfftw3) +#endif + +#if defined(ARMA_USE_FFTW) + #error "use ARMA_USE_FFTW3 instead of ARMA_USE_FFTW" +#endif // #define ARMA_BLAS_CAPITALS -//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names (eg. ACML on 64-bit Windows) +//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names #define ARMA_BLAS_UNDERSCORE //// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. @@ -74,6 +108,12 @@ // #define ARMA_BLAS_LONG_LONG //// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" +// #define ARMA_BLAS_NOEXCEPT +//// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification + +// #define ARMA_LAPACK_NOEXCEPT +//// Uncomment the above line if you require LAPACK functions to have the 'noexcept' specification + #define ARMA_USE_FORTRAN_HIDDEN_ARGS //// Comment out the above line to call BLAS and LAPACK functions without using so-called "hidden" arguments. //// Fortran functions (compiled without a BIND(C) declaration) that have char arguments @@ -82,29 +122,17 @@ //// These "hidden" arguments are typically tacked onto the end of function definitions. // #define ARMA_USE_TBB_ALLOC -//// Uncomment the above line if you want to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() +//// Uncomment the above line to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() // #define ARMA_USE_MKL_ALLOC -//// Uncomment the above line if you want to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() +//// Uncomment the above line to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() // #define ARMA_USE_MKL_TYPES -//// Uncomment the above line if you want to use Intel MKL types for complex numbers. +//// Uncomment the above line to use Intel MKL types for complex numbers. //// You will need to include appropriate MKL headers before the Armadillo header. //// You may also need to enable or disable the following options: //// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS -// #define ARMA_USE_ATLAS -// #define ARMA_ATLAS_INCLUDE_DIR /usr/include/ -//// If you're using ATLAS and the compiler can't find cblas.h and/or clapack.h -//// uncomment the above define and specify the appropriate include directory. -//// Make sure the directory has a trailing / - -#if !defined(ARMA_USE_CXX11) -// #define ARMA_USE_CXX11 -//// Uncomment the above line to forcefully enable use of C++11 features (eg. initialiser lists). -//// Note that ARMA_USE_CXX11 is automatically enabled when a C++11 compiler is detected. -#endif - #if !defined(ARMA_USE_OPENMP) // #define ARMA_USE_OPENMP //// Uncomment the above line to forcefully enable use of OpenMP for parallelisation. @@ -114,33 +142,31 @@ #if !defined(ARMA_64BIT_WORD) // #define ARMA_64BIT_WORD //// Uncomment the above line if you require matrices/vectors capable of holding more than 4 billion elements. -//// Your machine and compiler must have support for 64 bit integers (eg. via "long" or "long long"). -//// Note that ARMA_64BIT_WORD is automatically enabled when a C++11 compiler is detected. +//// Note that ARMA_64BIT_WORD is automatically enabled when std::size_t has 64 bits and ARMA_32BIT_WORD is not defined. #endif -#if !defined(ARMA_USE_HDF5) -// #define ARMA_USE_HDF5 -//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; -//// the hdf5.h header file must be available on your system, -//// and you will need to link with the hdf5 library (eg. -lhdf5) +#if !defined(ARMA_OPTIMISE_BAND) + #define ARMA_OPTIMISE_BAND + //// Comment out the above line to disable optimised handling + //// of band matrices by solve() and chol() #endif -#if !defined(ARMA_OPTIMISE_SOLVE_BAND) - #define ARMA_OPTIMISE_SOLVE_BAND - //// Comment out the above line if you don't want optimised handling of band matrices by solve() +#if !defined(ARMA_OPTIMISE_SYM) + #define ARMA_OPTIMISE_SYM + //// Comment out the above line to disable optimised handling + //// of symmetric/hermitian matrices by various functions: + //// solve(), inv(), pinv(), expmat(), logmat(), sqrtmat(), rcond(), rank() #endif -#if !defined(ARMA_OPTIMISE_SOLVE_SYMPD) - #define ARMA_OPTIMISE_SOLVE_SYMPD - //// Comment out the above line if you don't want optimised handling of symmetric/hermitian positive definite matrices by solve() +#if !defined(ARMA_OPTIMISE_INVEXPR) + #define ARMA_OPTIMISE_INVEXPR + //// Comment out the above line to disable optimised handling + //// of inv() and inv_sympd() within compound expressions #endif -// #define ARMA_USE_HDF5_ALT -#if defined(ARMA_USE_HDF5_ALT) && defined(ARMA_USE_WRAPPER) - #undef ARMA_USE_HDF5 - #define ARMA_USE_HDF5 - - // #define ARMA_HDF5_INCLUDE_DIR /usr/include/ +#if !defined(ARMA_CHECK_NONFINITE) + #define ARMA_CHECK_NONFINITE + //// Comment out the above line to disable checking for nonfinite matrices #endif #if !defined(ARMA_MAT_PREALLOC) @@ -152,29 +178,34 @@ //// change the number to the size of your vectors. #if !defined(ARMA_OPENMP_THRESHOLD) - #define ARMA_OPENMP_THRESHOLD 240 + #define ARMA_OPENMP_THRESHOLD 320 #endif //// The minimum number of elements in a matrix to allow OpenMP based parallelisation; //// it must be an integer that is at least 1. #if !defined(ARMA_OPENMP_THREADS) - #define ARMA_OPENMP_THREADS 10 + #define ARMA_OPENMP_THREADS 8 #endif //// The maximum number of threads to use for OpenMP based parallelisation; //// it must be an integer that is at least 1. // #define ARMA_NO_DEBUG -//// Uncomment the above line if you want to disable all run-time checks. -//// This will result in faster code, but you first need to make sure that your code runs correctly! -//// We strongly recommend to have the run-time checks enabled during development, -//// as this greatly aids in finding mistakes in your code, and hence speeds up development. -//// We recommend that run-time checks be disabled _only_ for the shipped version of your program. +//// Uncomment the above line to disable all run-time checks. NOT RECOMMENDED. +//// It is strongly recommended that run-time checks are enabled during development, +//// as this greatly aids in finding mistakes in your code. // #define ARMA_EXTRA_DEBUG -//// Uncomment the above line if you want to see the function traces of how Armadillo evaluates expressions. +//// Uncomment the above line to see the function traces of how Armadillo evaluates expressions. //// This is mainly useful for debugging of the library. +#if defined(ARMA_EXTRA_DEBUG) + #undef ARMA_NO_DEBUG + #undef ARMA_WARN_LEVEL + #define ARMA_WARN_LEVEL 3 +#endif + + #if defined(ARMA_DEFAULT_OSTREAM) #pragma message ("WARNING: support for ARMA_DEFAULT_OSTREAM is deprecated and will be removed;") #pragma message ("WARNING: use ARMA_COUT_STREAM and ARMA_CERR_STREAM instead") @@ -200,13 +231,12 @@ #endif -#if !defined(ARMA_PRINT_ERRORS) -#define ARMA_PRINT_ERRORS -//// Comment out the above line if you don't want errors and warnings printed (eg. failed decompositions) -#endif - -#if !defined(ARMA_PRINT_HDF5_ERRORS) -// #define ARMA_PRINT_HDF5_ERRORS +#if !defined(ARMA_PRINT_EXCEPTIONS) + // #define ARMA_PRINT_EXCEPTIONS + #if defined(ARMA_PRINT_EXCEPTIONS_INTERNAL) + #undef ARMA_PRINT_EXCEPTIONS + #define ARMA_PRINT_EXCEPTIONS + #endif #endif #if defined(ARMA_DONT_USE_LAPACK) @@ -232,62 +262,84 @@ #if defined(ARMA_DONT_USE_ATLAS) #undef ARMA_USE_ATLAS - #undef ARMA_ATLAS_INCLUDE_DIR +#endif + +#if defined(ARMA_DONT_USE_HDF5) + #undef ARMA_USE_HDF5 +#endif + +#if defined(ARMA_DONT_USE_FFTW3) + #undef ARMA_USE_FFTW3 #endif #if defined(ARMA_DONT_USE_WRAPPER) #undef ARMA_USE_WRAPPER - #undef ARMA_USE_HDF5_ALT #endif #if defined(ARMA_DONT_USE_FORTRAN_HIDDEN_ARGS) #undef ARMA_USE_FORTRAN_HIDDEN_ARGS #endif -#if defined(ARMA_DONT_USE_CXX11) - #undef ARMA_USE_CXX11 - #undef ARMA_USE_EXTERN_CXX11_RNG +#if !defined(ARMA_DONT_USE_STD_MUTEX) + // #define ARMA_DONT_USE_STD_MUTEX + //// Uncomment the above line to disable use of std::mutex +#endif + +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11_MUTEX) + #pragma message ("WARNING: support for ARMA_DONT_USE_CXX11_MUTEX is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_DONT_USE_STD_MUTEX instead") + #undef ARMA_DONT_USE_STD_MUTEX + #define ARMA_DONT_USE_STD_MUTEX #endif #if defined(ARMA_DONT_USE_OPENMP) #undef ARMA_USE_OPENMP #endif -#if defined(ARMA_USE_WRAPPER) - #if defined(ARMA_USE_CXX11) - #if !defined(ARMA_USE_EXTERN_CXX11_RNG) - // #define ARMA_USE_EXTERN_CXX11_RNG - #endif - #endif +#if defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD #endif -#if defined(ARMA_DONT_USE_EXTERN_CXX11_RNG) - #undef ARMA_USE_EXTERN_CXX11_RNG +#if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) + #undef ARMA_OPTIMISE_BAND #endif -#if defined(ARMA_32BIT_WORD) - #undef ARMA_64BIT_WORD +#if defined(ARMA_DONT_OPTIMISE_SYM) || defined(ARMA_DONT_OPTIMISE_SYMPD) || defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) + #undef ARMA_OPTIMISE_SYM #endif -#if defined(ARMA_DONT_USE_HDF5) - #undef ARMA_USE_HDF5 - #undef ARMA_USE_HDF5_ALT +#if defined(ARMA_DONT_OPTIMISE_INVEXPR) + #undef ARMA_OPTIMISE_INVEXPR #endif -#if defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) - #undef ARMA_OPTIMISE_SOLVE_BAND +#if defined(ARMA_DONT_CHECK_NONFINITE) + #undef ARMA_CHECK_NONFINITE #endif -#if defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) - #undef ARMA_OPTIMISE_SOLVE_SYMPD +#if defined(ARMA_DONT_PRINT_ERRORS) + #pragma message ("INFO: support for ARMA_DONT_PRINT_ERRORS option has been removed") + + #if defined(ARMA_PRINT_EXCEPTIONS) + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL and ARMA_DONT_PRINT_EXCEPTIONS options instead") + #else + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL option instead") + #endif + + #pragma message ("INFO: see the documentation for details") #endif -#if defined(ARMA_DONT_PRINT_ERRORS) - #undef ARMA_PRINT_ERRORS +#if defined(ARMA_DONT_PRINT_EXCEPTIONS) + #undef ARMA_PRINT_EXCEPTIONS +#endif + +#if !defined(ARMA_DONT_ZERO_INIT) + // #define ARMA_DONT_ZERO_INIT + //// Uncomment the above line to disable initialising elements to zero during construction of dense matrices and cubes #endif -#if defined(ARMA_DONT_PRINT_HDF5_ERRORS) - #undef ARMA_PRINT_HDF5_ERRORS +#if defined(ARMA_NO_CRIPPLED_LAPACK) + #undef ARMA_CRIPPLED_LAPACK #endif diff --git a/src/armadillo_bits/config.hpp.cmake b/src/armadillo_bits/config.hpp.cmake index be422433..4ac633bb 100644 --- a/src/armadillo_bits/config.hpp.cmake +++ b/src/armadillo_bits/config.hpp.cmake @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,17 +17,32 @@ +#if !defined(ARMA_WARN_LEVEL) + #define ARMA_WARN_LEVEL 2 +#endif +//// The level of warning messages printed to ARMA_CERR_STREAM. +//// Must be an integer >= 0. The default value is 2. +//// 0 = no warnings; generally not recommended +//// 1 = only critical warnings about arguments and/or data which are likely to lead to incorrect results +//// 2 = as per level 1, and warnings about poorly conditioned systems (low rcond) detected by solve(), spsolve(), etc +//// 3 = as per level 2, and warnings about failed decompositions, failed saving/loading, etc + +#cmakedefine ARMA_USE_WRAPPER +//// Comment out the above line if you prefer to directly link with BLAS, LAPACK, etc +//// instead of the Armadillo runtime library. +//// You will need to link your programs directly with -lopenblas -llapack instead of -larmadillo + #if !defined(ARMA_USE_LAPACK) #cmakedefine ARMA_USE_LAPACK //// Comment out the above line if you don't have LAPACK or a high-speed replacement for LAPACK, -//// such as Intel MKL, AMD ACML, or the Accelerate framework. +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. //// LAPACK is required for matrix decompositions (eg. SVD) and matrix inverse. #endif #if !defined(ARMA_USE_BLAS) #cmakedefine ARMA_USE_BLAS //// Comment out the above line if you don't have BLAS or a high-speed replacement for BLAS, -//// such as OpenBLAS, GotoBLAS, Intel MKL, AMD ACML, or the Accelerate framework. +//// such as OpenBLAS, Intel MKL, or the Accelerate framework. //// BLAS is used for matrix multiplication. //// Without BLAS, matrix multiplication will still work, but might be slower. #endif @@ -56,13 +73,30 @@ //// Make sure the directory has a trailing / #endif -#cmakedefine ARMA_USE_WRAPPER -//// Comment out the above line if you're getting linking errors when compiling your programs, -//// or if you prefer to directly link with LAPACK, BLAS + etc instead of the Armadillo runtime library. -//// You will then need to link your programs directly with -llapack -lblas instead of -larmadillo +#if !defined(ARMA_USE_ATLAS) +#cmakedefine ARMA_USE_ATLAS +//// NOTE: support for ATLAS is deprecated and will be removed. +#endif + +#if !defined(ARMA_USE_HDF5) +// #define ARMA_USE_HDF5 +//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; +//// the hdf5.h header file must be available on your system, +//// and you will need to link with the hdf5 library (eg. -lhdf5) +#endif + +#if !defined(ARMA_USE_FFTW3) +// #define ARMA_USE_FFTW3 +//// Uncomment the above line to allow the use of the FFTW3 library by fft() and ifft() functions; +//// you will need to link with the FFTW3 library (eg. -lfftw3) +#endif + +#if defined(ARMA_USE_FFTW) + #error "use ARMA_USE_FFTW3 instead of ARMA_USE_FFTW" +#endif // #define ARMA_BLAS_CAPITALS -//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names (eg. ACML on 64-bit Windows) +//// Uncomment the above line if your BLAS and LAPACK libraries have capitalised function names #define ARMA_BLAS_UNDERSCORE //// Uncomment the above line if your BLAS and LAPACK libraries have function names with a trailing underscore. @@ -74,6 +108,12 @@ // #define ARMA_BLAS_LONG_LONG //// Uncomment the above line if your BLAS and LAPACK libraries use "long long" instead of "int" +// #define ARMA_BLAS_NOEXCEPT +//// Uncomment the above line if you require BLAS functions to have the 'noexcept' specification + +// #define ARMA_LAPACK_NOEXCEPT +//// Uncomment the above line if you require LAPACK functions to have the 'noexcept' specification + #define ARMA_USE_FORTRAN_HIDDEN_ARGS //// Comment out the above line to call BLAS and LAPACK functions without using so-called "hidden" arguments. //// Fortran functions (compiled without a BIND(C) declaration) that have char arguments @@ -82,29 +122,17 @@ //// These "hidden" arguments are typically tacked onto the end of function definitions. // #define ARMA_USE_TBB_ALLOC -//// Uncomment the above line if you want to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() +//// Uncomment the above line to use Intel TBB scalable_malloc() and scalable_free() instead of standard malloc() and free() // #define ARMA_USE_MKL_ALLOC -//// Uncomment the above line if you want to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() +//// Uncomment the above line to use Intel MKL mkl_malloc() and mkl_free() instead of standard malloc() and free() // #define ARMA_USE_MKL_TYPES -//// Uncomment the above line if you want to use Intel MKL types for complex numbers. +//// Uncomment the above line to use Intel MKL types for complex numbers. //// You will need to include appropriate MKL headers before the Armadillo header. //// You may also need to enable or disable the following options: //// ARMA_BLAS_LONG, ARMA_BLAS_LONG_LONG, ARMA_USE_FORTRAN_HIDDEN_ARGS -#cmakedefine ARMA_USE_ATLAS -#define ARMA_ATLAS_INCLUDE_DIR ${ARMA_ATLAS_INCLUDE_DIR}/ -//// If you're using ATLAS and the compiler can't find cblas.h and/or clapack.h -//// uncomment the above define and specify the appropriate include directory. -//// Make sure the directory has a trailing / - -#if !defined(ARMA_USE_CXX11) -// #define ARMA_USE_CXX11 -//// Uncomment the above line to forcefully enable use of C++11 features (eg. initialiser lists). -//// Note that ARMA_USE_CXX11 is automatically enabled when a C++11 compiler is detected. -#endif - #if !defined(ARMA_USE_OPENMP) // #define ARMA_USE_OPENMP //// Uncomment the above line to forcefully enable use of OpenMP for parallelisation. @@ -114,33 +142,31 @@ #if !defined(ARMA_64BIT_WORD) // #define ARMA_64BIT_WORD //// Uncomment the above line if you require matrices/vectors capable of holding more than 4 billion elements. -//// Your machine and compiler must have support for 64 bit integers (eg. via "long" or "long long"). -//// Note that ARMA_64BIT_WORD is automatically enabled when a C++11 compiler is detected. +//// Note that ARMA_64BIT_WORD is automatically enabled when std::size_t has 64 bits and ARMA_32BIT_WORD is not defined. #endif -#if !defined(ARMA_USE_HDF5) -// #define ARMA_USE_HDF5 -//// Uncomment the above line to allow the ability to save and load matrices stored in HDF5 format; -//// the hdf5.h header file must be available on your system, -//// and you will need to link with the hdf5 library (eg. -lhdf5) +#if !defined(ARMA_OPTIMISE_BAND) + #define ARMA_OPTIMISE_BAND + //// Comment out the above line to disable optimised handling + //// of band matrices by solve() and chol() #endif -#if !defined(ARMA_OPTIMISE_SOLVE_BAND) - #define ARMA_OPTIMISE_SOLVE_BAND - //// Comment out the above line if you don't want optimised handling of band matrices by solve() +#if !defined(ARMA_OPTIMISE_SYM) + #define ARMA_OPTIMISE_SYM + //// Comment out the above line to disable optimised handling + //// of symmetric/hermitian matrices by various functions: + //// solve(), inv(), pinv(), expmat(), logmat(), sqrtmat(), rcond(), rank() #endif -#if !defined(ARMA_OPTIMISE_SOLVE_SYMPD) - #define ARMA_OPTIMISE_SOLVE_SYMPD - //// Comment out the above line if you don't want optimised handling of symmetric/hermitian positive definite matrices by solve() +#if !defined(ARMA_OPTIMISE_INVEXPR) + #define ARMA_OPTIMISE_INVEXPR + //// Comment out the above line to disable optimised handling + //// of inv() and inv_sympd() within compound expressions #endif -#cmakedefine ARMA_USE_HDF5_ALT -#if defined(ARMA_USE_HDF5_ALT) && defined(ARMA_USE_WRAPPER) - #undef ARMA_USE_HDF5 - #define ARMA_USE_HDF5 - - #define ARMA_HDF5_INCLUDE_DIR ${ARMA_HDF5_INCLUDE_DIR}/ +#if !defined(ARMA_CHECK_NONFINITE) + #define ARMA_CHECK_NONFINITE + //// Comment out the above line to disable checking for nonfinite matrices #endif #if !defined(ARMA_MAT_PREALLOC) @@ -152,29 +178,34 @@ //// change the number to the size of your vectors. #if !defined(ARMA_OPENMP_THRESHOLD) - #define ARMA_OPENMP_THRESHOLD 240 + #define ARMA_OPENMP_THRESHOLD 320 #endif //// The minimum number of elements in a matrix to allow OpenMP based parallelisation; //// it must be an integer that is at least 1. #if !defined(ARMA_OPENMP_THREADS) - #define ARMA_OPENMP_THREADS 10 + #define ARMA_OPENMP_THREADS 8 #endif //// The maximum number of threads to use for OpenMP based parallelisation; //// it must be an integer that is at least 1. // #define ARMA_NO_DEBUG -//// Uncomment the above line if you want to disable all run-time checks. -//// This will result in faster code, but you first need to make sure that your code runs correctly! -//// We strongly recommend to have the run-time checks enabled during development, -//// as this greatly aids in finding mistakes in your code, and hence speeds up development. -//// We recommend that run-time checks be disabled _only_ for the shipped version of your program. +//// Uncomment the above line to disable all run-time checks. NOT RECOMMENDED. +//// It is strongly recommended that run-time checks are enabled during development, +//// as this greatly aids in finding mistakes in your code. // #define ARMA_EXTRA_DEBUG -//// Uncomment the above line if you want to see the function traces of how Armadillo evaluates expressions. +//// Uncomment the above line to see the function traces of how Armadillo evaluates expressions. //// This is mainly useful for debugging of the library. +#if defined(ARMA_EXTRA_DEBUG) + #undef ARMA_NO_DEBUG + #undef ARMA_WARN_LEVEL + #define ARMA_WARN_LEVEL 3 +#endif + + #if defined(ARMA_DEFAULT_OSTREAM) #pragma message ("WARNING: support for ARMA_DEFAULT_OSTREAM is deprecated and will be removed;") #pragma message ("WARNING: use ARMA_COUT_STREAM and ARMA_CERR_STREAM instead") @@ -200,13 +231,12 @@ #endif -#if !defined(ARMA_PRINT_ERRORS) -#define ARMA_PRINT_ERRORS -//// Comment out the above line if you don't want errors and warnings printed (eg. failed decompositions) -#endif - -#if !defined(ARMA_PRINT_HDF5_ERRORS) -// #define ARMA_PRINT_HDF5_ERRORS +#if !defined(ARMA_PRINT_EXCEPTIONS) + // #define ARMA_PRINT_EXCEPTIONS + #if defined(ARMA_PRINT_EXCEPTIONS_INTERNAL) + #undef ARMA_PRINT_EXCEPTIONS + #define ARMA_PRINT_EXCEPTIONS + #endif #endif #if defined(ARMA_DONT_USE_LAPACK) @@ -232,62 +262,84 @@ #if defined(ARMA_DONT_USE_ATLAS) #undef ARMA_USE_ATLAS - #undef ARMA_ATLAS_INCLUDE_DIR +#endif + +#if defined(ARMA_DONT_USE_HDF5) + #undef ARMA_USE_HDF5 +#endif + +#if defined(ARMA_DONT_USE_FFTW3) + #undef ARMA_USE_FFTW3 #endif #if defined(ARMA_DONT_USE_WRAPPER) #undef ARMA_USE_WRAPPER - #undef ARMA_USE_HDF5_ALT #endif #if defined(ARMA_DONT_USE_FORTRAN_HIDDEN_ARGS) #undef ARMA_USE_FORTRAN_HIDDEN_ARGS #endif -#if defined(ARMA_DONT_USE_CXX11) - #undef ARMA_USE_CXX11 - #undef ARMA_USE_EXTERN_CXX11_RNG +#if !defined(ARMA_DONT_USE_STD_MUTEX) + // #define ARMA_DONT_USE_STD_MUTEX + //// Uncomment the above line to disable use of std::mutex +#endif + +// for compatibility with earlier versions of Armadillo +#if defined(ARMA_DONT_USE_CXX11_MUTEX) + #pragma message ("WARNING: support for ARMA_DONT_USE_CXX11_MUTEX is deprecated and will be removed;") + #pragma message ("WARNING: use ARMA_DONT_USE_STD_MUTEX instead") + #undef ARMA_DONT_USE_STD_MUTEX + #define ARMA_DONT_USE_STD_MUTEX #endif #if defined(ARMA_DONT_USE_OPENMP) #undef ARMA_USE_OPENMP #endif -#if defined(ARMA_USE_WRAPPER) - #if defined(ARMA_USE_CXX11) - #if !defined(ARMA_USE_EXTERN_CXX11_RNG) - #cmakedefine ARMA_USE_EXTERN_CXX11_RNG - #endif - #endif +#if defined(ARMA_32BIT_WORD) + #undef ARMA_64BIT_WORD #endif -#if defined(ARMA_DONT_USE_EXTERN_CXX11_RNG) - #undef ARMA_USE_EXTERN_CXX11_RNG +#if defined(ARMA_DONT_OPTIMISE_BAND) || defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) + #undef ARMA_OPTIMISE_BAND #endif -#if defined(ARMA_32BIT_WORD) - #undef ARMA_64BIT_WORD +#if defined(ARMA_DONT_OPTIMISE_SYM) || defined(ARMA_DONT_OPTIMISE_SYMPD) || defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) + #undef ARMA_OPTIMISE_SYM #endif -#if defined(ARMA_DONT_USE_HDF5) - #undef ARMA_USE_HDF5 - #undef ARMA_USE_HDF5_ALT +#if defined(ARMA_DONT_OPTIMISE_INVEXPR) + #undef ARMA_OPTIMISE_INVEXPR #endif -#if defined(ARMA_DONT_OPTIMISE_SOLVE_BAND) - #undef ARMA_OPTIMISE_SOLVE_BAND +#if defined(ARMA_DONT_CHECK_NONFINITE) + #undef ARMA_CHECK_NONFINITE #endif -#if defined(ARMA_DONT_OPTIMISE_SOLVE_SYMPD) - #undef ARMA_OPTIMISE_SOLVE_SYMPD +#if defined(ARMA_DONT_PRINT_ERRORS) + #pragma message ("INFO: support for ARMA_DONT_PRINT_ERRORS option has been removed") + + #if defined(ARMA_PRINT_EXCEPTIONS) + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL and ARMA_DONT_PRINT_EXCEPTIONS options instead") + #else + #pragma message ("INFO: suggest to use ARMA_WARN_LEVEL option instead") + #endif + + #pragma message ("INFO: see the documentation for details") #endif -#if defined(ARMA_DONT_PRINT_ERRORS) - #undef ARMA_PRINT_ERRORS +#if defined(ARMA_DONT_PRINT_EXCEPTIONS) + #undef ARMA_PRINT_EXCEPTIONS +#endif + +#if !defined(ARMA_DONT_ZERO_INIT) + // #define ARMA_DONT_ZERO_INIT + //// Uncomment the above line to disable initialising elements to zero during construction of dense matrices and cubes #endif -#if defined(ARMA_DONT_PRINT_HDF5_ERRORS) - #undef ARMA_PRINT_HDF5_ERRORS +#if defined(ARMA_NO_CRIPPLED_LAPACK) + #undef ARMA_CRIPPLED_LAPACK #endif diff --git a/src/armadillo_bits/constants.hpp b/src/armadillo_bits/constants.hpp index 2b39444e..9adf9ea6 100644 --- a/src/armadillo_bits/constants.hpp +++ b/src/armadillo_bits/constants.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,25 +29,18 @@ namespace priv template static typename arma_real_only::result - nan(typename arma_real_only::result* junk = 0) + nan(typename arma_real_only::result* junk = nullptr) { arma_ignore(junk); - if(std::numeric_limits::has_quiet_NaN) - { - return std::numeric_limits::quiet_NaN(); - } - else - { - return eT(0); - } + return (std::numeric_limits::has_quiet_NaN) ? eT(std::numeric_limits::quiet_NaN()) : eT(0); } template static typename arma_cx_only::result - nan(typename arma_cx_only::result* junk = 0) + nan(typename arma_cx_only::result* junk = nullptr) { arma_ignore(junk); @@ -58,7 +53,7 @@ namespace priv template static typename arma_integral_only::result - nan(typename arma_integral_only::result* junk = 0) + nan(typename arma_integral_only::result* junk = nullptr) { arma_ignore(junk); @@ -69,25 +64,18 @@ namespace priv template static typename arma_real_only::result - inf(typename arma_real_only::result* junk = 0) + inf(typename arma_real_only::result* junk = nullptr) { arma_ignore(junk); - if(std::numeric_limits::has_infinity) - { - return std::numeric_limits::infinity(); - } - else - { - return std::numeric_limits::max(); - } + return (std::numeric_limits::has_infinity) ? eT(std::numeric_limits::infinity()) : eT(std::numeric_limits::max()); } template static typename arma_cx_only::result - inf(typename arma_cx_only::result* junk = 0) + inf(typename arma_cx_only::result* junk = nullptr) { arma_ignore(junk); @@ -96,24 +84,23 @@ namespace priv return eT( Datum_helper::inf(), Datum_helper::inf() ); } - + template static typename arma_integral_only::result - inf(typename arma_integral_only::result* junk = 0) + inf(typename arma_integral_only::result* junk = nullptr) { arma_ignore(junk); return std::numeric_limits::max(); } - }; } //! various constants. -//! Physical constants taken from NIST 2014 CODATA values, and some from WolframAlpha (values provided as of 2009-06-23) +//! Physical constants taken from NIST 2018 CODATA values, and some from WolframAlpha (values provided as of 2009-06-23) //! http://physics.nist.gov/cuu/Constants //! http://www.wolframalpha.com //! See also http://en.wikipedia.org/wiki/Physical_constant @@ -124,17 +111,19 @@ class Datum { public: - static const eT pi; //!< ratio of any circle's circumference to its diameter - static const eT e; //!< base of the natural logarithm - static const eT euler; //!< Euler's constant, aka Euler-Mascheroni constant - static const eT gratio; //!< golden ratio - static const eT sqrt2; //!< square root of 2 - static const eT sqrt2pi; //!< square root of 2*pi - static const eT eps; //!< the difference between 1 and the least value greater than 1 that is representable - static const eT log_min; //!< log of the minimum representable value - static const eT log_max; //!< log of the maximum representable value - static const eT nan; //!< "not a number" - static const eT inf; //!< infinity + static const eT pi; //!< ratio of any circle's circumference to its diameter + static const eT tau; //!< ratio of any circle's circumference to its radius (replacement of 2*pi) + static const eT e; //!< base of the natural logarithm + static const eT euler; //!< Euler's constant, aka Euler-Mascheroni constant + static const eT gratio; //!< golden ratio + static const eT sqrt2; //!< square root of 2 + static const eT sqrt2pi; //!< square root of 2*pi + static const eT log_sqrt2pi; //!< log of square root of 2*pi + static const eT eps; //!< the difference between 1 and the least value greater than 1 that is representable + static const eT log_min; //!< log of the minimum representable value + static const eT log_max; //!< log of the maximum representable value + static const eT nan; //!< "not a number" + static const eT inf; //!< infinity // @@ -173,52 +162,54 @@ class Datum // the long lengths of the constants are for future support of "long double" // and any smart compiler that does high-precision computation at compile-time -template const eT Datum::pi = eT(3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679); -template const eT Datum::e = eT(2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274); -template const eT Datum::euler = eT(0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495); -template const eT Datum::gratio = eT(1.6180339887498948482045868343656381177203091798057628621354486227052604628189024497072072041893911374); -template const eT Datum::sqrt2 = eT(1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727); -template const eT Datum::sqrt2pi = eT(2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749465958383780572661); -template const eT Datum::eps = std::numeric_limits::epsilon(); -template const eT Datum::log_min = std::log(std::numeric_limits::min()); -template const eT Datum::log_max = std::log(std::numeric_limits::max()); -template const eT Datum::nan = priv::Datum_helper::nan(); -template const eT Datum::inf = priv::Datum_helper::inf(); +template const eT Datum::pi = eT(3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679); +template const eT Datum::tau = eT(6.2831853071795864769252867665590057683943387987502116419498891846156328125724179972560696506842341359); +template const eT Datum::e = eT(2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274); +template const eT Datum::euler = eT(0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495); +template const eT Datum::gratio = eT(1.6180339887498948482045868343656381177203091798057628621354486227052604628189024497072072041893911374); +template const eT Datum::sqrt2 = eT(1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727); +template const eT Datum::sqrt2pi = eT(2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749465958383780572661); +template const eT Datum::log_sqrt2pi = eT(0.9189385332046727417803297364056176398613974736377834128171515404827656959272603976947432986359541976); +template const eT Datum::eps = std::numeric_limits::epsilon(); +template const eT Datum::log_min = std::log(std::numeric_limits::min()); +template const eT Datum::log_max = std::log(std::numeric_limits::max()); +template const eT Datum::nan = priv::Datum_helper::nan(); +template const eT Datum::inf = priv::Datum_helper::inf(); -template const eT Datum::m_u = eT(1.660539040e-27); -template const eT Datum::N_A = eT(6.022140857e23); -template const eT Datum::k = eT(1.38064852e-23); -template const eT Datum::k_evk = eT(8.6173303e-5); -template const eT Datum::a_0 = eT(0.52917721067e-10); -template const eT Datum::mu_B = eT(927.4009994e-26); -template const eT Datum::Z_0 = eT(376.730313461771); -template const eT Datum::G_0 = eT(7.7480917310e-5); -template const eT Datum::k_e = eT(8.9875517873681764e9); -template const eT Datum::eps_0 = eT(8.85418781762039e-12); -template const eT Datum::m_e = eT(9.10938356e-31); -template const eT Datum::eV = eT(1.6021766208e-19); -template const eT Datum::ec = eT(1.6021766208e-19); -template const eT Datum::F = eT(96485.33289); -template const eT Datum::alpha = eT(7.2973525664e-3); -template const eT Datum::alpha_inv = eT(137.035999139); -template const eT Datum::K_J = eT(483597.8525e9); -template const eT Datum::mu_0 = eT(1.25663706143592e-06); -template const eT Datum::phi_0 = eT(2.067833667e-15); -template const eT Datum::R = eT(8.3144598); -template const eT Datum::G = eT(6.67408e-11); -template const eT Datum::h = eT(6.626070040e-34); -template const eT Datum::h_bar = eT(1.054571800e-34); -template const eT Datum::m_p = eT(1.672621898e-27); -template const eT Datum::R_inf = eT(10973731.568508); +template const eT Datum::m_u = eT(1.66053906660e-27); +template const eT Datum::N_A = eT(6.02214076e23); +template const eT Datum::k = eT(1.380649e-23); +template const eT Datum::k_evk = eT(8.617333262e-5); +template const eT Datum::a_0 = eT(5.29177210903e-11); +template const eT Datum::mu_B = eT(9.2740100783e-24); +template const eT Datum::Z_0 = eT(376.730313668); +template const eT Datum::G_0 = eT(7.748091729e-5); +template const eT Datum::k_e = eT(8.9875517923e9); +template const eT Datum::eps_0 = eT(8.8541878128e-12); +template const eT Datum::m_e = eT(9.1093837015e-31); +template const eT Datum::eV = eT(1.602176634e-19); +template const eT Datum::ec = eT(1.602176634e-19); +template const eT Datum::F = eT(96485.33212); +template const eT Datum::alpha = eT(7.2973525693e-3); +template const eT Datum::alpha_inv = eT(137.035999084); +template const eT Datum::K_J = eT(483597.8484e9); +template const eT Datum::mu_0 = eT(1.25663706212e-6); +template const eT Datum::phi_0 = eT(2.067833848e-15); +template const eT Datum::R = eT(8.314462618); +template const eT Datum::G = eT(6.67430e-11); +template const eT Datum::h = eT(6.62607015e-34); +template const eT Datum::h_bar = eT(1.054571817e-34); +template const eT Datum::m_p = eT(1.67262192369e-27); +template const eT Datum::R_inf = eT(10973731.568160); template const eT Datum::c_0 = eT(299792458.0); -template const eT Datum::sigma = eT(5.670367e-8); -template const eT Datum::R_k = eT(25812.8074555); -template const eT Datum::b = eT(2.8977729e-3); +template const eT Datum::sigma = eT(5.670374419e-8); +template const eT Datum::R_k = eT(25812.80745); +template const eT Datum::b = eT(2.897771955e-3); -typedef Datum fdatum; -typedef Datum datum; +typedef Datum fdatum; +typedef Datum datum; @@ -228,62 +219,40 @@ namespace priv template static - arma_inline + constexpr typename arma_real_only::result - most_neg(typename arma_real_only::result* junk = 0) + most_neg() { - arma_ignore(junk); - - if(std::numeric_limits::has_infinity) - { - return -(std::numeric_limits::infinity()); - } - else - { - return -(std::numeric_limits::max()); - } + return (std::numeric_limits::has_infinity) ? -(std::numeric_limits::infinity()) : std::numeric_limits::lowest(); } template static - arma_inline + constexpr typename arma_integral_only::result - most_neg(typename arma_integral_only::result* junk = 0) + most_neg() { - arma_ignore(junk); - - return std::numeric_limits::min(); + return std::numeric_limits::lowest(); } template static - arma_inline + constexpr typename arma_real_only::result - most_pos(typename arma_real_only::result* junk = 0) + most_pos() { - arma_ignore(junk); - - if(std::numeric_limits::has_infinity) - { - return std::numeric_limits::infinity(); - } - else - { - return std::numeric_limits::max(); - } + return (std::numeric_limits::has_infinity) ? std::numeric_limits::infinity() : std::numeric_limits::max(); } template static - arma_inline + constexpr typename arma_integral_only::result - most_pos(typename arma_integral_only::result* junk = 0) + most_pos() { - arma_ignore(junk); - return std::numeric_limits::max(); } diff --git a/src/armadillo_bits/constants_old.hpp b/src/armadillo_bits/constants_old.hpp index 89571a09..a2bc0466 100644 --- a/src/armadillo_bits/constants_old.hpp +++ b/src/armadillo_bits/constants_old.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,137 +30,54 @@ class Math { public: - // the long lengths of the constants are for future support of "long double" - // and any smart compiler that does high-precision computation at compile-time - - //! ratio of any circle's circumference to its diameter - arma_deprecated static eT pi() { return eT(3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679); } // use datum::pi instead - - //! base of the natural logarithm - arma_deprecated static eT e() { return eT(2.7182818284590452353602874713526624977572470936999595749669676277240766303535475945713821785251664274); } // use datum::e instead - - //! Euler's constant, aka Euler-Mascheroni constant - arma_deprecated static eT euler() { return eT(0.5772156649015328606065120900824024310421593359399235988057672348848677267776646709369470632917467495); } // use datum::euler instead - - //! golden ratio - arma_deprecated static eT gratio() { return eT(1.6180339887498948482045868343656381177203091798057628621354486227052604628189024497072072041893911374); } // use datum::gratio instead - - //! square root of 2 - arma_deprecated static eT sqrt2() { return eT(1.4142135623730950488016887242096980785696718753769480731766797379907324784621070388503875343276415727); } // use datum::sqrt2 instead - - //! the difference between 1 and the least value greater than 1 that is representable - arma_deprecated static eT eps() { return std::numeric_limits::epsilon(); } // use datum::eps instead - - //! log of the minimum representable value - arma_deprecated static eT log_min() { static const eT out = std::log(std::numeric_limits::min()); return out; } // use datum::log_min instead - - //! log of the maximum representable value - arma_deprecated static eT log_max() { static const eT out = std::log(std::numeric_limits::max()); return out; } // use datum::log_max instead - - //! "not a number" - arma_deprecated static eT nan() { return priv::Datum_helper::nan(); } // use datum::nan instead - - //! infinity - arma_deprecated static eT inf() { return priv::Datum_helper::inf(); } // use datum::inf instead + arma_frown("use datum::pi instead") static eT pi() { return eT(Datum::pi); } + arma_frown("use datum::e instead") static eT e() { return eT(Datum::e); } + arma_frown("use datum::euler instead") static eT euler() { return eT(Datum::euler); } + arma_frown("use datum::gratio instead") static eT gratio() { return eT(Datum::gratio); } + arma_frown("use datum::sqrt2 instead") static eT sqrt2() { return eT(Datum::sqrt2); } + arma_frown("use datum::eps instead") static eT eps() { return eT(Datum::eps); } + arma_frown("use datum::log_min instead") static eT log_min() { return eT(Datum::log_min); } + arma_frown("use datum::log_max instead") static eT log_max() { return eT(Datum::log_max); } + arma_frown("use datum::nan instead") static eT nan() { return eT(Datum::nan); } + arma_frown("use datum::inf instead") static eT inf() { return eT(Datum::inf); } }; -//! Physical constants taken from NIST 2010 CODATA values, and some from WolframAlpha (values provided as of 2009-06-23) -//! http://physics.nist.gov/cuu/Constants -//! http://www.wolframalpha.com -//! See also http://en.wikipedia.org/wiki/Physical_constant template class Phy { public: - //! atomic mass constant (in kg) - arma_deprecated static eT m_u() { return eT(1.660539040e-27); } - - //! Avogadro constant - arma_deprecated static eT N_A() { return eT(6.022140857e23); } - - //! Boltzmann constant (in joules per kelvin) - arma_deprecated static eT k() { return eT(1.38064852e-23); } - - //! Boltzmann constant (in eV/K) - arma_deprecated static eT k_evk() { return eT(8.6173303e-5); } - - //! Bohr radius (in meters) - arma_deprecated static eT a_0() { return eT(0.52917721067e-10); } - - //! Bohr magneton - arma_deprecated static eT mu_B() { return eT(927.4009994e-26); } - - //! characteristic impedance of vacuum (in ohms) - arma_deprecated static eT Z_0() { return eT(376.730313461771); } - - //! conductance quantum (in siemens) - arma_deprecated static eT G_0() { return eT(7.7480917310e-5); } - - //! Coulomb's constant (in meters per farad) - arma_deprecated static eT k_e() { return eT(8.9875517873681764e9); } - - //! electric constant (in farads per meter) - arma_deprecated static eT eps_0() { return eT(8.85418781762039e-12); } - - //! electron mass (in kg) - arma_deprecated static eT m_e() { return eT(9.10938356e-31); } - - //! electron volt (in joules) - arma_deprecated static eT eV() { return eT(1.6021766208e-19); } - - //! elementary charge (in coulombs) - arma_deprecated static eT e() { return eT(1.6021766208e-19); } - - //! Faraday constant (in coulombs) - arma_deprecated static eT F() { return eT(96485.33289); } - - //! fine-structure constant - arma_deprecated static eT alpha() { return eT(7.2973525664e-3); } - - //! inverse fine-structure constant - arma_deprecated static eT alpha_inv() { return eT(137.035999139); } - - //! Josephson constant - arma_deprecated static eT K_J() { return eT(483597.8525e9); } - - //! magnetic constant (in henries per meter) - arma_deprecated static eT mu_0() { return eT(1.25663706143592e-06); } - - //! magnetic flux quantum (in webers) - arma_deprecated static eT phi_0() { return eT(2.067833667e-15); } - - //! molar gas constant (in joules per mole kelvin) - arma_deprecated static eT R() { return eT(8.3144598); } - - //! Newtonian constant of gravitation (in newton square meters per kilogram squared) - arma_deprecated static eT G() { return eT(6.67408e-11); } - - //! Planck constant (in joule seconds) - arma_deprecated static eT h() { return eT(6.626070040e-34); } - - //! Planck constant over 2 pi, aka reduced Planck constant (in joule seconds) - arma_deprecated static eT h_bar() { return eT(1.054571800e-34); } - - //! proton mass (in kg) - arma_deprecated static eT m_p() { return eT(1.672621898e-27); } - - //! Rydberg constant (in reciprocal meters) - arma_deprecated static eT R_inf() { return eT(10973731.568508); } - - //! speed of light in vacuum (in meters per second) - arma_deprecated static eT c_0() { return eT(299792458.0); } - - //! Stefan-Boltzmann constant - arma_deprecated static eT sigma() { return eT(5.670367e-8); } - - //! von Klitzing constant (in ohms) - arma_deprecated static eT R_k() { return eT(25812.8074555); } - - //! Wien wavelength displacement law constant - arma_deprecated static eT b() { return eT(2.8977729e-3); } + arma_deprecated static eT m_u() { return eT(Datum::m_u); } + arma_deprecated static eT N_A() { return eT(Datum::N_A); } + arma_deprecated static eT k() { return eT(Datum::k); } + arma_deprecated static eT k_evk() { return eT(Datum::k_evk); } + arma_deprecated static eT a_0() { return eT(Datum::a_0); } + arma_deprecated static eT mu_B() { return eT(Datum::mu_B); } + arma_deprecated static eT Z_0() { return eT(Datum::Z_0); } + arma_deprecated static eT G_0() { return eT(Datum::G_0); } + arma_deprecated static eT k_e() { return eT(Datum::k_e); } + arma_deprecated static eT eps_0() { return eT(Datum::eps_0); } + arma_deprecated static eT m_e() { return eT(Datum::m_e); } + arma_deprecated static eT eV() { return eT(Datum::eV); } + arma_deprecated static eT e() { return eT(Datum::ec); } + arma_deprecated static eT F() { return eT(Datum::F); } + arma_deprecated static eT alpha() { return eT(Datum::alpha); } + arma_deprecated static eT alpha_inv() { return eT(Datum::alpha_inv); } + arma_deprecated static eT K_J() { return eT(Datum::K_J); } + arma_deprecated static eT mu_0() { return eT(Datum::mu_0); } + arma_deprecated static eT phi_0() { return eT(Datum::phi_0); } + arma_deprecated static eT R() { return eT(Datum::R); } + arma_deprecated static eT G() { return eT(Datum::G); } + arma_deprecated static eT h() { return eT(Datum::h); } + arma_deprecated static eT h_bar() { return eT(Datum::h_bar); } + arma_deprecated static eT m_p() { return eT(Datum::m_p); } + arma_deprecated static eT R_inf() { return eT(Datum::R_inf); } + arma_deprecated static eT c_0() { return eT(Datum::c_0); } + arma_deprecated static eT sigma() { return eT(Datum::sigma); } + arma_deprecated static eT R_k() { return eT(Datum::R_k); } + arma_deprecated static eT b() { return eT(Datum::b); } }; diff --git a/src/armadillo_bits/csv_name.hpp b/src/armadillo_bits/csv_name.hpp new file mode 100644 index 00000000..c6a1df5d --- /dev/null +++ b/src/armadillo_bits/csv_name.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +namespace csv_opts + { + typedef unsigned int flag_type; + + struct opts + { + const flag_type flags; + + inline constexpr explicit opts(const flag_type in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const flag_type in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 0) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr flag_type flag_none = flag_type(0 ); + static constexpr flag_type flag_trans = flag_type(1u << 0); + static constexpr flag_type flag_no_header = flag_type(1u << 1); + static constexpr flag_type flag_with_header = flag_type(1u << 2); + static constexpr flag_type flag_semicolon = flag_type(1u << 3); + static constexpr flag_type flag_strict = flag_type(1u << 4); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_trans : public opts { inline constexpr opts_trans() : opts(flag_trans ) {} }; + struct opts_no_header : public opts { inline constexpr opts_no_header() : opts(flag_no_header ) {} }; + struct opts_with_header : public opts { inline constexpr opts_with_header() : opts(flag_with_header) {} }; + struct opts_semicolon : public opts { inline constexpr opts_semicolon() : opts(flag_semicolon ) {} }; + struct opts_strict : public opts { inline constexpr opts_strict() : opts(flag_strict ) {} }; + + static constexpr opts_none none; + static constexpr opts_trans trans; + static constexpr opts_no_header no_header; + static constexpr opts_with_header with_header; + static constexpr opts_semicolon semicolon; + static constexpr opts_strict strict; + } + + +struct csv_name + { + typedef field header_type; + + const std::string filename; + const csv_opts::opts opts; + + header_type header_junk; + const header_type& header_ro; + header_type& header_rw; + + inline + csv_name(const std::string& in_filename) + : filename (in_filename ) + , opts (csv_opts::no_header) + , header_ro(header_junk ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::no_header + in_opts) + , header_ro(header_junk ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, field& in_header) + : filename (in_filename ) + , opts (csv_opts::with_header) + , header_ro(in_header ) + , header_rw(in_header ) + {} + + inline + csv_name(const std::string& in_filename, const field& in_header) + : filename (in_filename ) + , opts (csv_opts::with_header) + , header_ro(in_header ) + , header_rw(header_junk ) + {} + + inline + csv_name(const std::string& in_filename, field& in_header, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::with_header + in_opts) + , header_ro(in_header ) + , header_rw(in_header ) + {} + + inline + csv_name(const std::string& in_filename, const field& in_header, const csv_opts::opts& in_opts) + : filename (in_filename ) + , opts (csv_opts::with_header + in_opts) + , header_ro(in_header ) + , header_rw(header_junk ) + {} + }; + + +//! @} diff --git a/src/armadillo_bits/debug.hpp b/src/armadillo_bits/debug.hpp index 4c421430..7a0b95c2 100644 --- a/src/armadillo_bits/debug.hpp +++ b/src/armadillo_bits/debug.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,110 +21,98 @@ -template inline std::ostream& -arma_cout_stream(std::ostream* user_stream) +get_cout_stream() { - static std::ostream* cout_stream = &(ARMA_COUT_STREAM); - - if(user_stream != NULL) { cout_stream = user_stream; } - - return (*cout_stream); + return (ARMA_COUT_STREAM); } -template inline std::ostream& -arma_cerr_stream(std::ostream* user_stream) +get_cerr_stream() { - static std::ostream* cerr_stream = &(ARMA_CERR_STREAM); - - if(user_stream != NULL) { cerr_stream = user_stream; } - - return (*cerr_stream); + return (ARMA_CERR_STREAM); } +arma_deprecated inline -void -set_cout_stream(std::ostream& user_stream) +std::ostream& +get_stream_err1() { - arma_cout_stream(&user_stream); + return get_cerr_stream(); } +arma_deprecated inline -void -set_cerr_stream(std::ostream& user_stream) +std::ostream& +get_stream_err2() { - arma_cerr_stream(&user_stream); + return get_cerr_stream(); } +arma_frown("this function does nothing; instead use ARMA_COUT_STREAM or ARMA_WARN_LEVEL; see documentation") inline -std::ostream& -get_cout_stream() +void +set_cout_stream(const std::ostream&) { - return arma_cout_stream(NULL); } +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") inline -std::ostream& -get_cerr_stream() +void +set_cerr_stream(const std::ostream&) { - return arma_cerr_stream(NULL); } -//! do not use this function - it's deprecated and will be removed +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") inline -arma_deprecated void -set_stream_err1(std::ostream& user_stream) +set_stream_err1(const std::ostream&) { - set_cerr_stream(user_stream); } -//! do not use this function - it's deprecated and will be removed +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") inline -arma_deprecated void -set_stream_err2(std::ostream& user_stream) +set_stream_err2(const std::ostream&) { - set_cerr_stream(user_stream); } -//! do not use this function - it's deprecated and will be removed +template +arma_frown("this function does nothing; instead use ARMA_COUT_STREAM or ARMA_WARN_LEVEL; see documentation") inline -arma_deprecated std::ostream& -get_stream_err1() +arma_cout_stream(std::ostream*) { - return get_cerr_stream(); + return (ARMA_COUT_STREAM); } -//! do not use this function - it's deprecated and will be removed +template +arma_frown("this function does nothing; instead use ARMA_CERR_STREAM or ARMA_WARN_LEVEL; see documentation") inline -arma_deprecated std::ostream& -get_stream_err2() +arma_cerr_stream(std::ostream*) { - return get_cerr_stream(); + return (ARMA_CERR_STREAM); } @@ -135,7 +125,7 @@ static void arma_stop_logic_error(const T1& x) { - #if defined(ARMA_PRINT_ERRORS) + #if defined(ARMA_PRINT_EXCEPTIONS) { get_cerr_stream() << "\nerror: " << x << std::endl; } @@ -146,6 +136,36 @@ arma_stop_logic_error(const T1& x) +arma_cold +arma_noinline +static +void +arma_stop_logic_error(const char* x, const char* y) + { + arma_stop_logic_error( std::string(x) + std::string(y) ); + } + + + +//! print a message to get_cerr_stream() and throw out_of_range exception +template +arma_cold +arma_noinline +static +void +arma_stop_bounds_error(const T1& x) + { + #if defined(ARMA_PRINT_EXCEPTIONS) + { + get_cerr_stream() << "\nerror: " << x << std::endl; + } + #endif + + throw std::out_of_range( std::string(x) ); + } + + + //! print a message to get_cerr_stream() and throw bad_alloc exception template arma_cold @@ -154,7 +174,7 @@ static void arma_stop_bad_alloc(const T1& x) { - #if defined(ARMA_PRINT_ERRORS) + #if defined(ARMA_PRINT_EXCEPTIONS) { get_cerr_stream() << "\nerror: " << x << std::endl; } @@ -177,7 +197,7 @@ static void arma_stop_runtime_error(const T1& x) { - #if defined(ARMA_PRINT_ERRORS) + #if defined(ARMA_PRINT_EXCEPTIONS) { get_cerr_stream() << "\nerror: " << x << std::endl; } @@ -243,7 +263,7 @@ arma_print(const T1& x, const T2& y, const T3& z) // // arma_sigprint -//! print a message the the log stream with a preceding @ character. +//! print a message to the log stream with a preceding @ character. //! by default the log stream is cout. //! used for printing the signature of a function //! (see the arma_extra_debug_sigprint macro) @@ -313,17 +333,9 @@ arma_cold arma_noinline static void -arma_warn(const T1& x) +arma_warn(const T1& arg1) { - #if defined(ARMA_PRINT_ERRORS) - { - get_cerr_stream() << "\nwarning: " << x << '\n'; - } - #else - { - arma_ignore(x); - } - #endif + get_cerr_stream() << "\nwarning: " << arg1 << std::endl; } @@ -332,18 +344,9 @@ arma_cold arma_noinline static void -arma_warn(const T1& x, const T2& y) +arma_warn(const T1& arg1, const T2& arg2) { - #if defined(ARMA_PRINT_ERRORS) - { - get_cerr_stream() << "\nwarning: " << x << y << '\n'; - } - #else - { - arma_ignore(x); - arma_ignore(y); - } - #endif + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << std::endl; } @@ -352,19 +355,69 @@ arma_cold arma_noinline static void -arma_warn(const T1& x, const T2& y, const T3& z) +arma_warn(const T1& arg1, const T2& arg2, const T3& arg3) { - #if defined(ARMA_PRINT_ERRORS) - { - get_cerr_stream() << "\nwarning: " << x << y << z << '\n'; - } - #else - { - arma_ignore(x); - arma_ignore(y); - arma_ignore(z); - } - #endif + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << arg3 << std::endl; + } + + +template +arma_cold +arma_noinline +static +void +arma_warn(const T1& arg1, const T2& arg2, const T3& arg3, const T4& arg4) + { + get_cerr_stream() << "\nwarning: " << arg1 << arg2 << arg3 << arg4 << std::endl; + } + + + +// +// arma_warn_level + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2, const T3& arg3) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2,arg3); } + } + + +template +inline +void +arma_warn_level(const uword level, const T1& arg1, const T2& arg2, const T3& arg3, const T4& arg4) + { + constexpr uword config_level = (sword(ARMA_WARN_LEVEL) > 0) ? uword(ARMA_WARN_LEVEL) : uword(0); + + if((config_level > 0) && (level <= config_level)) { arma_warn(arg1,arg2,arg3,arg4); } } @@ -383,13 +436,42 @@ arma_check(const bool state, const T1& x) } -template +template arma_hot inline void -arma_check(const bool state, const T1& x, const T2& y) +arma_check(const bool state, const char* x, const Functor& fn) { - if(state) { arma_stop_logic_error( std::string(x) + std::string(y) ); } + if(state) { fn(); arma_stop_logic_error(x); } + } + + +arma_hot +inline +void +arma_check(const bool state, const char* x, const char* y) + { + if(state) { arma_stop_logic_error(x,y); } + } + + +template +arma_hot +inline +void +arma_check(const bool state, const char* x, const char* y, const Functor& fn) + { + if(state) { fn(); arma_stop_logic_error(x,y); } + } + + +template +arma_hot +inline +void +arma_check_bounds(const bool state, const T1& x) + { + if(state) { arma_stop_bounds_error(arma_str::str_wrapper(x)); } } @@ -485,8 +567,8 @@ arma_incompat_size_string(const subview_cube& Q, const Mat& A, const cha -arma_inline arma_hot +arma_inline void arma_assert_same_size(const uword A_n_rows, const uword A_n_cols, const uword B_n_rows, const uword B_n_cols, const char* x) { @@ -780,6 +862,28 @@ arma_assert_same_size(const subview_cube& A, const subview_cube& B, co +template +arma_hot +inline +void +arma_assert_same_size(const subview_cube& A, const ProxyCube& B, const char* x) + { + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + const uword B_n_rows = B.get_n_rows(); + const uword B_n_cols = B.get_n_cols(); + const uword B_n_slices = B.get_n_slices(); + + if( (A_n_rows != B_n_rows) || (A_n_cols != B_n_cols) || (A_n_slices != B_n_slices) ) + { + arma_stop_logic_error( arma_incompat_size_string(A_n_rows, A_n_cols, A_n_slices, B_n_rows, B_n_cols, B_n_slices, x) ); + } + } + + + //! stop if given cube proxies have different sizes template arma_hot @@ -804,7 +908,7 @@ arma_assert_same_size(const ProxyCube& A, const ProxyCube& B, const ch // -// functions for checking whether a cube or subcube can be interpreted as a matrix (i.e. single slice) +// functions for checking whether a cube or subcube can be interpreted as a matrix (ie. single slice) @@ -1179,6 +1283,7 @@ arma_assert_blas_size(const T1& A, const T2& B) +// TODO: remove support for ATLAS in next major version template arma_hot inline @@ -1201,6 +1306,7 @@ arma_assert_atlas_size(const T1& A) +// TODO: remove support for ATLAS in next major version template arma_hot inline @@ -1236,11 +1342,11 @@ arma_assert_atlas_size(const T1& A, const T2& B) #if defined(ARMA_NO_DEBUG) - #undef ARMA_EXTRA_DEBUG - #define arma_debug_print true ? (void)0 : arma_print #define arma_debug_warn true ? (void)0 : arma_warn + #define arma_debug_warn_level true ? (void)0 : arma_warn_level #define arma_debug_check true ? (void)0 : arma_check + #define arma_debug_check_bounds true ? (void)0 : arma_check_bounds #define arma_debug_set_error true ? (void)0 : arma_set_error #define arma_debug_assert_same_size true ? (void)0 : arma_assert_same_size #define arma_debug_assert_mul_size true ? (void)0 : arma_assert_mul_size @@ -1253,7 +1359,9 @@ arma_assert_atlas_size(const T1& A, const T2& B) #define arma_debug_print arma_print #define arma_debug_warn arma_warn + #define arma_debug_warn_level arma_warn_level #define arma_debug_check arma_check + #define arma_debug_check_bounds arma_check_bounds #define arma_debug_set_error arma_set_error #define arma_debug_assert_same_size arma_assert_same_size #define arma_debug_assert_mul_size arma_assert_mul_size @@ -1271,17 +1379,13 @@ arma_assert_atlas_size(const T1& A, const T2& B) #define arma_extra_debug_sigprint arma_sigprint(ARMA_FNSIG); arma_bktprint #define arma_extra_debug_sigprint_this arma_sigprint(ARMA_FNSIG); arma_thisprint #define arma_extra_debug_print arma_print - #define arma_extra_debug_warn arma_warn - #define arma_extra_debug_check arma_check - + #else #define arma_extra_debug_sigprint true ? (void)0 : arma_bktprint #define arma_extra_debug_sigprint_this true ? (void)0 : arma_thisprint #define arma_extra_debug_print true ? (void)0 : arma_print - #define arma_extra_debug_warn true ? (void)0 : arma_warn - #define arma_extra_debug_check true ? (void)0 : arma_check - + #endif @@ -1316,23 +1420,32 @@ arma_assert_atlas_size(const T1& A, const T2& B) << arma_version::major << '.' << arma_version::minor << '.' << arma_version::patch << " (" << nickname << ")\n"; - out << "@ arma_config::wrapper = " << arma_config::wrapper << '\n'; - out << "@ arma_config::cxx11 = " << arma_config::cxx11 << '\n'; - out << "@ arma_config::posix = " << arma_config::posix << '\n'; - out << "@ arma_config::openmp = " << arma_config::openmp << '\n'; - out << "@ arma_config::lapack = " << arma_config::lapack << '\n'; - out << "@ arma_config::blas = " << arma_config::blas << '\n'; - out << "@ arma_config::newarp = " << arma_config::newarp << '\n'; - out << "@ arma_config::arpack = " << arma_config::arpack << '\n'; - out << "@ arma_config::superlu = " << arma_config::superlu << '\n'; - out << "@ arma_config::atlas = " << arma_config::atlas << '\n'; - out << "@ arma_config::hdf5 = " << arma_config::hdf5 << '\n'; - out << "@ arma_config::good_comp = " << arma_config::good_comp << '\n'; - out << "@ arma_config::extra_code = " << arma_config::extra_code << '\n'; - out << "@ arma_config::hidden_args = " << arma_config::hidden_args << '\n'; - out << "@ arma_config::mat_prealloc = " << arma_config::mat_prealloc << '\n'; - out << "@ arma_config::mp_threshold = " << arma_config::mp_threshold << '\n'; - out << "@ arma_config::mp_threads = " << arma_config::mp_threads << '\n'; + out << "@ arma_config::wrapper = " << arma_config::wrapper << '\n'; + out << "@ arma_config::cxx14 = " << arma_config::cxx14 << '\n'; + out << "@ arma_config::cxx17 = " << arma_config::cxx17 << '\n'; + out << "@ arma_config::cxx20 = " << arma_config::cxx20 << '\n'; + out << "@ arma_config::std_mutex = " << arma_config::std_mutex << '\n'; + out << "@ arma_config::posix = " << arma_config::posix << '\n'; + out << "@ arma_config::openmp = " << arma_config::openmp << '\n'; + out << "@ arma_config::lapack = " << arma_config::lapack << '\n'; + out << "@ arma_config::blas = " << arma_config::blas << '\n'; + out << "@ arma_config::newarp = " << arma_config::newarp << '\n'; + out << "@ arma_config::arpack = " << arma_config::arpack << '\n'; + out << "@ arma_config::superlu = " << arma_config::superlu << '\n'; + out << "@ arma_config::atlas = " << arma_config::atlas << '\n'; + out << "@ arma_config::hdf5 = " << arma_config::hdf5 << '\n'; + out << "@ arma_config::good_comp = " << arma_config::good_comp << '\n'; + out << "@ arma_config::extra_code = " << arma_config::extra_code << '\n'; + out << "@ arma_config::hidden_args = " << arma_config::hidden_args << '\n'; + out << "@ arma_config::mat_prealloc = " << arma_config::mat_prealloc << '\n'; + out << "@ arma_config::mp_threshold = " << arma_config::mp_threshold << '\n'; + out << "@ arma_config::mp_threads = " << arma_config::mp_threads << '\n'; + out << "@ arma_config::optimise_band = " << arma_config::optimise_band << '\n'; + out << "@ arma_config::optimise_sym = " << arma_config::optimise_sym << '\n'; + out << "@ arma_config::optimise_invexpr = " << arma_config::optimise_invexpr << '\n'; + out << "@ arma_config::check_nonfinite = " << arma_config::check_nonfinite << '\n'; + out << "@ arma_config::zero_init = " << arma_config::zero_init << '\n'; + out << "@ arma_config::fast_math = " << arma_config::fast_math << '\n'; out << "@ sizeof(void*) = " << sizeof(void*) << '\n'; out << "@ sizeof(int) = " << sizeof(int) << '\n'; out << "@ sizeof(long) = " << sizeof(long) << '\n'; diff --git a/src/armadillo_bits/def_arpack.hpp b/src/armadillo_bits/def_arpack.hpp index 0278fed5..5bbbb7f9 100644 --- a/src/armadillo_bits/def_arpack.hpp +++ b/src/armadillo_bits/def_arpack.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,7 +16,7 @@ // ------------------------------------------------------------------------ -#ifdef ARMA_USE_ARPACK +#if defined(ARMA_USE_ARPACK) // I'm not sure this is necessary. #if !defined(ARMA_BLAS_CAPITALS) diff --git a/src/armadillo_bits/def_atlas.hpp b/src/armadillo_bits/def_atlas.hpp index 67800182..e410d9b0 100644 --- a/src/armadillo_bits/def_atlas.hpp +++ b/src/armadillo_bits/def_atlas.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,120 +16,63 @@ // ------------------------------------------------------------------------ -#ifdef ARMA_USE_ATLAS +// TODO: remove support for ATLAS in next major version + +#if defined(ARMA_USE_ATLAS) + +typedef enum + { + atlas_CblasRowMajor = 101, + atlas_CblasColMajor = 102 + } + atlas_CBLAS_LAYOUT; -//! \namespace atlas namespace for ATLAS functions (imported from the global namespace) -namespace atlas +typedef enum { - using ::CblasColMajor; - using ::CblasNoTrans; - using ::CblasTrans; - using ::CblasConjTrans; - using ::CblasLower; - using ::CblasUpper; + atlas_CblasNoTrans = 111, + atlas_CblasTrans = 112, + atlas_CblasConjTrans = 113 + } + atlas_CBLAS_TRANS; - #if defined(ARMA_USE_WRAPPER) - extern "C" - { - float wrapper_cblas_sasum(const int N, const float *X, const int incX); - double wrapper_cblas_dasum(const int N, const double *X, const int incX); - - float wrapper_cblas_snrm2(const int N, const float *X, const int incX); - double wrapper_cblas_dnrm2(const int N, const double *X, const int incX); - - float wrapper_cblas_sdot(const int N, const float *X, const int incX, const float *Y, const int incY); - double wrapper_cblas_ddot(const int N, const double *X, const int incX, const double *Y, const int incY); - - void wrapper_cblas_cdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); - void wrapper_cblas_zdotu_sub(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); - - - void wrapper_cblas_sgemv(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const float alpha, - const float *A, const int lda, const float *X, const int incX, const float beta, float *Y, const int incY); - - void wrapper_cblas_dgemv(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const double alpha, - const double *A, const int lda, const double *X, const int incX, const double beta, double *Y, const int incY); - - void wrapper_cblas_cgemv(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const void *alpha, - const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); - - void wrapper_cblas_zgemv(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const int M, const int N, const void *alpha, - const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); - - - - void wrapper_cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, - const int M, const int N, const int K, const float alpha, - const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc); - - void wrapper_cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, - const int M, const int N, const int K, const double alpha, - const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc); - - void wrapper_cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, - const int M, const int N, const int K, const void *alpha, - const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); - - void wrapper_cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_TRANSPOSE TransB, - const int M, const int N, const int K, const void *alpha, - const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); - - - - void wrapper_cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, - const int N, const int K, const float alpha, - const float *A, const int lda, const float beta, float *C, const int ldc); - - void wrapper_cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, - const int N, const int K, const double alpha, - const double *A, const int lda, const double beta, double *C, const int ldc); - - - - void wrapper_cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, - const int N, const int K, const float alpha, - const void *A, const int lda, const float beta, void *C, const int ldc); - - void wrapper_cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, - const int N, const int K, const double alpha, - const void *A, const int lda, const double beta, void *C, const int ldc); - - - - int wrapper_clapack_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N, float *A, const int lda, int *ipiv); - int wrapper_clapack_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N, double *A, const int lda, int *ipiv); - int wrapper_clapack_cgetrf(const enum CBLAS_ORDER Order, const int M, const int N, void *A, const int lda, int *ipiv); - int wrapper_clapack_zgetrf(const enum CBLAS_ORDER Order, const int M, const int N, void *A, const int lda, int *ipiv); - - int wrapper_clapack_sgetri(const enum CBLAS_ORDER Order, const int N, float *A, const int lda, const int *ipiv); - int wrapper_clapack_dgetri(const enum CBLAS_ORDER Order, const int N, double *A, const int lda, const int *ipiv); - int wrapper_clapack_cgetri(const enum CBLAS_ORDER Order, const int N, void *A, const int lda, const int *ipiv); - int wrapper_clapack_zgetri(const enum CBLAS_ORDER Order, const int N, void *A, const int lda, const int *ipiv); +typedef enum + { + atlas_CblasUpper = 121, + atlas_CblasLower = 122 + } + atlas_CBLAS_UPLO; + - int wrapper_clapack_sgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, float *A, const int lda, int *ipiv, float *B, const int ldb); - int wrapper_clapack_dgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, double *A, const int lda, int *ipiv, double *B, const int ldb); - int wrapper_clapack_cgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, void *A, const int lda, int *ipiv, void *B, const int ldb); - int wrapper_clapack_zgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS, void *A, const int lda, int *ipiv, void *B, const int ldb); - - - - int wrapper_clapack_spotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, float *A, const int lda); - int wrapper_clapack_dpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, double *A, const int lda); - int wrapper_clapack_cpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, void *A, const int lda); - int wrapper_clapack_zpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, void *A, const int lda); - - int wrapper_clapack_spotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, float *A, const int lda); - int wrapper_clapack_dpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, double *A, const int lda); - int wrapper_clapack_cpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, void *A, const int lda); - int wrapper_clapack_zpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, void *A, const int lda); - - int wrapper_clapack_sposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, float *A, const int lda, float *B, const int ldb); - int wrapper_clapack_dposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, double *A, const int lda, double *B, const int ldb); - int wrapper_clapack_cposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, void *A, const int lda, void *B, const int ldb); - int wrapper_clapack_zposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, void *A, const int lda, void *B, const int ldb); - } - #endif +extern "C" + { + float arma_wrapper(cblas_sasum)(const int N, const float *X, const int incX); + double arma_wrapper(cblas_dasum)(const int N, const double *X, const int incX); + + float arma_wrapper(cblas_snrm2)(const int N, const float *X, const int incX); + double arma_wrapper(cblas_dnrm2)(const int N, const double *X, const int incX); + + float arma_wrapper(cblas_sdot)(const int N, const float *X, const int incX, const float *Y, const int incY); + double arma_wrapper(cblas_ddot)(const int N, const double *X, const int incX, const double *Y, const int incY); + + void arma_wrapper(cblas_cdotu_sub)(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); + void arma_wrapper(cblas_zdotu_sub)(const int N, const void *X, const int incX, const void *Y, const int incY, void *dotu); + + void arma_wrapper(cblas_sgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const float alpha, const float *A, const int lda, const float *X, const int incX, const float beta, float *Y, const int incY); + void arma_wrapper(cblas_dgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const double alpha, const double *A, const int lda, const double *X, const int incX, const double beta, double *Y, const int incY); + void arma_wrapper(cblas_cgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const void *alpha, const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); + void arma_wrapper(cblas_zgemv)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const void *alpha, const void *A, const int lda, const void *X, const int incX, const void *beta, void *Y, const int incY); + + void arma_wrapper(cblas_sgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const float alpha, const float *A, const int lda, const float *B, const int ldb, const float beta, float *C, const int ldc); + void arma_wrapper(cblas_dgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc); + void arma_wrapper(cblas_cgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const void *alpha, const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); + void arma_wrapper(cblas_zgemm)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const void *alpha, const void *A, const int lda, const void *B, const int ldb, const void *beta, void *C, const int ldc); + + void arma_wrapper(cblas_ssyrk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const float alpha, const float *A, const int lda, const float beta, float *C, const int ldc); + void arma_wrapper(cblas_dsyrk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const double alpha, const double *A, const int lda, const double beta, double *C, const int ldc); + + void arma_wrapper(cblas_cherk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const float alpha, const void *A, const int lda, const float beta, void *C, const int ldc); + void arma_wrapper(cblas_zherk)(const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const double alpha, const void *A, const int lda, const double beta, void *C, const int ldc); } diff --git a/src/armadillo_bits/def_blas.hpp b/src/armadillo_bits/def_blas.hpp index e20b0624..e27ca6c4 100644 --- a/src/armadillo_bits/def_blas.hpp +++ b/src/armadillo_bits/def_blas.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,13 +17,23 @@ -#ifdef ARMA_USE_BLAS +#if defined(ARMA_USE_BLAS) #if defined(dgemm) || defined(DGEMM) #pragma message ("WARNING: detected possible interference with definitions of BLAS functions;") #pragma message ("WARNING: include the armadillo header before any other header as a workaround") #endif + +#if defined(ARMA_BLAS_NOEXCEPT) + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT noexcept +#else + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT +#endif + + #if !defined(ARMA_BLAS_CAPITALS) #define arma_sasum sasum @@ -87,62 +99,63 @@ extern "C" { #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) - float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx); - double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx); + float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; - float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx); - double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx); + float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; - float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy); - double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy); + float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy) ARMA_NOEXCEPT; + double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy) ARMA_NOEXCEPT; - void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy, blas_len transA_len); - void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy, blas_len transA_len); - void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy, blas_len transA_len); - void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy, blas_len transA_len); + void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy, blas_len transA_len) ARMA_NOEXCEPT; - void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len); - void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len); - void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len); - void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len); + void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC, blas_len transA_len, blas_len transB_len) ARMA_NOEXCEPT; - void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len); - void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len); + void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; - void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len); - void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len); + void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC, blas_len uplo_len, blas_len transA_len) ARMA_NOEXCEPT; #else // prototypes without hidden arguments - float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx); - double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx); + float arma_fortran(arma_sasum)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dasum)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; - float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx); - double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx); + float arma_fortran(arma_snrm2)(const blas_int* n, const float* x, const blas_int* incx) ARMA_NOEXCEPT; + double arma_fortran(arma_dnrm2)(const blas_int* n, const double* x, const blas_int* incx) ARMA_NOEXCEPT; - float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy); - double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy); + float arma_fortran(arma_sdot)(const blas_int* n, const float* x, const blas_int* incx, const float* y, const blas_int* incy) ARMA_NOEXCEPT; + double arma_fortran(arma_ddot)(const blas_int* n, const double* x, const blas_int* incx, const double* y, const blas_int* incy) ARMA_NOEXCEPT; - void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy); - void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy); - void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy); - void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy); + void arma_fortran(arma_sgemv)(const char* transA, const blas_int* m, const blas_int* n, const float* alpha, const float* A, const blas_int* ldA, const float* x, const blas_int* incx, const float* beta, float* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemv)(const char* transA, const blas_int* m, const blas_int* n, const double* alpha, const double* A, const blas_int* ldA, const double* x, const blas_int* incx, const double* beta, double* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* x, const blas_int* incx, const blas_cxf* beta, blas_cxf* y, const blas_int* incy) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemv)(const char* transA, const blas_int* m, const blas_int* n, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* x, const blas_int* incx, const blas_cxd* beta, blas_cxd* y, const blas_int* incy) ARMA_NOEXCEPT; - void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC); - void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC); - void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC); - void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC); + void arma_fortran(arma_sgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* B, const blas_int* ldB, const float* beta, float* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_dgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* B, const blas_int* ldB, const double* beta, double* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_cgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxf* alpha, const blas_cxf* A, const blas_int* ldA, const blas_cxf* B, const blas_int* ldB, const blas_cxf* beta, blas_cxf* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_zgemm)(const char* transA, const char* transB, const blas_int* m, const blas_int* n, const blas_int* k, const blas_cxd* alpha, const blas_cxd* A, const blas_int* ldA, const blas_cxd* B, const blas_int* ldB, const blas_cxd* beta, blas_cxd* C, const blas_int* ldC) ARMA_NOEXCEPT; - void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC); - void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC); + void arma_fortran(arma_ssyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const float* A, const blas_int* ldA, const float* beta, float* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyrk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const double* A, const blas_int* ldA, const double* beta, double* C, const blas_int* ldC) ARMA_NOEXCEPT; - void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC); - void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC); + void arma_fortran(arma_cherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const float* alpha, const blas_cxf* A, const blas_int* ldA, const float* beta, blas_cxf* C, const blas_int* ldC) ARMA_NOEXCEPT; + void arma_fortran(arma_zherk)(const char* uplo, const char* transA, const blas_int* n, const blas_int* k, const double* alpha, const blas_cxd* A, const blas_int* ldA, const double* beta, blas_cxd* C, const blas_int* ldC) ARMA_NOEXCEPT; #endif } +#undef ARMA_NOEXCEPT #endif diff --git a/src/armadillo_bits/def_fftw3.hpp b/src/armadillo_bits/def_fftw3.hpp new file mode 100644 index 00000000..454d7524 --- /dev/null +++ b/src/armadillo_bits/def_fftw3.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_FFTW3) + + +extern "C" + { + // function prefix for single precision: fftwf_ + // function prefix for double precision: fftw_ + + + // single precision (float) + + void_ptr fftwf_plan_dft_1d(int N, void* input, void* output, int fftw3_sign, unsigned int fftw3_flags); + + void fftwf_execute(void_ptr plan); + void fftwf_destroy_plan(void_ptr plan); + + void fftwf_cleanup(); + + + // double precision (double) + + void_ptr fftw_plan_dft_1d(int N, void* input, void* output, int fftw3_sign, unsigned int fftw3_flags); + + void fftw_execute(void_ptr plan); + void fftw_destroy_plan(void_ptr plan); + + void fftw_cleanup(); + } + + +#endif diff --git a/src/armadillo_bits/def_hdf5.hpp b/src/armadillo_bits/def_hdf5.hpp deleted file mode 100644 index 70cdffe2..00000000 --- a/src/armadillo_bits/def_hdf5.hpp +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -#if defined(ARMA_USE_HDF5) - -#if !defined(ARMA_USE_HDF5_ALT) - - // macros needed if the wrapper run-time library is not being used - - #define arma_H5Tcopy H5Tcopy - #define arma_H5Tcreate H5Tcreate - #define arma_H5Tinsert H5Tinsert - #define arma_H5Tequal H5Tequal - #define arma_H5Tclose H5Tclose - - #define arma_H5Dopen H5Dopen - #define arma_H5Dget_type H5Dget_type - #define arma_H5Dclose H5Dclose - #define arma_H5Dwrite H5Dwrite - #define arma_H5Dget_space H5Dget_space - #define arma_H5Dread H5Dread - #define arma_H5Dcreate H5Dcreate - - #define arma_H5Sget_simple_extent_ndims H5Sget_simple_extent_ndims - #define arma_H5Sget_simple_extent_dims H5Sget_simple_extent_dims - #define arma_H5Sclose H5Sclose - #define arma_H5Screate_simple H5Screate_simple - - #define arma_H5Ovisit H5Ovisit - - #define arma_H5Eset_auto H5Eset_auto - #define arma_H5Eget_auto H5Eget_auto - - #define arma_H5Fopen H5Fopen - #define arma_H5Fcreate H5Fcreate - #define arma_H5Fclose H5Fclose - #define arma_H5Fis_hdf5 H5Fis_hdf5 - - #define arma_H5Gcreate H5Gcreate - #define arma_H5Gopen H5Gopen - #define arma_H5Gclose H5Gclose - - #define arma_H5Lexists H5Lexists - #define arma_H5Ldelete H5Ldelete - - #define arma_H5T_NATIVE_UCHAR H5T_NATIVE_UCHAR - #define arma_H5T_NATIVE_CHAR H5T_NATIVE_CHAR - #define arma_H5T_NATIVE_SHORT H5T_NATIVE_SHORT - #define arma_H5T_NATIVE_USHORT H5T_NATIVE_USHORT - #define arma_H5T_NATIVE_INT H5T_NATIVE_INT - #define arma_H5T_NATIVE_UINT H5T_NATIVE_UINT - #define arma_H5T_NATIVE_LONG H5T_NATIVE_LONG - #define arma_H5T_NATIVE_ULONG H5T_NATIVE_ULONG - #define arma_H5T_NATIVE_LLONG H5T_NATIVE_LLONG - #define arma_H5T_NATIVE_ULLONG H5T_NATIVE_ULLONG - #define arma_H5T_NATIVE_FLOAT H5T_NATIVE_FLOAT - #define arma_H5T_NATIVE_DOUBLE H5T_NATIVE_DOUBLE - -#else - -// prototypes for the wrapper functions defined in the wrapper run-time library (src/wrapper.cpp) - -extern "C" - { - // Wrapper functions for H5* functions. - hid_t arma_H5Tcopy(hid_t dtype_id); - hid_t arma_H5Tcreate(H5T_class_t cl, size_t size); - herr_t arma_H5Tinsert(hid_t dtype_id, const char* name, size_t offset, hid_t field_id); - htri_t arma_H5Tequal(hid_t dtype_id1, hid_t dtype_id2); - herr_t arma_H5Tclose(hid_t dtype_id); - - hid_t arma_H5Dopen(hid_t loc_id, const char* name, hid_t dapl_id); - hid_t arma_H5Dget_type(hid_t dataset_id); - herr_t arma_H5Dclose(hid_t dataset_id); - hid_t arma_H5Dcreate(hid_t loc_id, const char* name, hid_t dtype_id, hid_t space_id, hid_t lcpl_id, hid_t dcpl_id, hid_t dapl_id); - herr_t arma_H5Dwrite(hid_t dataset_id, hid_t mem_type_id, hid_t mem_space_id, hid_t file_space_id, hid_t xfer_plist_id, const void* buf); - hid_t arma_H5Dget_space(hid_t dataset_id); - herr_t arma_H5Dread(hid_t dataset_id, hid_t mem_type_id, hid_t mem_space_id, hid_t file_space_id, hid_t xfer_plist_id, void* buf); - - int arma_H5Sget_simple_extent_ndims(hid_t space_id); - int arma_H5Sget_simple_extent_dims(hid_t space_id, hsize_t* dims, hsize_t* maxdims); - herr_t arma_H5Sclose(hid_t space_id); - hid_t arma_H5Screate_simple(int rank, const hsize_t* current_dims, const hsize_t* maximum_dims); - - herr_t arma_H5Ovisit(hid_t object_id, H5_index_t index_type, H5_iter_order_t order, H5O_iterate_t op, void* op_data); - - herr_t arma_H5Eset_auto(hid_t estack_id, H5E_auto_t func, void* client_data); - herr_t arma_H5Eget_auto(hid_t estack_id, H5E_auto_t* func, void** client_data); - - hid_t arma_H5Fopen(const char* name, unsigned flags, hid_t fapl_id); - hid_t arma_H5Fcreate(const char* name, unsigned flags, hid_t fcpl_id, hid_t fapl_id); - herr_t arma_H5Fclose(hid_t file_id); - htri_t arma_H5Fis_hdf5(const char* name); - - hid_t arma_H5Gcreate(hid_t loc_id, const char* name, hid_t lcpl_id, hid_t gcpl_id, hid_t gapl_id); - hid_t arma_H5Gopen(hid_t loc_id, const char* name, hid_t gapl_id); - herr_t arma_H5Gclose(hid_t group_id); - - htri_t arma_H5Lexists(hid_t loc_id, const char* name, hid_t lapl_id); - herr_t arma_H5Ldelete(hid_t loc_id, const char* name, hid_t lapl_id); - - // Wrapper variables that represent the hid_t values for the H5T_NATIVE_* - // types. Note that H5T_NATIVE_UCHAR itself is a macro that resolves to about - // forty other macros, and we definitely don't want to hijack those, - // so this is the best way to go about wrapping these... - extern hid_t arma_H5T_NATIVE_UCHAR; - extern hid_t arma_H5T_NATIVE_CHAR; - extern hid_t arma_H5T_NATIVE_SHORT; - extern hid_t arma_H5T_NATIVE_USHORT; - extern hid_t arma_H5T_NATIVE_INT; - extern hid_t arma_H5T_NATIVE_UINT; - extern hid_t arma_H5T_NATIVE_LONG; - extern hid_t arma_H5T_NATIVE_ULONG; - extern hid_t arma_H5T_NATIVE_LLONG; - extern hid_t arma_H5T_NATIVE_ULLONG; - extern hid_t arma_H5T_NATIVE_FLOAT; - extern hid_t arma_H5T_NATIVE_DOUBLE; - - } - - // Lastly, we have to hijack H5open() and H5check_version(), which are called - // by some expanded macros of the other H5* functions. This means we can't - // create arma_H5open(), because we can't modify those macros. Instead, we'll - // create arma::H5open() and arma::H5check_version(), and then issue a using - // directive so that arma::H5open() and arma::H5check_version() are always - // called. - // - // There is potential danger in the use of a using directive like this, but in - // this case, I can't think of a better way to solve the problem, and I doubt - // this will cause problems in any situations that aren't truly bizarre. And - // if it does cause problems, the user can #define ARMA_DONT_USE_WRAPPER or - // #undef ARMA_USE_WRAPPER in their Armadillo configuration. - herr_t H5open(); - herr_t H5check_version(unsigned majnum, unsigned minnum, unsigned relnum); - - using arma::H5open; - using arma::H5check_version; - -#endif - -#endif diff --git a/src/armadillo_bits/def_lapack.hpp b/src/armadillo_bits/def_lapack.hpp index e5f80b95..00854ab0 100644 --- a/src/armadillo_bits/def_lapack.hpp +++ b/src/armadillo_bits/def_lapack.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,13 +17,23 @@ -#ifdef ARMA_USE_LAPACK +#if defined(ARMA_USE_LAPACK) #if defined(dgetrf) || defined(DGETRF) #pragma message ("WARNING: detected possible interference with definitions of LAPACK functions;") #pragma message ("WARNING: include the armadillo header before any other header as a workaround") #endif + +#if defined(ARMA_LAPACK_NOEXCEPT) + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT noexcept +#else + #undef ARMA_NOEXCEPT + #define ARMA_NOEXCEPT +#endif + + #if !defined(ARMA_BLAS_CAPITALS) #define arma_sgetrf sgetrf #define arma_dgetrf dgetrf @@ -96,6 +108,11 @@ #define arma_cgeqrf cgeqrf #define arma_zgeqrf zgeqrf + #define arma_sgeqp3 sgeqp3 + #define arma_dgeqp3 dgeqp3 + #define arma_cgeqp3 cgeqp3 + #define arma_zgeqp3 zgeqp3 + #define arma_sorgqr sorgqr #define arma_dorgqr dorgqr @@ -242,14 +259,16 @@ #define arma_strevc strevc #define arma_dtrevc dtrevc - #define arma_slarnv slarnv - #define arma_dlarnv dlarnv - #define arma_sgehrd sgehrd #define arma_dgehrd dgehrd #define arma_cgehrd cgehrd #define arma_zgehrd zgehrd + #define arma_spstrf spstrf + #define arma_dpstrf dpstrf + #define arma_cpstrf cpstrf + #define arma_zpstrf zpstrf + #else #define arma_sgetrf SGETRF @@ -325,6 +344,11 @@ #define arma_cgeqrf CGEQRF #define arma_zgeqrf ZGEQRF + #define arma_sgeqp3 SGEQP3 + #define arma_dgeqp3 DGEQP3 + #define arma_cgeqp3 CGEQP3 + #define arma_zgeqp3 ZGEQP3 + #define arma_sorgqr SORGQR #define arma_dorgqr DORGQR @@ -471,14 +495,16 @@ #define arma_strevc STREVC #define arma_dtrevc DTREVC - #define arma_slarnv SLARNV - #define arma_dlarnv DLARNV - #define arma_sgehrd SGEHRD #define arma_dgehrd DGEHRD #define arma_cgehrd CGEHRD #define arma_zgehrd ZGEHRD + #define arma_spstrf SPSTRF + #define arma_dpstrf DPSTRF + #define arma_cpstrf CPSTRF + #define arma_zpstrf ZPSTRF + #endif @@ -504,627 +530,649 @@ extern "C" #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) // LU decomposition - void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info); + void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations using pre-computed LU decomposition - void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, const blas_len trans_len); - void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, const blas_len trans_len); - void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, const blas_len trans_len); - void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, const blas_len trans_len); + void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, const blas_len trans_len) ARMA_NOEXCEPT; // matrix inversion (using pre-computed LU decomposition) - void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // matrix inversion (triangular matrices) - void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len); - void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len); - void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len); - void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len); + void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; // eigen decomposition of general matrix (real) - void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); - void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); + void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; // eigen decomposition of general matrix (complex) - void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); - void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); + void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; // eigen decomposition of general matrix (real; advanced form) - void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len); - void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len); + void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; // eigen decomposition of general matrix (complex; advanced form) - void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len); - void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len); + void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* info, blas_len balanc_len, blas_len jobvl_len, blas_len jobvr_len, blas_len sense_len) ARMA_NOEXCEPT; // eigen decomposition of symmetric real matrices - void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); - void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); + void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; // eigen decomposition of hermitian matrices (complex) - void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); - void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); + void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; // eigen decomposition of symmetric real matrices by divide and conquer - void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); - void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); + void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; // eigen decomposition of hermitian matrices (complex) by divide and conquer - void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); - void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len); + void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len jobz_len, blas_len uplo_len) ARMA_NOEXCEPT; // eigen decomposition of general real matrix pair - void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); - void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); + void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; // eigen decomposition of general complex matrix pair - void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); - void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len); + void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobvl_len, blas_len jobvr_len) ARMA_NOEXCEPT; // Cholesky decomposition - void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // solve system of linear equations using pre-computed Cholesky decomposition - void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // Cholesky decomposition (band matrices) - void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // matrix inversion (using pre-computed Cholesky decomposition) - void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // QR decomposition - void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (real matrices) + void arma_fortran(arma_sgeqp3)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* jpvt, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqp3)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* jpvt, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (complex matrices) + void arma_fortran(arma_cgeqp3)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* jpvt, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqp3)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* jpvt, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // Q matrix calculation from QR decomposition (real matrices) - void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // Q matrix calculation from QR decomposition (complex matrices) - void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // SVD (real matrices) - void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len); - void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len); + void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; // SVD (complex matrices) - void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len); - void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len); + void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info, blas_len jobu_len, blas_len jobvt_len) ARMA_NOEXCEPT; // SVD (real matrices) by divide and conquer - void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len); - void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len); + void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; // SVD (complex matrices) by divide and conquer - void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len); - void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len); + void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info, blas_len jobz_len) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix) - void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix, advanced form, real matrices) - void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); - void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); + void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix, advanced form, complex matrices) - void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); - void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); + void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; // solve system of linear equations (symmetric positive definite matrix) - void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // solve system of linear equations (symmetric positive definite matrix, advanced form, real matrices) - void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len); - void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len); + void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; // solve system of linear equations (hermitian positive definite matrix, advanced form, complex matrices) - void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len); - void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len); + void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len uplo_len, blas_len equed_len) ARMA_NOEXCEPT; // solve over/under-determined system of linear equations - void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info, blas_len trans_len); - void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info, blas_len trans_len); - void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info, blas_len trans_len); - void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info, blas_len trans_len); + void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; // approximately solve system of linear equations using svd (real) - void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // approximately solve system of linear equations using svd (complex) - void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (triangular matrix) - void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len); - void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len); - void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len); - void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len); + void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len uplo_len, blas_len trans_len, blas_len diag_len) ARMA_NOEXCEPT; // LU factorisation (general band matrix) - void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); + void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations using pre-computed LU decomposition (general band matrix) - void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, blas_len trans_len); - void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, blas_len trans_len); - void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len trans_len); - void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len trans_len); + void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const float* ab, const blas_int* ldab, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const double* ab, const blas_int* ldab, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info, blas_len trans_len) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix) - void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix, advanced form, real matrices) - void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); - void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); + void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix, advanced form, complex matrices) - void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); - void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len); + void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len, blas_len equed_len) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix) - void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix, advanced form, real matrices) - void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len); - void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len); + void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix, advanced form, complex matrices) - void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len); - void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len); + void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info, blas_len fact_len, blas_len trans_len) ARMA_NOEXCEPT; // Schur decomposition (real matrices) - void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len); - void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len); + void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; // Schur decomposition (complex matrices) - void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len); - void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len); + void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvs_len, blas_len sort_len) ARMA_NOEXCEPT; // solve a Sylvester equation ax + xb = c, with a and b assumed to be in Schur form - void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len); - void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len); - void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len); - void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len); + void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info, blas_len transa_len, blas_len transb_len) ARMA_NOEXCEPT; // QZ decomposition (real matrices) - void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len); - void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len); + void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; // QZ decomposition (complex matrices) - void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len); - void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len); + void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info, blas_len jobvsl_len, blas_len jobvsr_len, blas_len sort_len) ARMA_NOEXCEPT; // 1-norm (general matrix) - float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len); - double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len); - float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len); - double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len); + float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len) ARMA_NOEXCEPT; // 1-norm (real symmetric matrix) - float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len); - double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len); - float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len); - double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len); + float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; // 1-norm (complex hermitian matrix) - float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len); - double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len); + float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work, blas_len norm_len, blas_len uplo_len) ARMA_NOEXCEPT; // 1-norm (band matrix) - float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work, blas_len norm_len); - double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work, blas_len norm_len); - float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work, blas_len norm_len); - double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work, blas_len norm_len); + float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work, blas_len norm_len) ARMA_NOEXCEPT; + float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work, blas_len norm_len) ARMA_NOEXCEPT; + double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work, blas_len norm_len) ARMA_NOEXCEPT; // reciprocal of condition number (real, generic matrix) - void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len); - void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len); + void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; // reciprocal of condition number (complex, generic matrix) - void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len); - void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len); + void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; // reciprocal of condition number (real, symmetric positive definite matrix) - void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // reciprocal of condition number (complex, hermitian positive definite matrix) - void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len uplo_len); - void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len uplo_len); + void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; // reciprocal of condition number (real, triangular matrix) - void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len); - void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len); + void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; // reciprocal of condition number (complex, triangular matrix) - void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len); - void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len); + void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len, blas_len uplo_len, blas_len diag_len) ARMA_NOEXCEPT; // reciprocal of condition number (real, band matrix) - void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len); - void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len); + void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; // reciprocal of condition number (complex, band matrix) - void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len); - void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len); + void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info, blas_len norm_len) ARMA_NOEXCEPT; // obtain parameters according to the local configuration of lapack - blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4, blas_len name_len, blas_len opts_len); + blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4, blas_len name_len, blas_len opts_len) ARMA_NOEXCEPT; // calculate eigenvalues of an upper Hessenberg matrix - void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info); - void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info); + void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; // calculate eigenvalues of a symmetric tridiagonal matrix - void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len); - void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len); + void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info, blas_len compz_len) ARMA_NOEXCEPT; // calculate eigenvectors of a Schur form matrix - void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info, blas_len side_len, blas_len howmny_len); - void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info, blas_len side_len, blas_len howmny_len); - - // generate a vector of random numbers - void arma_fortran(arma_slarnv)(const blas_int* idist, blas_int* iseed, const blas_int* n, float* x); - void arma_fortran(arma_dlarnv)(const blas_int* idist, blas_int* iseed, const blas_int* n, double* x); + void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info, blas_len side_len, blas_len howmny_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info, blas_len side_len, blas_len howmny_len) ARMA_NOEXCEPT; // hessenberg decomposition - void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // pivoted cholesky + void arma_fortran(arma_spstrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_dpstrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; + void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info, blas_len uplo_len) ARMA_NOEXCEPT; #else // prototypes without hidden arguments // LU decomposition - void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info); + void arma_fortran(arma_sgetrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations using pre-computed LU decomposition - void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetrs)(const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // matrix inversion (using pre-computed LU decomposition) - void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgetri)(const blas_int* n, float* a, const blas_int* lda, const blas_int* ipiv, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgetri)(const blas_int* n, double* a, const blas_int* lda, const blas_int* ipiv, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgetri)(const blas_int* n, blas_cxf* a, const blas_int* lda, const blas_int* ipiv, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgetri)(const blas_int* n, blas_cxd* a, const blas_int* lda, const blas_int* ipiv, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // matrix inversion (triangular matrices) - void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info); + void arma_fortran(arma_strtri)(const char* uplo, const char* diag, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtri)(const char* uplo, const char* diag, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtri)(const char* uplo, const char* diag, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general matrix (real) - void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgeev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general matrix (complex) - void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info); - void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info); + void arma_fortran(arma_cgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general matrix (real; advanced form) - void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, float* a, const blas_int* lda, float* wr, float* wi, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, double* a, const blas_int* lda, double* wr, double* wi, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general matrix (complex; advanced form) - void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info); - void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info); + void arma_fortran(arma_cgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* w, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, float* scale, float* abnrm, float* rconde, float* rcondv, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeevx)(const char* balanc, const char* jobvl, const char* jobvr, const char* sense, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* w, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_int* ilo, blas_int* ihi, double* scale, double* abnrm, double* rconde, double* rcondv, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of symmetric real matrices - void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_ssyev)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyev)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of hermitian matrices (complex) - void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info); - void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info); + void arma_fortran(arma_cheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zheev)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of symmetric real matrices by divide and conquer - void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info); - void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info); + void arma_fortran(arma_ssyevd)(const char* jobz, const char* uplo, const blas_int* n, float* a, const blas_int* lda, float* w, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dsyevd)(const char* jobz, const char* uplo, const blas_int* n, double* a, const blas_int* lda, double* w, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of hermitian matrices (complex) by divide and conquer - void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info); - void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info); + void arma_fortran(arma_cheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, float* w, blas_cxf* work, const blas_int* lwork, float* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zheevd)(const char* jobz, const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, double* w, blas_cxd* work, const blas_int* lwork, double* rwork, const blas_int* lrwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general real matrix pair - void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sggev)(const char* jobvl, const char* jobvr, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* alphar, float* alphai, float* beta, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dggev)(const char* jobvl, const char* jobvr, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* alphar, double* alphai, double* beta, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // eigen decomposition of general complex matrix pair - void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info); - void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info); + void arma_fortran(arma_cggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vl, const blas_int* ldvl, blas_cxf* vr, const blas_int* ldvr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zggev)(const char* jobvl, const char* jobvr, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vl, const blas_int* ldvl, blas_cxd* vr, const blas_int* ldvr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // Cholesky decomposition - void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info); + void arma_fortran(arma_spotrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations with pre-computed Cholesky decomposition - void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_spotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotrs)(const char* uplo, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // Cholesky decomposition (band matrices) - void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info); - void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info); - void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info); - void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info); + void arma_fortran(arma_spbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, float* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, double* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxf* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpbtrf)(const char* uplo, const blas_int* n, const blas_int* kd, blas_cxd* ab, const blas_int* ldab, blas_int* info) ARMA_NOEXCEPT; // matrix inversion (using pre-computed Cholesky decomposition) - void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info); - void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info); + void arma_fortran(arma_spotri)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpotri)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpotri)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpotri)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* info) ARMA_NOEXCEPT; // QR decomposition - void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgeqrf)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqrf)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgeqrf)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqrf)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (real matrices) + void arma_fortran(arma_sgeqp3)(const blas_int* m, const blas_int* n, float* a, const blas_int* lda, blas_int* jpvt, float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgeqp3)(const blas_int* m, const blas_int* n, double* a, const blas_int* lda, blas_int* jpvt, double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // QR decomposition with pivoting (complex matrices) + void arma_fortran(arma_cgeqp3)(const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* jpvt, blas_cxf* tau, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgeqp3)(const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* jpvt, blas_cxd* tau, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // Q matrix calculation from QR decomposition (real matrices) - void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, float* a, const blas_int* lda, const float* tau, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dorgqr)(const blas_int* m, const blas_int* n, const blas_int* k, double* a, const blas_int* lda, const double* tau, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // Q matrix calculation from QR decomposition (complex matrices) - void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_cungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxf* a, const blas_int* lda, const blas_cxf* tau, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zungqr)(const blas_int* m, const blas_int* n, const blas_int* k, blas_cxd* a, const blas_int* lda, const blas_cxd* tau, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // SVD (real matrices) - void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // SVD (complex matrices) - void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info); - void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info); + void arma_fortran(arma_cgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvd)(const char* jobu, const char* jobvt, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* info) ARMA_NOEXCEPT; // SVD (real matrices) by divide and conquer - void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgesdd)(const char* jobz, const blas_int* m, const blas_int* n, float* a, const blas_int* lda, float* s, float* u, const blas_int* ldu, float* vt, const blas_int* ldvt, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesdd)(const char* jobz, const blas_int* m, const blas_int* n, double* a, const blas_int* lda, double* s, double* u, const blas_int* ldu, double* vt, const blas_int* ldvt, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // SVD (complex matrices) by divide and conquer - void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_cgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxf* a, const blas_int* lda, float* s, blas_cxf* u, const blas_int* ldu, blas_cxf* vt, const blas_int* ldvt, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesdd)(const char* jobz, const blas_int* m, const blas_int* n, blas_cxd* a, const blas_int* lda, double* s, blas_cxd* u, const blas_int* ldu, blas_cxd* vt, const blas_int* ldvt, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix) - void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgesv)(const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesv)(const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgesv)(const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesv)(const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix, advanced form, real matrices) - void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general square matrix, advanced form, complex matrices) - void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgesvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (symmetric positive definite matrix) - void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zposv)(const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (symmetric positive definite matrix, advanced form, real matrices) - void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* af, const blas_int* ldaf, char* equed, float* s, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* af, const blas_int* ldaf, char* equed, double* s, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (hermitian positive definite matrix, advanced form, complex matrices) - void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* af, const blas_int* ldaf, char* equed, float* s, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zposvx)(const char* fact, const char* uplo, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* af, const blas_int* ldaf, char* equed, double* s, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // solve over/under-determined system of linear equations - void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info); + void arma_fortran(arma_sgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgels)(const char* trans, const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; // approximately solve system of linear equations using svd (real) - void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, float* a, const blas_int* lda, float* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, float* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, double* a, const blas_int* lda, double* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, double* work, const blas_int* lwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + // approximately solve system of linear equations using svd (complex) - void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info); - void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info); + void arma_fortran(arma_cgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, float* S, const float* rcond, blas_int* rank, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgelsd)(const blas_int* m, const blas_int* n, const blas_int* nrhs, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, double* S, const double* rcond, blas_int* rank, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (triangular matrix) - void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_strtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrtrs)(const char* uplo, const char* trans, const char* diag, const blas_int* n, const blas_int* nrhs, const blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // LU factorisation (general band matrix) - void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); - void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info); + void arma_fortran(arma_sgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, float* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, double* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrf)(const blas_int* m, const blas_int* n, const blas_int* kl, const blas_int* ku, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations using pre-computed LU decomposition (general band matrix) - void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const float* ab, const blas_int* ldab, const blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const double* ab, const blas_int* ldab, const blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbtrs)(const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix) - void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, blas_int* ipiv, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, blas_int* ipiv, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_int* ipiv, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsv)(const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_int* ipiv, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix, advanced form, real matrices) - void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, float* ab, const blas_int* ldab, float* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, double* ab, const blas_int* ldab, double* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (general band matrix, advanced form, complex matrices) - void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxf* ab, const blas_int* ldab, blas_cxf* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, float* r, float* c, blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_int* nrhs, blas_cxd* ab, const blas_int* ldab, blas_cxd* afb, const blas_int* ldafb, blas_int* ipiv, char* equed, double* r, double* c, blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix) - void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info); - void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info); + void arma_fortran(arma_sgtsv)(const blas_int* n, const blas_int* nrhs, float* dl, float* d, float* du, float* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsv)(const blas_int* n, const blas_int* nrhs, double* dl, double* d, double* du, double* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxf* dl, blas_cxf* d, blas_cxf* du, blas_cxf* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsv)(const blas_int* n, const blas_int* nrhs, blas_cxd* dl, blas_cxd* d, blas_cxd* du, blas_cxd* b, const blas_int* ldb, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix, advanced form, real matrices) - void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const float* dl, const float* d, const float* du, float* dlf, float* df, float* duf, float* du2, blas_int* ipiv, const float* b, const blas_int* ldb, float* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const double* dl, const double* d, const double* du, double* dlf, double* df, double* duf, double* du2, blas_int* ipiv, const double* b, const blas_int* ldb, double* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // solve system of linear equations (tridiagonal band matrix, advanced form, complex matrices) - void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxf* dl, const blas_cxf* d, const blas_cxf* du, blas_cxf* dlf, blas_cxf* df, blas_cxf* duf, blas_cxf* du2, blas_int* ipiv, const blas_cxf* b, const blas_int* ldb, blas_cxf* x, const blas_int* ldx, float* rcond, float* ferr, float* berr, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgtsvx)(const char* fact, const char* trans, const blas_int* n, const blas_int* nrhs, const blas_cxd* dl, const blas_cxd* d, const blas_cxd* du, blas_cxd* dlf, blas_cxd* df, blas_cxd* duf, blas_cxd* du2, blas_int* ipiv, const blas_cxd* b, const blas_int* ldb, blas_cxd* x, const blas_int* ldx, double* rcond, double* ferr, double* berr, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // Schur decomposition (real matrices) - void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info); - void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info); + void arma_fortran(arma_sgees)(const char* jobvs, const char* sort, fn_select_s2 select, const blas_int* n, float* a, const blas_int* lda, blas_int* sdim, float* wr, float* wi, float* vs, const blas_int* ldvs, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgees)(const char* jobvs, const char* sort, fn_select_d2 select, const blas_int* n, double* a, const blas_int* lda, blas_int* sdim, double* wr, double* wi, double* vs, const blas_int* ldvs, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; // Schur decomposition (complex matrices) - void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info); - void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info); + void arma_fortran(arma_cgees)(const char* jobvs, const char* sort, fn_select_c1 select, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* sdim, blas_cxf* w, blas_cxf* vs, const blas_int* ldvs, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgees)(const char* jobvs, const char* sort, fn_select_z1 select, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* sdim, blas_cxd* w, blas_cxd* vs, const blas_int* ldvs, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; // solve a Sylvester equation ax + xb = c, with a and b assumed to be in Schur form - void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info); - void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info); - void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info); - void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info); + void arma_fortran(arma_strsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, const float* b, const blas_int* ldb, float* c, const blas_int* ldc, float* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, const double* b, const blas_int* ldb, double* c, const blas_int* ldc, double* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ctrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, const blas_cxf* b, const blas_int* ldb, blas_cxf* c, const blas_int* ldc, float* scale, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrsyl)(const char* transa, const char* transb, const blas_int* isgn, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, const blas_cxd* b, const blas_int* ldb, blas_cxd* c, const blas_int* ldc, double* scale, blas_int* info) ARMA_NOEXCEPT; // QZ decomposition (real matrices) - void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info); - void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info); + void arma_fortran(arma_sgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_s3 selctg, const blas_int* n, float* a, const blas_int* lda, float* b, const blas_int* ldb, blas_int* sdim, float* alphar, float* alphai, float* beta, float* vsl, const blas_int* ldvsl, float* vsr, const blas_int* ldvsr, float* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_d3 selctg, const blas_int* n, double* a, const blas_int* lda, double* b, const blas_int* ldb, blas_int* sdim, double* alphar, double* alphai, double* beta, double* vsl, const blas_int* ldvsl, double* vsr, const blas_int* ldvsr, double* work, const blas_int* lwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; // QZ decomposition (complex matrices) - void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info); - void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info); + void arma_fortran(arma_cgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_c2 selctg, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_cxf* b, const blas_int* ldb, blas_int* sdim, blas_cxf* alpha, blas_cxf* beta, blas_cxf* vsl, const blas_int* ldvsl, blas_cxf* vsr, const blas_int* ldvsr, blas_cxf* work, const blas_int* lwork, float* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgges)(const char* jobvsl, const char* jobvsr, const char* sort, fn_select_z2 selctg, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_cxd* b, const blas_int* ldb, blas_int* sdim, blas_cxd* alpha, blas_cxd* beta, blas_cxd* vsl, const blas_int* ldvsl, blas_cxd* vsr, const blas_int* ldvsr, blas_cxd* work, const blas_int* lwork, double* rwork, blas_int* bwork, blas_int* info) ARMA_NOEXCEPT; // 1-norm (general matrix) - float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work); - double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work); - float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work); - double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work); + float arma_fortran(arma_slange)(const char* norm, const blas_int* m, const blas_int* n, const float* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlange)(const char* norm, const blas_int* m, const blas_int* n, const double* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlange)(const char* norm, const blas_int* m, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; // 1-norm (real symmetric matrix) - float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work); - double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work); - float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work); - double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work); + float arma_fortran(arma_slansy)(const char* norm, const char* uplo, const blas_int* n, const float* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlansy)(const char* norm, const char* uplo, const blas_int* n, const double* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlansy)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; // 1-norm (complex hermitian matrix) - float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work); - double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work); + float arma_fortran(arma_clanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlanhe)(const char* norm, const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* work) ARMA_NOEXCEPT; // 1-norm (band matrix) - float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work); - double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work); - float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work); - double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work); + float arma_fortran(arma_slangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_dlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, double* work) ARMA_NOEXCEPT; + float arma_fortran(arma_clangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, float* work) ARMA_NOEXCEPT; + double arma_fortran(arma_zlangb)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, double* work) ARMA_NOEXCEPT; // reciprocal of condition number (real, generic matrix) - void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgecon)(const char* norm, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgecon)(const char* norm, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (complex, generic matrix) - void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cgecon)(const char* norm, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgecon)(const char* norm, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (real, symmetric positive definite matrix) - void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_spocon)(const char* uplo, const blas_int* n, const float* a, const blas_int* lda, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpocon)(const char* uplo, const blas_int* n, const double* a, const blas_int* lda, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (complex, hermitian positive definite matrix) - void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cpocon)(const char* uplo, const blas_int* n, const blas_cxf* a, const blas_int* lda, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpocon)(const char* uplo, const blas_int* n, const blas_cxd* a, const blas_int* lda, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (real, triangular matrix) - void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_strcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const float* a, const blas_int* lda, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const double* a, const blas_int* lda, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (complex, triangular matrix) - void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_ctrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxf* a, const blas_int* lda, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_ztrcon)(const char* norm, const char* uplo, const char* diag, const blas_int* n, const blas_cxd* a, const blas_int* lda, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (real, band matrix) - void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info); - void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info); + void arma_fortran(arma_sgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const float* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, float* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const double* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, double* work, blas_int* iwork, blas_int* info) ARMA_NOEXCEPT; // reciprocal of condition number (complex, band matrix) - void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info); - void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info); + void arma_fortran(arma_cgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxf* ab, const blas_int* ldab, const blas_int* ipiv, const float* anorm, float* rcond, blas_cxf* work, float* rwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgbcon)(const char* norm, const blas_int* n, const blas_int* kl, const blas_int* ku, const blas_cxd* ab, const blas_int* ldab, const blas_int* ipiv, const double* anorm, double* rcond, blas_cxd* work, double* rwork, blas_int* info) ARMA_NOEXCEPT; // obtain parameters according to the local configuration of lapack // NOTE: DO NOT USE THIS FORM; kept only for compatibility // NOTE: this function takes 'name' and 'opts' argumments, which are strings with length != 1; their length needs to be given via "hidden" parameters, which this form lacks - blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4); + blas_int arma_fortran(arma_ilaenv)(const blas_int* ispec, const char* name, const char* opts, const blas_int* n1, const blas_int* n2, const blas_int* n3, const blas_int* n4) ARMA_NOEXCEPT; // calculate eigenvalues of an upper Hessenberg matrix - void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info); - void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info); + void arma_fortran(arma_slahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* h, const blas_int* ldh, float* wr, float* wi, const blas_int* iloz, const blas_int* ihiz, float* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dlahqr)(const blas_int* wantt, const blas_int* wantz, const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* h, const blas_int* ldh, double* wr, double* wi, const blas_int* iloz, const blas_int* ihiz, double* z, const blas_int* ldz, blas_int* info) ARMA_NOEXCEPT; // calculate eigenvalues of a symmetric tridiagonal matrix - void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info); - void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info); + void arma_fortran(arma_sstedc)(const char* compz, const blas_int* n, float* d, float* e, float* z, const blas_int* ldz, float* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dstedc)(const char* compz, const blas_int* n, double* d, double* e, double* z, const blas_int* ldz, double* work, const blas_int* lwork, blas_int* iwork, const blas_int* liwork, blas_int* info) ARMA_NOEXCEPT; // calculate eigenvectors of a Schur form matrix - void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info); - void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info); - - // generate a vector of random numbers - void arma_fortran(arma_slarnv)(const blas_int* idist, blas_int* iseed, const blas_int* n, float* x); - void arma_fortran(arma_dlarnv)(const blas_int* idist, blas_int* iseed, const blas_int* n, double* x); + void arma_fortran(arma_strevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const float* t, const blas_int* ldt, float* vl, const blas_int* ldvl, float* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dtrevc)(const char* side, const char* howmny, blas_int* select, const blas_int* n, const double* t, const blas_int* ldt, double* vl, const blas_int* ldvl, double* vr, const blas_int* ldvr, const blas_int* mm, blas_int* m, double* work, blas_int* info) ARMA_NOEXCEPT; // hessenberg decomposition - void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info); - void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info); - + void arma_fortran(arma_sgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, float* a, const blas_int* lda, float* tao, float* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, double* a, const blas_int* lda, double* tao, double* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxf* a, const blas_int* lda, blas_cxf* tao, blas_cxf* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zgehrd)(const blas_int* n, const blas_int* ilo, const blas_int* ihi, blas_cxd* a, const blas_int* lda, blas_cxd* tao, blas_cxd* work, const blas_int* lwork, blas_int* info) ARMA_NOEXCEPT; + + // pivoted cholesky + void arma_fortran(arma_spstrf)(const char* uplo, const blas_int* n, float* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_dpstrf)(const char* uplo, const blas_int* n, double* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_cpstrf)(const char* uplo, const blas_int* n, blas_cxf* a, const blas_int* lda, blas_int* piv, blas_int* rank, const float* tol, float* work, blas_int* info) ARMA_NOEXCEPT; + void arma_fortran(arma_zpstrf)(const char* uplo, const blas_int* n, blas_cxd* a, const blas_int* lda, blas_int* piv, blas_int* rank, const double* tol, double* work, blas_int* info) ARMA_NOEXCEPT; + #endif } +#undef ARMA_NOEXCEPT #endif diff --git a/src/armadillo_bits/def_superlu.hpp b/src/armadillo_bits/def_superlu.hpp index 06009075..81f6ac39 100644 --- a/src/armadillo_bits/def_superlu.hpp +++ b/src/armadillo_bits/def_superlu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,12 +29,37 @@ extern "C" extern void arma_wrapper(cgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, float*, float*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, float*, float*, float*, float*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); extern void arma_wrapper(zgssvx)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, int*, char*, double*, double*, superlu::SuperMatrix*, superlu::SuperMatrix*, void*, int, superlu::SuperMatrix*, superlu::SuperMatrix*, double*, double*, double*, double*, superlu::GlobalLU_t*, superlu::mem_usage_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(sgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgstrf)(superlu::superlu_options_t*, superlu::SuperMatrix*, int, int, int*, void*, int, int*, int*, superlu::SuperMatrix*, superlu::SuperMatrix*, superlu::GlobalLU_t*, superlu::SuperLUStat_t*, int*); + + extern void arma_wrapper(sgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(dgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(cgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + extern void arma_wrapper(zgstrs)(superlu::trans_t, superlu::SuperMatrix*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*, superlu::SuperLUStat_t*, int*); + + extern float arma_wrapper(slangs)(char* norm, superlu::SuperMatrix* A); + extern double arma_wrapper(dlangs)(char* norm, superlu::SuperMatrix* A); + extern float arma_wrapper(clangs)(char* norm, superlu::SuperMatrix* A); + extern double arma_wrapper(zlangs)(char* norm, superlu::SuperMatrix* A); + + extern void arma_wrapper(sgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, float anorm, float* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(dgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, double anorm, double* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(cgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, float anorm, float* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(zgscon)(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, double anorm, double* rcond, superlu::SuperLUStat_t* stat, int* info); + extern void arma_wrapper(StatInit)(superlu::SuperLUStat_t*); extern void arma_wrapper(StatFree)(superlu::SuperLUStat_t*); extern void arma_wrapper(set_default_options)(superlu::superlu_options_t*); + extern void arma_wrapper(get_perm_c)(int, superlu::SuperMatrix*, int*); + extern int arma_wrapper(sp_ienv)(int); + extern void arma_wrapper(sp_preorder)(superlu::superlu_options_t*, superlu::SuperMatrix*, int*, int*, superlu::SuperMatrix*); + extern void arma_wrapper(Destroy_SuperNode_Matrix)(superlu::SuperMatrix*); extern void arma_wrapper(Destroy_CompCol_Matrix)(superlu::SuperMatrix*); + extern void arma_wrapper(Destroy_CompCol_Permuted)(superlu::SuperMatrix*); extern void arma_wrapper(Destroy_SuperMatrix_Store)(superlu::SuperMatrix*); // We also need superlu_malloc() and superlu_free(). diff --git a/src/armadillo_bits/diagmat_proxy.hpp b/src/armadillo_bits/diagmat_proxy.hpp index 24202c15..262dd958 100644 --- a/src/armadillo_bits/diagmat_proxy.hpp +++ b/src/armadillo_bits/diagmat_proxy.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -90,7 +92,7 @@ class diagmat_proxy_default } - arma_inline bool is_alias(const Mat&) const { return false; } + inline bool is_alias(const Mat& X) const { return P.is_alias(X); } const Proxy P; const bool P_is_vec; @@ -143,6 +145,12 @@ class diagmat_proxy_fixed const T1& P; + //// this may require T1::n_elem etc to be declared as static constexpr inline variables (C++17) + //// see also the notes in Mat::fixed + // static constexpr bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); + // static constexpr uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows; + // static constexpr uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols; + static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); static const uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows; static const uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols; @@ -161,11 +169,11 @@ struct diagmat_proxy_redirect { typedef diagmat_proxy_fixed res template -class diagmat_proxy : public diagmat_proxy_redirect::value >::result +class diagmat_proxy : public diagmat_proxy_redirect::value>::result { public: inline diagmat_proxy(const T1& X) - : diagmat_proxy_redirect< T1, is_Mat_fixed::value >::result(X) + : diagmat_proxy_redirect::value>::result(X) { } }; @@ -226,7 +234,7 @@ class diagmat_proxy< Row > arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } - static const bool P_is_vec = true; + static constexpr bool P_is_vec = true; const Row& P; const uword n_rows; @@ -258,7 +266,7 @@ class diagmat_proxy< Col > arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&P)); } - static const bool P_is_vec = true; + static constexpr bool P_is_vec = true; const Col& P; const uword n_rows; @@ -290,7 +298,7 @@ class diagmat_proxy< subview_row > arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); } - static const bool P_is_vec = true; + static constexpr bool P_is_vec = true; const subview_row& P; const uword n_rows; @@ -322,7 +330,7 @@ class diagmat_proxy< subview_col > arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&(P.m))); } - static const bool P_is_vec = true; + static constexpr bool P_is_vec = true; const subview_col& P; const uword n_rows; @@ -331,255 +339,35 @@ class diagmat_proxy< subview_col > -// -// -// - - - -template -class diagmat_proxy_check_default - { - public: - - typedef typename T1::elem_type elem_type; - typedef typename get_pod_type::result pod_type; - - inline - diagmat_proxy_check_default(const T1& X, const Mat&) - : P(X) - , P_is_vec( (resolves_to_vector::yes) || (P.n_rows == 1) || (P.n_cols == 1) ) - , n_rows( P_is_vec ? P.n_elem : P.n_rows ) - , n_cols( P_is_vec ? P.n_elem : P.n_cols ) - { - arma_extra_debug_sigprint(); - } - - arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } - - const Mat P; - const bool P_is_vec; - const uword n_rows; - const uword n_cols; - }; - - - -template -class diagmat_proxy_check_fixed +template +class diagmat_proxy< Glue > { public: - typedef typename T1::elem_type eT; typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; inline - diagmat_proxy_check_fixed(const T1& X, const Mat& out) - : P( const_cast(X.memptr()), T1::n_rows, T1::n_cols, (&X == &out), false ) + diagmat_proxy(const Glue& X) { + op_diagmat::apply_times(P, X.A, X.B); + + n_rows = P.n_rows; + n_cols = P.n_cols; + arma_extra_debug_sigprint(); } + arma_inline elem_type operator[] (const uword i) const { return P.at(i,i); } + arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P.at(row,row) : elem_type(0); } - arma_inline eT operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } - arma_inline eT at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } - - const Mat P; // TODO: why not just store X directly as T1& ? test with fixed size vectors and matrices - - static const bool P_is_vec = (T1::n_rows == 1) || (T1::n_cols == 1); - static const uword n_rows = P_is_vec ? T1::n_elem : T1::n_rows; - static const uword n_cols = P_is_vec ? T1::n_elem : T1::n_cols; - }; - - - -template -struct diagmat_proxy_check_redirect {}; - -template -struct diagmat_proxy_check_redirect { typedef diagmat_proxy_check_default result; }; - -template -struct diagmat_proxy_check_redirect { typedef diagmat_proxy_check_fixed result; }; - - -template -class diagmat_proxy_check : public diagmat_proxy_check_redirect::value >::result - { - public: - inline diagmat_proxy_check(const T1& X, const Mat& out) - : diagmat_proxy_check_redirect< T1, is_Mat_fixed::value >::result(X, out) - { - } - }; - - - -template -class diagmat_proxy_check< Mat > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - - - inline - diagmat_proxy_check(const Mat& X, const Mat& out) - : P_local ( (&X == &out) ? new Mat(X) : 0 ) - , P ( (&X == &out) ? (*P_local) : X ) - , P_is_vec( (P.n_rows == 1) || (P.n_cols == 1) ) - , n_rows ( P_is_vec ? P.n_elem : P.n_rows ) - , n_cols ( P_is_vec ? P.n_elem : P.n_cols ) - { - arma_extra_debug_sigprint(); - } - - inline ~diagmat_proxy_check() - { - if(P_local) { delete P_local; } - } - - arma_inline elem_type operator[] (const uword i) const { return P_is_vec ? P[i] : P.at(i,i); } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? ( P_is_vec ? P[row] : P.at(row,row) ) : elem_type(0); } - - const Mat* P_local; - const Mat& P; - const bool P_is_vec; - const uword n_rows; - const uword n_cols; - }; - - - -template -class diagmat_proxy_check< Row > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - - inline - diagmat_proxy_check(const Row& X, const Mat& out) - : P_local ( (&X == reinterpret_cast*>(&out)) ? new Row(X) : 0 ) - , P ( (&X == reinterpret_cast*>(&out)) ? (*P_local) : X ) - , n_rows (X.n_elem) - , n_cols (X.n_elem) - { - arma_extra_debug_sigprint(); - } - - inline ~diagmat_proxy_check() - { - if(P_local) { delete P_local; } - } - - arma_inline elem_type operator[] (const uword i) const { return P[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } - - static const bool P_is_vec = true; - - const Row* P_local; - const Row& P; - const uword n_rows; - const uword n_cols; - }; - - - -template -class diagmat_proxy_check< Col > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - - inline - diagmat_proxy_check(const Col& X, const Mat& out) - : P_local ( (&X == reinterpret_cast*>(&out)) ? new Col(X) : 0 ) - , P ( (&X == reinterpret_cast*>(&out)) ? (*P_local) : X ) - , n_rows (X.n_elem) - , n_cols (X.n_elem) - { - arma_extra_debug_sigprint(); - } - - inline ~diagmat_proxy_check() - { - if(P_local) { delete P_local; } - } - - arma_inline elem_type operator[] (const uword i) const { return P[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } - - static const bool P_is_vec = true; - - const Col* P_local; - const Col& P; - const uword n_rows; - const uword n_cols; - }; - - - -template -class diagmat_proxy_check< subview_row > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - - inline - diagmat_proxy_check(const subview_row& X, const Mat&) - : P ( X ) - , n_rows ( X.n_elem ) - , n_cols ( X.n_elem ) - { - arma_extra_debug_sigprint(); - } - - arma_inline elem_type operator[] (const uword i) const { return P[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } - - static const bool P_is_vec = true; - - const Row P; - const uword n_rows; - const uword n_cols; - }; - - - -template -class diagmat_proxy_check< subview_col > - { - public: - - typedef eT elem_type; - typedef typename get_pod_type::result pod_type; - - inline - diagmat_proxy_check(const subview_col& X, const Mat& out) - : P ( const_cast(X.colptr(0)), X.n_rows, (&(X.m) == &out), false ) - , n_rows( X.n_elem ) - , n_cols( X.n_elem ) - { - arma_extra_debug_sigprint(); - } - - arma_inline elem_type operator[] (const uword i) const { return P[i]; } - arma_inline elem_type at (const uword row, const uword col) const { return (row == col) ? P[row] : elem_type(0); } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool P_is_vec = true; + static constexpr bool P_is_vec = false; - const Col P; - const uword n_rows; - const uword n_cols; + Mat P; + uword n_rows; + uword n_cols; }; diff --git a/src/armadillo_bits/diagview_bones.hpp b/src/armadillo_bits/diagview_bones.hpp index 022eb568..5aa4bcee 100644 --- a/src/armadillo_bits/diagview_bones.hpp +++ b/src/armadillo_bits/diagview_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,7 @@ //! Class for storing data required to extract and set the diagonals of a matrix template -class diagview : public Base > +class diagview : public Base< eT, diagview > { public: @@ -29,9 +31,9 @@ class diagview : public Base > arma_aligned const Mat& m; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; const uword row_offset; const uword col_offset; @@ -39,17 +41,21 @@ class diagview : public Base > const uword n_rows; // equal to n_elem const uword n_elem; - static const uword n_cols = 1; + static constexpr uword n_cols = 1; protected: arma_inline diagview(const Mat& in_m, const uword in_row_offset, const uword in_col_offset, const uword len); - + public: inline ~diagview(); + inline diagview() = delete; + + inline diagview(const diagview& in); + inline diagview( diagview&& in); inline void operator=(const diagview& x); @@ -83,12 +89,12 @@ class diagview : public Base > arma_inline eT operator()(const uword in_n_row, const uword in_n_col) const; - arma_inline const Op,op_htrans> t() const; - arma_inline const Op,op_htrans> ht() const; - arma_inline const Op,op_strans> st() const; - inline void replace(const eT old_val, const eT new_val); + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); @@ -103,13 +109,8 @@ class diagview : public Base > inline static void div_inplace(Mat& out, const diagview& in); - private: - friend class Mat; friend class subview; - - diagview(); - //diagview(const diagview&); // making this private causes an error under gcc 4.1/4.2, but not 4.3 }; diff --git a/src/armadillo_bits/diagview_meat.hpp b/src/armadillo_bits/diagview_meat.hpp index ea8d78f4..e35f8579 100644 --- a/src/armadillo_bits/diagview_meat.hpp +++ b/src/armadillo_bits/diagview_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,20 +24,56 @@ template inline diagview::~diagview() { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); } + template arma_inline diagview::diagview(const Mat& in_m, const uword in_row_offset, const uword in_col_offset, const uword in_len) - : m(in_m) + : m (in_m ) , row_offset(in_row_offset) , col_offset(in_col_offset) - , n_rows(in_len) - , n_elem(in_len) + , n_rows (in_len ) + , n_elem (in_len ) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); + } + + + +template +inline +diagview::diagview(const diagview& in) + : m (in.m ) + , row_offset(in.row_offset) + , col_offset(in.col_offset) + , n_rows (in.n_rows ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +diagview::diagview(diagview&& in) + : m (in.m ) + , row_offset(in.row_offset) + , col_offset(in.col_offset) + , n_rows (in.n_rows ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + + // for paranoia + + access::rw(in.row_offset) = 0; + access::rw(in.col_offset) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_elem ) = 0; } @@ -50,7 +88,7 @@ diagview::operator= (const diagview& x) diagview& d = *this; - arma_debug_check( (d.n_elem != x.n_elem), "diagview: diagonals have incompatible lengths"); + arma_debug_check( (d.n_elem != x.n_elem), "diagview: diagonals have incompatible lengths" ); Mat& d_m = const_cast< Mat& >(d.m); const Mat& x_m = x.m; @@ -775,7 +813,7 @@ arma_inline eT& diagview::operator()(const uword ii) { - arma_debug_check( (ii >= n_elem), "diagview::operator(): out of bounds" ); + arma_debug_check_bounds( (ii >= n_elem), "diagview::operator(): out of bounds" ); return (const_cast< Mat& >(m)).at(ii+row_offset, ii+col_offset); } @@ -787,7 +825,7 @@ arma_inline eT diagview::operator()(const uword ii) const { - arma_debug_check( (ii >= n_elem), "diagview::operator(): out of bounds" ); + arma_debug_check_bounds( (ii >= n_elem), "diagview::operator(): out of bounds" ); return m.at(ii+row_offset, ii+col_offset); } @@ -819,7 +857,7 @@ arma_inline eT& diagview::operator()(const uword row, const uword col) { - arma_debug_check( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); return (const_cast< Mat& >(m)).at(row+row_offset, row+col_offset); } @@ -831,43 +869,13 @@ arma_inline eT diagview::operator()(const uword row, const uword col) const { - arma_debug_check( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "diagview::operator(): out of bounds" ); return m.at(row+row_offset, row+col_offset); } -template -arma_inline -const Op,op_htrans> -diagview::t() const - { - return Op,op_htrans>(*this); - } - - - -template -arma_inline -const Op,op_htrans> -diagview::ht() const - { - return Op,op_htrans>(*this); - } - - - -template -arma_inline -const Op,op_strans> -diagview::st() const - { - return Op,op_strans>(*this); - } - - - template inline void @@ -901,6 +909,38 @@ diagview::replace(const eT old_val, const eT new_val) +template +inline +void +diagview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +diagview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + template inline void diff --git a/src/armadillo_bits/diskio_bones.hpp b/src/armadillo_bits/diskio_bones.hpp index 505dca08..03e1ac5f 100644 --- a/src/armadillo_bits/diskio_bones.hpp +++ b/src/armadillo_bits/diskio_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,24 +25,45 @@ class diskio { public: - template inline arma_cold static std::string gen_txt_header(const Mat&); - template inline arma_cold static std::string gen_bin_header(const Mat&); + arma_deprecated inline static file_type guess_file_type(std::istream& f); + + + private: + + template friend class Mat; + template friend class Cube; + template friend class SpMat; + template friend class field; + + friend class Mat_aux; + friend class Cube_aux; + friend class SpMat_aux; + friend class field_aux; - template inline arma_cold static std::string gen_bin_header(const SpMat&); + template arma_cold inline static std::string gen_txt_header(const Mat&); + template arma_cold inline static std::string gen_bin_header(const Mat&); + + template arma_cold inline static std::string gen_bin_header(const SpMat&); - template inline arma_cold static std::string gen_txt_header(const Cube&); - template inline arma_cold static std::string gen_bin_header(const Cube&); + template arma_cold inline static std::string gen_txt_header(const Cube&); + template arma_cold inline static std::string gen_bin_header(const Cube&); + + arma_cold inline static file_type guess_file_type_internal(std::istream& f); - inline arma_cold static file_type guess_file_type(std::istream& f); + arma_cold inline static std::string gen_tmp_name(const std::string& x); - inline arma_cold static std::string gen_tmp_name(const std::string& x); + arma_cold inline static bool safe_rename(const std::string& old_name, const std::string& new_name); - inline arma_cold static bool safe_rename(const std::string& old_name, const std::string& new_name); + arma_cold inline static bool is_readable(const std::string& name); + + arma_cold inline static void sanitise_token(std::string& token); template inline static bool convert_token(eT& val, const std::string& token); template inline static bool convert_token(std::complex& val, const std::string& token); - template arma_deprecated inline static bool convert_naninf(eT& val, const std::string& token); + template inline static bool convert_token_strict(eT& val, const std::string& token); + + template inline static std::streamsize prepare_stream(std::ostream& f); // @@ -49,7 +72,8 @@ class diskio template inline static bool save_raw_ascii (const Mat& x, const std::string& final_name); template inline static bool save_raw_binary (const Mat& x, const std::string& final_name); template inline static bool save_arma_ascii (const Mat& x, const std::string& final_name); - template inline static bool save_csv_ascii (const Mat& x, const std::string& final_name); + template inline static bool save_csv_ascii (const Mat& x, const std::string& final_name, const field& header, const bool with_header, const char separator); + template inline static bool save_coord_ascii(const Mat& x, const std::string& final_name); template inline static bool save_arma_binary(const Mat& x, const std::string& final_name); template inline static bool save_pgm_binary (const Mat& x, const std::string& final_name); template inline static bool save_pgm_binary (const Mat< std::complex >& x, const std::string& final_name); @@ -58,8 +82,10 @@ class diskio template inline static bool save_raw_ascii (const Mat& x, std::ostream& f); template inline static bool save_raw_binary (const Mat& x, std::ostream& f); template inline static bool save_arma_ascii (const Mat& x, std::ostream& f); - template inline static bool save_csv_ascii (const Mat& x, std::ostream& f); - template inline static bool save_csv_ascii (const Mat< std::complex >& x, std::ostream& f); + template inline static bool save_csv_ascii (const Mat& x, std::ostream& f, const char separator); + template inline static bool save_csv_ascii (const Mat< std::complex >& x, std::ostream& f, const char separator); + template inline static bool save_coord_ascii(const Mat& x, std::ostream& f); + template inline static bool save_coord_ascii(const Mat< std::complex >& x, std::ostream& f); template inline static bool save_arma_binary(const Mat& x, std::ostream& f); template inline static bool save_pgm_binary (const Mat& x, std::ostream& f); template inline static bool save_pgm_binary (const Mat< std::complex >& x, std::ostream& f); @@ -71,7 +97,8 @@ class diskio template inline static bool load_raw_ascii (Mat& x, const std::string& name, std::string& err_msg); template inline static bool load_raw_binary (Mat& x, const std::string& name, std::string& err_msg); template inline static bool load_arma_ascii (Mat& x, const std::string& name, std::string& err_msg); - template inline static bool load_csv_ascii (Mat& x, const std::string& name, std::string& err_msg); + template inline static bool load_csv_ascii (Mat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator, const bool strict); + template inline static bool load_coord_ascii(Mat& x, const std::string& name, std::string& err_msg); template inline static bool load_arma_binary(Mat& x, const std::string& name, std::string& err_msg); template inline static bool load_pgm_binary (Mat& x, const std::string& name, std::string& err_msg); template inline static bool load_pgm_binary (Mat< std::complex >& x, const std::string& name, std::string& err_msg); @@ -81,8 +108,10 @@ class diskio template inline static bool load_raw_ascii (Mat& x, std::istream& f, std::string& err_msg); template inline static bool load_raw_binary (Mat& x, std::istream& f, std::string& err_msg); template inline static bool load_arma_ascii (Mat& x, std::istream& f, std::string& err_msg); - template inline static bool load_csv_ascii (Mat& x, std::istream& f, std::string& err_msg); - template inline static bool load_csv_ascii (Mat< std::complex >& x, std::istream& f, std::string& err_msg); + template inline static bool load_csv_ascii (Mat& x, std::istream& f, std::string& err_msg, const char separator, const bool strict); + template inline static bool load_csv_ascii (Mat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator, const bool strict); + template inline static bool load_coord_ascii(Mat& x, std::istream& f, std::string& err_msg); + template inline static bool load_coord_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg); template inline static bool load_arma_binary(Mat& x, std::istream& f, std::string& err_msg); template inline static bool load_pgm_binary (Mat& x, std::istream& is, std::string& err_msg); template inline static bool load_pgm_binary (Mat< std::complex >& x, std::istream& is, std::string& err_msg); @@ -94,12 +123,12 @@ class diskio // // sparse matrix saving - template inline static bool save_csv_ascii (const SpMat& x, const std::string& final_name); + template inline static bool save_csv_ascii (const SpMat& x, const std::string& final_name, const field& header, const bool with_header, const char separator); template inline static bool save_coord_ascii(const SpMat& x, const std::string& final_name); template inline static bool save_arma_binary(const SpMat& x, const std::string& final_name); - template inline static bool save_csv_ascii (const SpMat& x, std::ostream& f); - template inline static bool save_csv_ascii (const SpMat< std::complex >& x, std::ostream& f); + template inline static bool save_csv_ascii (const SpMat& x, std::ostream& f, const char separator); + template inline static bool save_csv_ascii (const SpMat< std::complex >& x, std::ostream& f, const char separator); template inline static bool save_coord_ascii(const SpMat& x, std::ostream& f); template inline static bool save_coord_ascii(const SpMat< std::complex >& x, std::ostream& f); template inline static bool save_arma_binary(const SpMat& x, std::ostream& f); @@ -108,12 +137,12 @@ class diskio // // sparse matrix loading - template inline static bool load_csv_ascii (SpMat& x, const std::string& name, std::string& err_msg); + template inline static bool load_csv_ascii (SpMat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator); template inline static bool load_coord_ascii(SpMat& x, const std::string& name, std::string& err_msg); template inline static bool load_arma_binary(SpMat& x, const std::string& name, std::string& err_msg); - template inline static bool load_csv_ascii (SpMat& x, std::istream& f, std::string& err_msg); - template inline static bool load_csv_ascii (SpMat< std::complex >& x, std::istream& f, std::string& err_msg); + template inline static bool load_csv_ascii (SpMat& x, std::istream& f, std::string& err_msg, const char separator); + template inline static bool load_csv_ascii (SpMat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator); template inline static bool load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg); template inline static bool load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg); template inline static bool load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg); diff --git a/src/armadillo_bits/diskio_meat.hpp b/src/armadillo_bits/diskio_meat.hpp index c826c776..4f716cc0 100644 --- a/src/armadillo_bits/diskio_meat.hpp +++ b/src/armadillo_bits/diskio_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,37 +24,47 @@ //! Format: "ARMA_MAT_TXT_ABXYZ". //! A is one of: I (for integral types) or F (for floating point types). //! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). -//! XYZ specifies the width of each element in terms of bytes, e.g. "008" indicates eight bytes. +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. template inline -arma_cold std::string diskio::gen_txt_header(const Mat&) { arma_type_check(( is_supported_elem_type::value == false )); - if( is_u8::value) { return std::string("ARMA_MAT_TXT_IU001"); } - else if( is_s8::value) { return std::string("ARMA_MAT_TXT_IS001"); } - else if(is_u16::value) { return std::string("ARMA_MAT_TXT_IU002"); } - else if(is_s16::value) { return std::string("ARMA_MAT_TXT_IS002"); } - else if(is_u32::value) { return std::string("ARMA_MAT_TXT_IU004"); } - else if(is_s32::value) { return std::string("ARMA_MAT_TXT_IS004"); } -#if defined(ARMA_USE_U64S64) - else if(is_u64::value) { return std::string("ARMA_MAT_TXT_IU008"); } - else if(is_s64::value) { return std::string("ARMA_MAT_TXT_IS008"); } -#endif -#if defined(ARMA_ALLOW_LONG) - else if(is_ulng_t_32::value) { return std::string("ARMA_MAT_TXT_IU004"); } - else if(is_slng_t_32::value) { return std::string("ARMA_MAT_TXT_IS004"); } - else if(is_ulng_t_64::value) { return std::string("ARMA_MAT_TXT_IU008"); } - else if(is_slng_t_64::value) { return std::string("ARMA_MAT_TXT_IS008"); } -#endif - else if( is_float::value) { return std::string("ARMA_MAT_TXT_FN004"); } - else if( is_double::value) { return std::string("ARMA_MAT_TXT_FN008"); } - else if( is_cx_float::value) { return std::string("ARMA_MAT_TXT_FC008"); } - else if(is_cx_double::value) { return std::string("ARMA_MAT_TXT_FC016"); } - - return std::string(); + const char* ARMA_MAT_TXT_IU001 = "ARMA_MAT_TXT_IU001"; + const char* ARMA_MAT_TXT_IS001 = "ARMA_MAT_TXT_IS001"; + const char* ARMA_MAT_TXT_IU002 = "ARMA_MAT_TXT_IU002"; + const char* ARMA_MAT_TXT_IS002 = "ARMA_MAT_TXT_IS002"; + const char* ARMA_MAT_TXT_IU004 = "ARMA_MAT_TXT_IU004"; + const char* ARMA_MAT_TXT_IS004 = "ARMA_MAT_TXT_IS004"; + const char* ARMA_MAT_TXT_IU008 = "ARMA_MAT_TXT_IU008"; + const char* ARMA_MAT_TXT_IS008 = "ARMA_MAT_TXT_IS008"; + const char* ARMA_MAT_TXT_FN004 = "ARMA_MAT_TXT_FN004"; + const char* ARMA_MAT_TXT_FN008 = "ARMA_MAT_TXT_FN008"; + const char* ARMA_MAT_TXT_FC008 = "ARMA_MAT_TXT_FC008"; + const char* ARMA_MAT_TXT_FC016 = "ARMA_MAT_TXT_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_MAT_TXT_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_MAT_TXT_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_MAT_TXT_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_MAT_TXT_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_MAT_TXT_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_MAT_TXT_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_MAT_TXT_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_MAT_TXT_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_MAT_TXT_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_MAT_TXT_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_MAT_TXT_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_MAT_TXT_IS008); } + else if( is_float::value) { header = const_cast(ARMA_MAT_TXT_FN004); } + else if( is_double::value) { header = const_cast(ARMA_MAT_TXT_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_MAT_TXT_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_MAT_TXT_FC016); } + + return std::string(header); } @@ -61,37 +73,47 @@ diskio::gen_txt_header(const Mat&) //! Format: "ARMA_MAT_BIN_ABXYZ". //! A is one of: I (for integral types) or F (for floating point types). //! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). -//! XYZ specifies the width of each element in terms of bytes, e.g. "008" indicates eight bytes. +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. template inline -arma_cold std::string diskio::gen_bin_header(const Mat&) { arma_type_check(( is_supported_elem_type::value == false )); - if( is_u8::value) { return std::string("ARMA_MAT_BIN_IU001"); } - else if( is_s8::value) { return std::string("ARMA_MAT_BIN_IS001"); } - else if(is_u16::value) { return std::string("ARMA_MAT_BIN_IU002"); } - else if(is_s16::value) { return std::string("ARMA_MAT_BIN_IS002"); } - else if(is_u32::value) { return std::string("ARMA_MAT_BIN_IU004"); } - else if(is_s32::value) { return std::string("ARMA_MAT_BIN_IS004"); } -#if defined(ARMA_USE_U64S64) - else if(is_u64::value) { return std::string("ARMA_MAT_BIN_IU008"); } - else if(is_s64::value) { return std::string("ARMA_MAT_BIN_IS008"); } -#endif -#if defined(ARMA_ALLOW_LONG) - else if(is_ulng_t_32::value) { return std::string("ARMA_MAT_BIN_IU004"); } - else if(is_slng_t_32::value) { return std::string("ARMA_MAT_BIN_IS004"); } - else if(is_ulng_t_64::value) { return std::string("ARMA_MAT_BIN_IU008"); } - else if(is_slng_t_64::value) { return std::string("ARMA_MAT_BIN_IS008"); } -#endif - else if( is_float::value) { return std::string("ARMA_MAT_BIN_FN004"); } - else if( is_double::value) { return std::string("ARMA_MAT_BIN_FN008"); } - else if( is_cx_float::value) { return std::string("ARMA_MAT_BIN_FC008"); } - else if(is_cx_double::value) { return std::string("ARMA_MAT_BIN_FC016"); } - - return std::string(); + const char* ARMA_MAT_BIN_IU001 = "ARMA_MAT_BIN_IU001"; + const char* ARMA_MAT_BIN_IS001 = "ARMA_MAT_BIN_IS001"; + const char* ARMA_MAT_BIN_IU002 = "ARMA_MAT_BIN_IU002"; + const char* ARMA_MAT_BIN_IS002 = "ARMA_MAT_BIN_IS002"; + const char* ARMA_MAT_BIN_IU004 = "ARMA_MAT_BIN_IU004"; + const char* ARMA_MAT_BIN_IS004 = "ARMA_MAT_BIN_IS004"; + const char* ARMA_MAT_BIN_IU008 = "ARMA_MAT_BIN_IU008"; + const char* ARMA_MAT_BIN_IS008 = "ARMA_MAT_BIN_IS008"; + const char* ARMA_MAT_BIN_FN004 = "ARMA_MAT_BIN_FN004"; + const char* ARMA_MAT_BIN_FN008 = "ARMA_MAT_BIN_FN008"; + const char* ARMA_MAT_BIN_FC008 = "ARMA_MAT_BIN_FC008"; + const char* ARMA_MAT_BIN_FC016 = "ARMA_MAT_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_MAT_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_MAT_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_MAT_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_MAT_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_MAT_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_MAT_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_MAT_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_MAT_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_MAT_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_MAT_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_MAT_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_MAT_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_MAT_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_MAT_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_MAT_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_MAT_BIN_FC016); } + + return std::string(header); } @@ -100,37 +122,47 @@ diskio::gen_bin_header(const Mat&) //! Format: "ARMA_SPM_BIN_ABXYZ". //! A is one of: I (for integral types) or F (for floating point types). //! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). -//! XYZ specifies the width of each element in terms of bytes, e.g. "008" indicates eight bytes. +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. template inline -arma_cold std::string diskio::gen_bin_header(const SpMat&) { arma_type_check(( is_supported_elem_type::value == false )); - if( is_u8::value) { return std::string("ARMA_SPM_BIN_IU001"); } - else if( is_s8::value) { return std::string("ARMA_SPM_BIN_IS001"); } - else if(is_u16::value) { return std::string("ARMA_SPM_BIN_IU002"); } - else if(is_s16::value) { return std::string("ARMA_SPM_BIN_IS002"); } - else if(is_u32::value) { return std::string("ARMA_SPM_BIN_IU004"); } - else if(is_s32::value) { return std::string("ARMA_SPM_BIN_IS004"); } -#if defined(ARMA_USE_U64S64) - else if(is_u64::value) { return std::string("ARMA_SPM_BIN_IU008"); } - else if(is_s64::value) { return std::string("ARMA_SPM_BIN_IS008"); } -#endif -#if defined(ARMA_ALLOW_LONG) - else if(is_ulng_t_32::value) { return std::string("ARMA_SPM_BIN_IU004"); } - else if(is_slng_t_32::value) { return std::string("ARMA_SPM_BIN_IS004"); } - else if(is_ulng_t_64::value) { return std::string("ARMA_SPM_BIN_IU008"); } - else if(is_slng_t_64::value) { return std::string("ARMA_SPM_BIN_IS008"); } -#endif - else if( is_float::value) { return std::string("ARMA_SPM_BIN_FN004"); } - else if( is_double::value) { return std::string("ARMA_SPM_BIN_FN008"); } - else if( is_cx_float::value) { return std::string("ARMA_SPM_BIN_FC008"); } - else if(is_cx_double::value) { return std::string("ARMA_SPM_BIN_FC016"); } - - return std::string(); + const char* ARMA_SPM_BIN_IU001 = "ARMA_SPM_BIN_IU001"; + const char* ARMA_SPM_BIN_IS001 = "ARMA_SPM_BIN_IS001"; + const char* ARMA_SPM_BIN_IU002 = "ARMA_SPM_BIN_IU002"; + const char* ARMA_SPM_BIN_IS002 = "ARMA_SPM_BIN_IS002"; + const char* ARMA_SPM_BIN_IU004 = "ARMA_SPM_BIN_IU004"; + const char* ARMA_SPM_BIN_IS004 = "ARMA_SPM_BIN_IS004"; + const char* ARMA_SPM_BIN_IU008 = "ARMA_SPM_BIN_IU008"; + const char* ARMA_SPM_BIN_IS008 = "ARMA_SPM_BIN_IS008"; + const char* ARMA_SPM_BIN_FN004 = "ARMA_SPM_BIN_FN004"; + const char* ARMA_SPM_BIN_FN008 = "ARMA_SPM_BIN_FN008"; + const char* ARMA_SPM_BIN_FC008 = "ARMA_SPM_BIN_FC008"; + const char* ARMA_SPM_BIN_FC016 = "ARMA_SPM_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_SPM_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_SPM_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_SPM_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_SPM_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_SPM_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_SPM_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_SPM_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_SPM_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_SPM_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_SPM_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_SPM_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_SPM_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_SPM_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_SPM_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_SPM_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_SPM_BIN_FC016); } + + return std::string(header); } @@ -138,37 +170,47 @@ diskio::gen_bin_header(const SpMat&) //! Format: "ARMA_CUB_TXT_ABXYZ". //! A is one of: I (for integral types) or F (for floating point types). //! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). -//! XYZ specifies the width of each element in terms of bytes, e.g. "008" indicates eight bytes. +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. template inline -arma_cold std::string diskio::gen_txt_header(const Cube&) { arma_type_check(( is_supported_elem_type::value == false )); - - if( is_u8::value) { return std::string("ARMA_CUB_TXT_IU001"); } - else if( is_s8::value) { return std::string("ARMA_CUB_TXT_IS001"); } - else if(is_u16::value) { return std::string("ARMA_CUB_TXT_IU002"); } - else if(is_s16::value) { return std::string("ARMA_CUB_TXT_IS002"); } - else if(is_u32::value) { return std::string("ARMA_CUB_TXT_IU004"); } - else if(is_s32::value) { return std::string("ARMA_CUB_TXT_IS004"); } -#if defined(ARMA_USE_U64S64) - else if(is_u64::value) { return std::string("ARMA_CUB_TXT_IU008"); } - else if(is_s64::value) { return std::string("ARMA_CUB_TXT_IS008"); } -#endif -#if defined(ARMA_ALLOW_LONG) - else if(is_ulng_t_32::value) { return std::string("ARMA_CUB_TXT_IU004"); } - else if(is_slng_t_32::value) { return std::string("ARMA_CUB_TXT_IS004"); } - else if(is_ulng_t_64::value) { return std::string("ARMA_CUB_TXT_IU008"); } - else if(is_slng_t_64::value) { return std::string("ARMA_CUB_TXT_IS008"); } -#endif - else if( is_float::value) { return std::string("ARMA_CUB_TXT_FN004"); } - else if( is_double::value) { return std::string("ARMA_CUB_TXT_FN008"); } - else if( is_cx_float::value) { return std::string("ARMA_CUB_TXT_FC008"); } - else if(is_cx_double::value) { return std::string("ARMA_CUB_TXT_FC016"); } - - return std::string(); + + const char* ARMA_CUB_TXT_IU001 = "ARMA_CUB_TXT_IU001"; + const char* ARMA_CUB_TXT_IS001 = "ARMA_CUB_TXT_IS001"; + const char* ARMA_CUB_TXT_IU002 = "ARMA_CUB_TXT_IU002"; + const char* ARMA_CUB_TXT_IS002 = "ARMA_CUB_TXT_IS002"; + const char* ARMA_CUB_TXT_IU004 = "ARMA_CUB_TXT_IU004"; + const char* ARMA_CUB_TXT_IS004 = "ARMA_CUB_TXT_IS004"; + const char* ARMA_CUB_TXT_IU008 = "ARMA_CUB_TXT_IU008"; + const char* ARMA_CUB_TXT_IS008 = "ARMA_CUB_TXT_IS008"; + const char* ARMA_CUB_TXT_FN004 = "ARMA_CUB_TXT_FN004"; + const char* ARMA_CUB_TXT_FN008 = "ARMA_CUB_TXT_FN008"; + const char* ARMA_CUB_TXT_FC008 = "ARMA_CUB_TXT_FC008"; + const char* ARMA_CUB_TXT_FC016 = "ARMA_CUB_TXT_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_CUB_TXT_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_CUB_TXT_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_CUB_TXT_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_CUB_TXT_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_CUB_TXT_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_CUB_TXT_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_CUB_TXT_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_CUB_TXT_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_CUB_TXT_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_CUB_TXT_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_CUB_TXT_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_CUB_TXT_IS008); } + else if( is_float::value) { header = const_cast(ARMA_CUB_TXT_FN004); } + else if( is_double::value) { header = const_cast(ARMA_CUB_TXT_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_CUB_TXT_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_CUB_TXT_FC016); } + + return std::string(header); } @@ -177,48 +219,68 @@ diskio::gen_txt_header(const Cube&) //! Format: "ARMA_CUB_BIN_ABXYZ". //! A is one of: I (for integral types) or F (for floating point types). //! B is one of: U (for unsigned types), S (for signed types), N (for not applicable) or C (for complex types). -//! XYZ specifies the width of each element in terms of bytes, e.g. "008" indicates eight bytes. +//! XYZ specifies the width of each element in terms of bytes, eg. "008" indicates eight bytes. template inline -arma_cold std::string diskio::gen_bin_header(const Cube&) { arma_type_check(( is_supported_elem_type::value == false )); - if( is_u8::value) { return std::string("ARMA_CUB_BIN_IU001"); } - else if( is_s8::value) { return std::string("ARMA_CUB_BIN_IS001"); } - else if(is_u16::value) { return std::string("ARMA_CUB_BIN_IU002"); } - else if(is_s16::value) { return std::string("ARMA_CUB_BIN_IS002"); } - else if(is_u32::value) { return std::string("ARMA_CUB_BIN_IU004"); } - else if(is_s32::value) { return std::string("ARMA_CUB_BIN_IS004"); } -#if defined(ARMA_USE_U64S64) - else if(is_u64::value) { return std::string("ARMA_CUB_BIN_IU008"); } - else if(is_s64::value) { return std::string("ARMA_CUB_BIN_IS008"); } -#endif -#if defined(ARMA_ALLOW_LONG) - else if(is_ulng_t_32::value) { return std::string("ARMA_CUB_BIN_IU004"); } - else if(is_slng_t_32::value) { return std::string("ARMA_CUB_BIN_IS004"); } - else if(is_ulng_t_64::value) { return std::string("ARMA_CUB_BIN_IU008"); } - else if(is_slng_t_64::value) { return std::string("ARMA_CUB_BIN_IS008"); } -#endif - else if( is_float::value) { return std::string("ARMA_CUB_BIN_FN004"); } - else if( is_double::value) { return std::string("ARMA_CUB_BIN_FN008"); } - else if( is_cx_float::value) { return std::string("ARMA_CUB_BIN_FC008"); } - else if(is_cx_double::value) { return std::string("ARMA_CUB_BIN_FC016"); } - - return std::string(); + const char* ARMA_CUB_BIN_IU001 = "ARMA_CUB_BIN_IU001"; + const char* ARMA_CUB_BIN_IS001 = "ARMA_CUB_BIN_IS001"; + const char* ARMA_CUB_BIN_IU002 = "ARMA_CUB_BIN_IU002"; + const char* ARMA_CUB_BIN_IS002 = "ARMA_CUB_BIN_IS002"; + const char* ARMA_CUB_BIN_IU004 = "ARMA_CUB_BIN_IU004"; + const char* ARMA_CUB_BIN_IS004 = "ARMA_CUB_BIN_IS004"; + const char* ARMA_CUB_BIN_IU008 = "ARMA_CUB_BIN_IU008"; + const char* ARMA_CUB_BIN_IS008 = "ARMA_CUB_BIN_IS008"; + const char* ARMA_CUB_BIN_FN004 = "ARMA_CUB_BIN_FN004"; + const char* ARMA_CUB_BIN_FN008 = "ARMA_CUB_BIN_FN008"; + const char* ARMA_CUB_BIN_FC008 = "ARMA_CUB_BIN_FC008"; + const char* ARMA_CUB_BIN_FC016 = "ARMA_CUB_BIN_FC016"; + + char* header = nullptr; + + if( is_u8::value) { header = const_cast(ARMA_CUB_BIN_IU001); } + else if( is_s8::value) { header = const_cast(ARMA_CUB_BIN_IS001); } + else if( is_u16::value) { header = const_cast(ARMA_CUB_BIN_IU002); } + else if( is_s16::value) { header = const_cast(ARMA_CUB_BIN_IS002); } + else if( is_u32::value) { header = const_cast(ARMA_CUB_BIN_IU004); } + else if( is_s32::value) { header = const_cast(ARMA_CUB_BIN_IS004); } + else if( is_u64::value) { header = const_cast(ARMA_CUB_BIN_IU008); } + else if( is_s64::value) { header = const_cast(ARMA_CUB_BIN_IS008); } + else if(is_ulng_t_32::value) { header = const_cast(ARMA_CUB_BIN_IU004); } + else if(is_slng_t_32::value) { header = const_cast(ARMA_CUB_BIN_IS004); } + else if(is_ulng_t_64::value) { header = const_cast(ARMA_CUB_BIN_IU008); } + else if(is_slng_t_64::value) { header = const_cast(ARMA_CUB_BIN_IS008); } + else if( is_float::value) { header = const_cast(ARMA_CUB_BIN_FN004); } + else if( is_double::value) { header = const_cast(ARMA_CUB_BIN_FN008); } + else if( is_cx_float::value) { header = const_cast(ARMA_CUB_BIN_FC008); } + else if(is_cx_double::value) { header = const_cast(ARMA_CUB_BIN_FC016); } + + return std::string(header); } inline -arma_cold file_type diskio::guess_file_type(std::istream& f) { arma_extra_debug_sigprint(); + return diskio::guess_file_type_internal(f); + } + + + +inline +file_type +diskio::guess_file_type_internal(std::istream& f) + { + arma_extra_debug_sigprint(); + f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -252,24 +314,33 @@ diskio::guess_file_type(std::istream& f) if(load_okay == false) { return file_type_unknown; } - bool has_binary = false; - bool has_bracket = false; - bool has_comma = false; + bool has_binary = false; + bool has_bracket = false; + bool has_comma = false; + bool has_semicolon = false; for(uword i=0; i= 123) ) { has_binary = true; break; } // the range checking can be made more elaborate + if( (val <= 8) || (val >= 123) ) { has_binary = true; break; } // the range checking can be made more elaborate - if( (val == '(') || (val == ')') ) { has_bracket = true; } + if( (val == '(') || (val == ')') ) { has_bracket = true; } - if( (val == ',') ) { has_comma = true; } + if( (val == ';') ) { has_semicolon = true; } + + if( (val == ',') ) { has_comma = true; } } if(has_binary) { return raw_binary; } - if(has_comma && (has_bracket == false)) { return csv_ascii; } + // ssv_ascii has to be before csv_ascii; + // if the data has semicolons, it suggests a CSV file with semicolon as the separating character; + // the semicolon may be used to allow the comma character to represent the decimal seperator (eg. 1,2345 vs 1.2345) + + if(has_semicolon && (has_bracket == false)) { return ssv_ascii; } + + if(has_comma && (has_bracket == false)) { return csv_ascii; } return raw_ascii; } @@ -279,7 +350,6 @@ diskio::guess_file_type(std::istream& f) //! Append a quasi-random string to the given filename. //! Avoiding use of rand() to preserve its state. inline -arma_cold std::string diskio::gen_tmp_name(const std::string& x) { @@ -316,7 +386,6 @@ diskio::gen_tmp_name(const std::string& x) //! (i) overwriting files that are write protected, //! (ii) overwriting directories. inline -arma_cold bool diskio::safe_rename(const std::string& old_name, const std::string& new_name) { @@ -335,6 +404,40 @@ diskio::safe_rename(const std::string& old_name, const std::string& new_name) +inline +bool +diskio::is_readable(const std::string& name) + { + std::ifstream f; + + f.open(name, std::fstream::binary); + + // std::ifstream destructor will close the file + + return (f.is_open()); + } + + + +inline +void +diskio::sanitise_token(std::string& token) + { + // remove spaces, tabs, carriage returns + + if(token.length() == 0) { return; } + + const char c_front = token.front(); + const char c_back = token.back(); + + if( (c_front == ' ') || (c_front == '\t') || (c_front == '\r') || (c_back == ' ') || (c_back == '\t') || (c_back == '\r') ) + { + token.erase(std::remove_if(token.begin(), token.end(), [](char c) { return ((c == ' ') || (c == '\t') || (c == '\r')); }), token.end()); + } + } + + + template inline bool @@ -342,10 +445,10 @@ diskio::convert_token(eT& val, const std::string& token) { const size_t N = size_t(token.length()); - if(N == 0) { val = eT(0); return true; } - const char* str = token.c_str(); + if( (N == 0) || ((N == 1) && (str[0] == '0')) ) { val = eT(0); return true; } + if( (N == 3) || (N == 4) ) { const bool neg = (str[0] == '-'); @@ -372,8 +475,36 @@ diskio::convert_token(eT& val, const std::string& token) } } - - char* endptr = NULL; + // #if (defined(ARMA_HAVE_CXX17) && (__cpp_lib_to_chars >= 201611L)) + // { + // // std::from_chars() doesn't handle leading whitespace + // // std::from_chars() doesn't handle leading + sign + // // std::from_chars() handles only the decimal point (.) as the decimal seperator + // + // const char str0 = str[0]; + // const bool start_ok = ((str0 != ' ') && (str0 != '\t') && (str0 != '+')); + // + // bool has_comma = false; + // for(uword i=0; i::value) { @@ -385,31 +516,28 @@ diskio::convert_token(eT& val, const std::string& token) { // signed integer - #if defined(ARMA_USE_CXX11) || (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - { - val = eT( std::strtoll(str, &endptr, 10) ); - } - #else - { - val = eT( std::strtol(str, &endptr, 10) ); - } - #endif + val = eT( std::strtoll(str, &endptr, 10) ); } else { // unsigned integer - if(str[0] == '-') { val = eT(0); return true; } - - #if defined(ARMA_USE_CXX11) || (defined(_POSIX_C_SOURCE) && (_POSIX_C_SOURCE >= 200112L)) - { - val = eT( std::strtoull(str, &endptr, 10) ); - } - #else + if((str[0] == '-') && (N >= 2)) { - val = eT( std::strtoul(str, &endptr, 10) ); + val = eT(0); + + if((str[1] == '-') || (str[1] == '+')) { return false; } + + const char* str_offset1 = &(str[1]); + + std::strtoull(str_offset1, &endptr, 10); + + if(str_offset1 == endptr) { return false; } + + return true; } - #endif + + val = eT( std::strtoull(str, &endptr, 10) ); } } @@ -483,7 +611,7 @@ diskio::convert_token(std::complex& val, const std::string& token) const bool state_real = diskio::convert_token(val_real, token_real); const bool state_imag = diskio::convert_token(val_imag, token_imag); - state = ((state_real == true) && (state_imag == true)); + state = (state_real && state_imag); val = std::complex(val_real, val_imag); } @@ -494,18 +622,53 @@ diskio::convert_token(std::complex& val, const std::string& token) template -arma_deprecated inline bool -diskio::convert_naninf(eT& val, const std::string& token) +diskio::convert_token_strict(eT& val, const std::string& token) + { + const size_t N = size_t(token.length()); + + const bool status = (N > 0) ? diskio::convert_token(val, token) : false; + + if(status == false) { val = Datum::nan; } + + return status; + } + + + +template +inline +std::streamsize +diskio::prepare_stream(std::ostream& f) { - // TODO: remove this function; - // TODO: this function is kept only to allow compilation of old versions of mlpack + std::streamsize cell_width = f.width(); - arma_debug_warn("*** arma::diskio::convert_naninf() is an internal armadillo function subject to removal ***"); + if(is_real::value) + { + f.unsetf(ios::fixed); + f.setf(ios::scientific); + f.fill(' '); + + f.precision(16); + cell_width = 24; + + // NOTE: for 'float' the optimum settings are f.precision(8) and cell_width = 15 + // NOTE: however, to avoid introducing errors in case single precision data is loaded as double precision, + // NOTE: the same settings must be used for both 'float' and 'double' + } + else + if(is_cx::value) + { + f.unsetf(ios::fixed); + f.setf(ios::scientific); + + f.precision(16); + } - return diskio::convert_token(val, token); + return cell_width; } + @@ -520,7 +683,9 @@ diskio::save_raw_ascii(const Mat& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::fstream f(tmp_name.c_str(), std::fstream::out); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); @@ -548,22 +713,9 @@ diskio::save_raw_ascii(const Mat& x, std::ostream& f) { arma_extra_debug_sigprint(); - uword cell_width; - - if(is_real::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - cell_width = 22; - } + const arma_ostream_state stream_state(f); - if(is_cx::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + const std::streamsize cell_width = diskio::prepare_stream(f); for(uword row=0; row < x.n_rows; ++row) { @@ -571,15 +723,19 @@ diskio::save_raw_ascii(const Mat& x, std::ostream& f) { f.put(' '); - if(is_real::value) { f.width(std::streamsize(cell_width)); } + if(is_real::value) { f.width(cell_width); } - arma_ostream::print_elem(f, x.at(row,col), false); + arma_ostream::raw_print_elem(f, x.at(row,col)); } f.put('\n'); } - return f.good(); + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; } @@ -594,7 +750,7 @@ diskio::save_raw_binary(const Mat& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str(), std::fstream::binary); + std::ofstream f(tmp_name, std::fstream::binary); bool save_okay = f.is_open(); @@ -638,11 +794,13 @@ diskio::save_arma_ascii(const Mat& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str()); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); - - if(save_okay) + + if(save_okay) { save_okay = diskio::save_arma_ascii(x, f); @@ -666,27 +824,12 @@ diskio::save_arma_ascii(const Mat& x, std::ostream& f) { arma_extra_debug_sigprint(); - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); f << diskio::gen_txt_header(x) << '\n'; f << x.n_rows << ' ' << x.n_cols << '\n'; - uword cell_width; - - if(is_real::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - cell_width = 22; - } - - if(is_cx::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + const std::streamsize cell_width = diskio::prepare_stream(f); for(uword row=0; row < x.n_rows; ++row) { @@ -694,9 +837,9 @@ diskio::save_arma_ascii(const Mat& x, std::ostream& f) { f.put(' '); - if(is_real::value) { f.width(std::streamsize(cell_width)); } + if(is_real::value) { f.width(cell_width); } - arma_ostream::print_elem(f, x.at(row,col), false); + arma_ostream::raw_print_elem(f, x.at(row,col)); } f.put('\n'); @@ -704,7 +847,7 @@ diskio::save_arma_ascii(const Mat& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -715,26 +858,43 @@ diskio::save_arma_ascii(const Mat& x, std::ostream& f) template inline bool -diskio::save_csv_ascii(const Mat& x, const std::string& final_name) +diskio::save_csv_ascii(const Mat& x, const std::string& final_name, const field& header, const bool with_header, const char separator) { arma_extra_debug_sigprint(); const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str()); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); - if(save_okay) + if(save_okay == false) { return false; } + + if(with_header) { - save_okay = diskio::save_csv_ascii(x, f); + arma_extra_debug_print("diskio::save_csv_ascii(): writing header"); - f.flush(); - f.close(); + for(uword i=0; i < header.n_elem; ++i) + { + f << header.at(i); + + if(i != (header.n_elem-1)) { f.put(separator); } + } - if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + f.put('\n'); + + save_okay = f.good(); } + if(save_okay) { save_okay = diskio::save_csv_ascii(x, f, separator); } + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + return save_okay; } @@ -744,18 +904,13 @@ diskio::save_csv_ascii(const Mat& x, const std::string& final_name) template inline bool -diskio::save_csv_ascii(const Mat& x, std::ostream& f) +diskio::save_csv_ascii(const Mat& x, std::ostream& f, const char separator) { arma_extra_debug_sigprint(); - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); - if( (is_float::value) || (is_double::value) ) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + diskio::prepare_stream(f); uword x_n_rows = x.n_rows; uword x_n_cols = x.n_cols; @@ -764,9 +919,9 @@ diskio::save_csv_ascii(const Mat& x, std::ostream& f) { for(uword col=0; col < x_n_cols; ++col) { - arma_ostream::print_elem(f, x.at(row,col), false); + arma_ostream::raw_print_elem(f, x.at(row,col)); - if( col < (x_n_cols-1) ) { f.put(','); } + if( col < (x_n_cols-1) ) { f.put(separator); } } f.put('\n'); @@ -774,7 +929,7 @@ diskio::save_csv_ascii(const Mat& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -785,20 +940,15 @@ diskio::save_csv_ascii(const Mat& x, std::ostream& f) template inline bool -diskio::save_csv_ascii(const Mat< std::complex >& x, std::ostream& f) +diskio::save_csv_ascii(const Mat< std::complex >& x, std::ostream& f, const char separator) { arma_extra_debug_sigprint(); typedef typename std::complex eT; - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); - if( (is_float::value) || (is_double::value) ) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + diskio::prepare_stream(f); uword x_n_rows = x.n_rows; uword x_n_cols = x.n_cols; @@ -814,12 +964,12 @@ diskio::save_csv_ascii(const Mat< std::complex >& x, std::ostream& f) const T tmp_i_abs = (tmp_i < T(0)) ? T(-tmp_i) : T(tmp_i); const char tmp_sign = (tmp_i < T(0)) ? char('-') : char('+'); - arma_ostream::print_elem(f, tmp_r, false); + arma_ostream::raw_print_elem(f, tmp_r ); f.put(tmp_sign); - arma_ostream::print_elem(f, tmp_i_abs, false); + arma_ostream::raw_print_elem(f, tmp_i_abs); f.put('i'); - if( col < (x_n_cols-1) ) { f.put(','); } + if( col < (x_n_cols-1) ) { f.put(separator); } } f.put('\n'); @@ -827,7 +977,127 @@ diskio::save_csv_ascii(const Mat< std::complex >& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat& x, const std::string& final_name) + { + arma_extra_debug_sigprint(); + + const std::string tmp_name = diskio::gen_tmp_name(final_name); + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + + bool save_okay = f.is_open(); + + if(save_okay) + { + save_okay = diskio::save_coord_ascii(x, f); + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + } + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + for(uword col=0; col < x.n_cols; ++col) + for(uword row=0; row < x.n_rows; ++row) + { + const eT val = x.at(row,col); + + if(val != eT(0)) + { + f << row << ' ' << col << ' ' << val << '\n'; + } + } + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT(0) ) + { + f << max_row << ' ' << max_col << " 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; + } + + + +template +inline +bool +diskio::save_coord_ascii(const Mat< std::complex >& x, std::ostream& f) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); + + const eT eT_zero = eT(0); + + for(uword col=0; col < x.n_cols; ++col) + for(uword row=0; row < x.n_rows; ++row) + { + const eT& val = x.at(row,col); + + if(val != eT_zero) + { + f << row << ' ' << col << ' ' << val.real() << ' ' << val.imag() << '\n'; + } + } + + // make sure it's possible to figure out the matrix size later + if( (x.n_rows > 0) && (x.n_cols > 0) ) + { + const uword max_row = (x.n_rows > 0) ? x.n_rows-1 : 0; + const uword max_col = (x.n_cols > 0) ? x.n_cols-1 : 0; + + if( x.at(max_row, max_col) == eT_zero ) + { + f << max_row << ' ' << max_col << " 0 0\n"; + } + } + + const bool save_okay = f.good(); + + stream_state.restore(f); return save_okay; } @@ -845,7 +1115,7 @@ diskio::save_arma_binary(const Mat& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str(), std::fstream::binary); + std::ofstream f(tmp_name, std::fstream::binary); bool save_okay = f.is_open(); @@ -872,7 +1142,7 @@ bool diskio::save_arma_binary(const Mat& x, std::ostream& f) { arma_extra_debug_sigprint(); - + f << diskio::gen_bin_header(x) << '\n'; f << x.n_rows << ' ' << x.n_cols << '\n'; @@ -893,7 +1163,7 @@ diskio::save_pgm_binary(const Mat& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::fstream f(tmp_name.c_str(), std::fstream::out | std::fstream::binary); + std::fstream f(tmp_name, std::fstream::out | std::fstream::binary); bool save_okay = f.is_open(); @@ -912,11 +1182,6 @@ diskio::save_pgm_binary(const Mat& x, const std::string& final_name) -// -// TODO: -// add functionality to save the image in a normalised format, -// i.e. scaled so that every value falls in the [0,255] range. - //! Save a matrix as a PGM greyscale image template inline @@ -995,12 +1260,12 @@ diskio::save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& e const bool append = bool(spec.opts.flags & hdf5_opts::flag_append); const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); - const bool use_existing_file = ((append || replace) && (arma_H5Fis_hdf5(spec.filename.c_str()) > 0)); + const bool use_existing_file = ((append || replace) && (H5Fis_hdf5(spec.filename.c_str()) > 0)); const std::string tmp_name = (use_existing_file) ? std::string() : diskio::gen_tmp_name(spec.filename); // Set up the file according to HDF5's preferences - hid_t file = (use_existing_file) ? arma_H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : arma_H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); + hid_t file = (use_existing_file) ? H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); if(file < 0) { return false; } @@ -1009,7 +1274,7 @@ diskio::save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& e dims[1] = x.n_rows; dims[0] = x.n_cols; - hid_t dataspace = arma_H5Screate_simple(2, dims, NULL); // treat the matrix as a 2d array dataspace + hid_t dataspace = H5Screate_simple(2, dims, NULL); // treat the matrix as a 2d array dataspace hid_t datatype = hdf5_misc::get_hdf5_type(); // If this returned something invalid, well, it's time to crash. @@ -1022,16 +1287,16 @@ diskio::save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& e std::vector groups; std::string full_name = spec.dsname; size_t loc; - while ((loc = full_name.find("/")) != std::string::npos) + while((loc = full_name.find("/")) != std::string::npos) { // Create another group... - if (loc != 0) // Ignore the first /, if there is a leading /. + if(loc != 0) // Ignore the first /, if there is a leading /. { - hid_t gid = arma_H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + hid_t gid = H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); if((gid < 0) && use_existing_file) { - gid = arma_H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); + gid = H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); } groups.push_back(gid); @@ -1046,32 +1311,32 @@ diskio::save_hdf5_binary(const Mat& x, const hdf5_name& spec, std::string& e if(use_existing_file && replace) { - arma_H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); + H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); // NOTE: H5Ldelete() in HDF5 v1.8 doesn't reclaim the deleted space; use h5repack to reclaim space: h5repack oldfile.h5 newfile.h5 // NOTE: has this behaviour changed in HDF5 1.10 ? // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010482.html // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010486.html } - hid_t dataset = arma_H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + hid_t dataset = H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); if(dataset < 0) { save_okay = false; - err_msg = "couldn't create dataset in "; + err_msg = "failed to create dataset"; } else { - save_okay = (arma_H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); + save_okay = (H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); - arma_H5Dclose(dataset); + H5Dclose(dataset); } - arma_H5Tclose(datatype); - arma_H5Sclose(dataspace); - for (size_t i = 0; i < groups.size(); ++i) { arma_H5Gclose(groups[i]); } - arma_H5Fclose(file); + H5Tclose(datatype); + H5Sclose(dataspace); + for(size_t i = 0; i < groups.size(); ++i) { H5Gclose(groups[i]); } + H5Fclose(file); if((use_existing_file == false) && (save_okay == true)) { save_okay = diskio::safe_rename(tmp_name, spec.filename); } @@ -1101,9 +1366,10 @@ bool diskio::load_raw_ascii(Mat& x, const std::string& name, std::string& err_msg) { arma_extra_debug_sigprint(); - - std::fstream f; - f.open(name.c_str(), std::fstream::in); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); @@ -1158,7 +1424,7 @@ diskio::load_raw_ascii(Mat& x, std::istream& f, std::string& err_msg) uword line_n_cols = 0; - while (line_stream >> token) { ++line_n_cols; } + while(line_stream >> token) { ++line_n_cols; } if(f_n_cols_found == false) { @@ -1170,7 +1436,7 @@ diskio::load_raw_ascii(Mat& x, std::istream& f, std::string& err_msg) if(line_n_cols != f_n_cols) { load_okay = false; - err_msg = "inconsistent number of columns in "; + err_msg = "inconsistent number of columns"; } } @@ -1183,7 +1449,9 @@ diskio::load_raw_ascii(Mat& x, std::istream& f, std::string& err_msg) f.clear(); f.seekg(pos1); - x.set_size(f_n_rows, f_n_cols); + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } for(uword row=0; ((row < x.n_rows) && load_okay); ++row) for(uword col=0; ((col < x.n_cols) && load_okay); ++col) @@ -1193,7 +1461,7 @@ diskio::load_raw_ascii(Mat& x, std::istream& f, std::string& err_msg) if(diskio::convert_token(x.at(row,col), token) == false) { load_okay = false; - err_msg = "couldn't interpret data in "; + err_msg = "data interpretation failure"; } } } @@ -1218,7 +1486,7 @@ diskio::load_raw_binary(Mat& x, const std::string& name, std::string& err_ms arma_extra_debug_sigprint(); std::ifstream f; - f.open(name.c_str(), std::fstream::binary); + f.open(name, std::fstream::binary); bool load_okay = f.is_open(); @@ -1239,7 +1507,6 @@ bool diskio::load_raw_binary(Mat& x, std::istream& f, std::string& err_msg) { arma_extra_debug_sigprint(); - arma_ignore(err_msg); f.clear(); const std::streampos pos1 = f.tellg(); @@ -1256,7 +1523,9 @@ diskio::load_raw_binary(Mat& x, std::istream& f, std::string& err_msg) //f.seekg(0, ios::beg); f.seekg(pos1); - x.set_size(N / uword(sizeof(eT)), 1); + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.set_size(N / uword(sizeof(eT)), 1); } catch(...) { err_msg = "not enough memory"; return false; } f.clear(); f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem * uword(sizeof(eT))) ); @@ -1275,7 +1544,9 @@ diskio::load_arma_ascii(Mat& x, const std::string& name, std::string& err_ms { arma_extra_debug_sigprint(); - std::ifstream f(name.c_str()); + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); @@ -1300,7 +1571,7 @@ diskio::load_arma_ascii(Mat& x, std::istream& f, std::string& err_msg) arma_extra_debug_sigprint(); std::streampos pos = f.tellg(); - + bool load_okay = true; std::string f_header; @@ -1313,7 +1584,7 @@ diskio::load_arma_ascii(Mat& x, std::istream& f, std::string& err_msg) if(f_header == diskio::gen_txt_header(x)) { - x.zeros(f_n_rows, f_n_cols); + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } std::string token; @@ -1330,7 +1601,7 @@ diskio::load_arma_ascii(Mat& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "incorrect header in "; + err_msg = "incorrect header"; } @@ -1374,21 +1645,70 @@ diskio::load_arma_ascii(Mat& x, std::istream& f, std::string& err_msg) template inline bool -diskio::load_csv_ascii(Mat& x, const std::string& name, std::string& err_msg) +diskio::load_csv_ascii(Mat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator, const bool strict) { arma_extra_debug_sigprint(); - std::fstream f; - f.open(name.c_str(), std::fstream::in); + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); - if(load_okay) + if(load_okay == false) { return false; } + + if(with_header) { - load_okay = diskio::load_csv_ascii(x, f, err_msg); - f.close(); + arma_extra_debug_print("diskio::load_csv_ascii(): reading header"); + + std::string header_line; + std::stringstream header_stream; + std::vector header_tokens; + + std::getline(f, header_line); + + load_okay = f.good(); + + if(load_okay) + { + std::string token; + + header_stream.clear(); + header_stream.str(header_line); + + uword header_n_tokens = 0; + + while(header_stream.good()) + { + std::getline(header_stream, token, separator); + + diskio::sanitise_token(token); + + ++header_n_tokens; + + header_tokens.push_back(token); + } + + if(header_n_tokens == uword(0)) + { + header.reset(); + } + else + { + header.set_size(1,header_n_tokens); + + for(uword i=0; i < header_n_tokens; ++i) { header.at(i) = header_tokens[i]; } + } + } } + if(load_okay) + { + load_okay = diskio::load_csv_ascii(x, f, err_msg, separator, strict); + } + + f.close(); + return load_okay; } @@ -1398,13 +1718,13 @@ diskio::load_csv_ascii(Mat& x, const std::string& name, std::string& err_msg template inline bool -diskio::load_csv_ascii(Mat& x, std::istream& f, std::string&) +diskio::load_csv_ascii(Mat& x, std::istream& f, std::string& err_msg, const char separator, const bool strict) { arma_extra_debug_sigprint(); // TODO: replace with more efficient implementation - bool load_okay = f.good(); + if(f.good() == false) { return false; } f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -1420,7 +1740,7 @@ diskio::load_csv_ascii(Mat& x, std::istream& f, std::string&) std::string token; - while( f.good() && load_okay ) + while(f.good()) { std::getline(f, line_string); @@ -1433,7 +1753,7 @@ diskio::load_csv_ascii(Mat& x, std::istream& f, std::string&) while(line_stream.good()) { - std::getline(line_stream, token, ','); + std::getline(line_stream, token, separator); ++line_n_cols; } @@ -1445,34 +1765,106 @@ diskio::load_csv_ascii(Mat& x, std::istream& f, std::string&) f.clear(); f.seekg(pos1); - x.zeros(f_n_rows, f_n_cols); + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } - uword row = 0; + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } - while(f.good()) + if(strict) { x.fill(Datum::nan); } // take into account that each row may have a unique number of columns + + const bool use_mp = (arma_config::openmp) && (f_n_rows >= 2) && (f_n_cols >= 64); + + field token_array; + + bool token_array_ok = false; + + if(use_mp) { - std::getline(f, line_string); - - if(line_string.size() == 0) { break; } - - line_stream.clear(); - line_stream.str(line_string); - - uword col = 0; - - while(line_stream.good()) + try { - std::getline(line_stream, token, ','); + token_array.set_size(f_n_cols); - diskio::convert_token( x.at(row,col), token ); + for(uword i=0; i < f_n_cols; ++i) { token_array(i).reserve(32); } - ++col; + token_array_ok = true; } + catch(...) + { + token_array.reset(); + } + } + + if(use_mp && token_array_ok) + { + #if defined(ARMA_USE_OPENMP) + { + uword row = 0; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + for(uword i=0; i < f_n_cols; ++i) { token_array(i).clear(); } + + uword line_stream_col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token_array(line_stream_col), separator); + + ++line_stream_col; + } + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < line_stream_col; ++col) + { + eT& out_val = x.at(row,col); + + (strict) ? diskio::convert_token_strict( out_val, token_array(col) ) : diskio::convert_token( out_val, token_array(col) ); + } + + ++row; + } + } + #endif + } + else // serial implementation + { + uword row = 0; - ++row; + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword col = 0; + + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + + eT& out_val = x.at(row,col); + + (strict) ? diskio::convert_token_strict( out_val, token ) : diskio::convert_token( out_val, token ); + + ++col; + } + + ++row; + } } - return load_okay; + return true; } @@ -1481,13 +1873,13 @@ diskio::load_csv_ascii(Mat& x, std::istream& f, std::string&) template inline bool -diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) +diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator, const bool strict) { arma_extra_debug_sigprint(); // TODO: replace with more efficient implementation - bool load_okay = f.good(); + if(f.good() == false) { return false; } f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -1503,7 +1895,7 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) std::string token; - while( f.good() && load_okay ) + while(f.good()) { std::getline(f, line_string); @@ -1516,7 +1908,7 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) while(line_stream.good()) { - std::getline(line_stream, token, ','); + std::getline(line_stream, token, separator); ++line_n_cols; } @@ -1528,7 +1920,11 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) f.clear(); f.seekg(pos1); - x.zeros(f_n_rows, f_n_cols); + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { x.zeros(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + + if(strict) { x.fill(Datum< std::complex >::nan); } // take into account that each row may have a unique number of columns uword row = 0; @@ -1548,7 +1944,9 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) while(line_stream.good()) { - std::getline(line_stream, token, ','); + std::getline(line_stream, token, separator); + + diskio::sanitise_token(token); const size_t token_len = size_t( token.length() ); @@ -1587,7 +1985,9 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) if(found_val_real) { - x.at(row,col) = std::complex(val_real, T(0)); + const T val_imag = (strict) ? T(Datum::nan) : T(0); + + x.at(row,col) = std::complex(val_real, val_imag); col++; continue; // get next token } @@ -1691,8 +2091,8 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) T val_real = T(0); T val_imag = T(0); - diskio::convert_token(val_real, str_real); - diskio::convert_token(val_imag, str_imag); + (strict) ? diskio::convert_token_strict(val_real, str_real) : diskio::convert_token(val_real, str_real); + (strict) ? diskio::convert_token_strict(val_imag, str_imag) : diskio::convert_token(val_imag, str_imag); x.at(row,col) = std::complex(val_real, val_imag); @@ -1702,11 +2102,246 @@ diskio::load_csv_ascii(Mat< std::complex >& x, std::istream& f, std::string&) ++row; } + return true; + } + + + +template +inline +bool +diskio::load_coord_ascii(Mat& x, const std::string& name, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); + + bool load_okay = f.is_open(); + + if(load_okay == false) { return false; } + + if(load_okay) + { + load_okay = diskio::load_coord_ascii(x, f, err_msg); + } + + f.close(); + return load_okay; } +//! Load a matrix in CSV text format (human readable) +template +inline +bool +diskio::load_coord_ascii(Mat& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) { f_n_rows = line_row; } + if(f_n_cols < line_col) { f_n_cols = line_col; } + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + Mat tmp(f_n_rows, f_n_cols, arma_zeros_indicator()); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + eT val = eT(0); + + line_stream >> token; + + if(line_stream.fail() == false) { diskio::convert_token( val, token ); } + + if(val != eT(0)) { tmp(line_row,line_col) = val; } + } + + x.steal_mem(tmp); + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + +template +inline +bool +diskio::load_coord_ascii(Mat< std::complex >& x, std::istream& f, std::string& err_msg) + { + arma_extra_debug_sigprint(); + + if(f.good() == false) { return false; } + + f.clear(); + const std::fstream::pos_type pos1 = f.tellg(); + + // work out the size + + uword f_n_rows = 0; + uword f_n_cols = 0; + + bool size_found = false; + + std::string line_string; + std::stringstream line_stream; + + std::string token_real; + std::string token_imag; + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + // a valid line in co-ord format has at least 2 entries + + line_stream >> line_row; + + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } + + line_stream >> line_col; + + size_found = true; + + if(f_n_rows < line_row) f_n_rows = line_row; + if(f_n_cols < line_col) f_n_cols = line_col; + } + + // take into account that indices start at 0 + if(size_found) { ++f_n_rows; ++f_n_cols; } + + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try + { + Mat< std::complex > tmp(f_n_rows, f_n_cols, arma_zeros_indicator()); + + while(f.good()) + { + std::getline(f, line_string); + + if(line_string.size() == 0) { break; } + + line_stream.clear(); + line_stream.str(line_string); + + uword line_row = 0; + uword line_col = 0; + + line_stream >> line_row; + line_stream >> line_col; + + T val_real = T(0); + T val_imag = T(0); + + line_stream >> token_real; + + if(line_stream.fail() == false) { diskio::convert_token( val_real, token_real ); } + + line_stream >> token_imag; + + if(line_stream.fail() == false) { diskio::convert_token( val_imag, token_imag ); } + + if( (val_real != T(0)) || (val_imag != T(0)) ) + { + tmp(line_row,line_col) = std::complex(val_real, val_imag); + } + } + + x.steal_mem(tmp); + } + catch(...) + { + err_msg = "not enough memory"; + return false; + } + + return true; + } + + + //! Load a matrix in binary format, //! with a header that indicates the matrix type as well as its dimensions template @@ -1717,7 +2352,7 @@ diskio::load_arma_binary(Mat& x, const std::string& name, std::string& err_m arma_extra_debug_sigprint(); std::ifstream f; - f.open(name.c_str(), std::fstream::binary); + f.open(name, std::fstream::binary); bool load_okay = f.is_open(); @@ -1740,7 +2375,7 @@ diskio::load_arma_binary(Mat& x, std::istream& f, std::string& err_msg) arma_extra_debug_sigprint(); std::streampos pos = f.tellg(); - + bool load_okay = true; std::string f_header; @@ -1756,7 +2391,8 @@ diskio::load_arma_binary(Mat& x, std::istream& f, std::string& err_msg) //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters f.get(); - x.set_size(f_n_rows,f_n_cols); + try { x.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem*sizeof(eT)) ); load_okay = f.good(); @@ -1764,7 +2400,7 @@ diskio::load_arma_binary(Mat& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "incorrect header in "; + err_msg = "incorrect header"; } @@ -1811,7 +2447,7 @@ diskio::pnm_skip_comments(std::istream& f) while( isspace(f.peek()) ) { while( isspace(f.peek()) ) { f.get(); } - + if(f.peek() == '#') { while( (f.peek() != '\r') && (f.peek() != '\n') ) { f.get(); } @@ -1830,7 +2466,7 @@ diskio::load_pgm_binary(Mat& x, const std::string& name, std::string& err_ms arma_extra_debug_sigprint(); std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -1862,21 +2498,21 @@ diskio::load_pgm_binary(Mat& x, std::istream& f, std::string& err_msg) uword f_n_rows = 0; uword f_n_cols = 0; int f_maxval = 0; - + diskio::pnm_skip_comments(f); - + f >> f_n_cols; diskio::pnm_skip_comments(f); - + f >> f_n_rows; diskio::pnm_skip_comments(f); - + f >> f_maxval; f.get(); if( (f_maxval > 0) && (f_maxval <= 65535) ) { - x.set_size(f_n_rows,f_n_cols); + try { x.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } if(f_maxval <= 255) { @@ -1917,7 +2553,7 @@ diskio::load_pgm_binary(Mat& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "functionality unimplemented to handle loading "; + err_msg = "functionality unimplemented"; } if(f.good() == false) { load_okay = false; } @@ -1925,7 +2561,7 @@ diskio::load_pgm_binary(Mat& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "unsupported header in "; + err_msg = "unsupported header"; } return load_okay; @@ -1979,11 +2615,13 @@ diskio::load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg #if defined(ARMA_USE_HDF5) { + if(diskio::is_readable(spec.filename) == false) { return false; } + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; bool load_okay = false; - hid_t fid = arma_H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + hid_t fid = H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); if(fid >= 0) { @@ -2010,39 +2648,39 @@ diskio::load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg if(dataset >= 0) { - hid_t filespace = arma_H5Dget_space(dataset); + hid_t filespace = H5Dget_space(dataset); // This must be <= 2 due to our search rules. - const int ndims = arma_H5Sget_simple_extent_ndims(filespace); + const int ndims = H5Sget_simple_extent_ndims(filespace); hsize_t dims[2]; - const herr_t query_status = arma_H5Sget_simple_extent_dims(filespace, dims, NULL); + const herr_t query_status = H5Sget_simple_extent_dims(filespace, dims, NULL); // arma_check(query_status < 0, "Mat::load(): cannot get size of HDF5 dataset"); if(query_status < 0) { - err_msg = "cannot get size of HDF5 dataset in "; + err_msg = "cannot get size of HDF5 dataset"; - arma_H5Sclose(filespace); - arma_H5Dclose(dataset); - arma_H5Fclose(fid); + H5Sclose(filespace); + H5Dclose(dataset); + H5Fclose(fid); return false; } if(ndims == 1) { dims[1] = 1; } // Vector case; fake second dimension (one column). - x.set_size(dims[1], dims[0]); + try { x.set_size(dims[1], dims[0]); } catch(...) { err_msg = "not enough memory"; return false; } // Now we have to see what type is stored to figure out how to load it. - hid_t datatype = arma_H5Dget_type(dataset); + hid_t datatype = H5Dget_type(dataset); hid_t mat_type = hdf5_misc::get_hdf5_type(); // If these are the same type, it is simple. - if(arma_H5Tequal(datatype, mat_type) > 0) + if(H5Tequal(datatype, mat_type) > 0) { // Load directly; H5S_ALL used so that we load the entire dataset. - hid_t read_status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); + hid_t read_status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); if(read_status >= 0) { load_okay = true; } } @@ -2055,25 +2693,25 @@ diskio::load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg } // Now clean up. - arma_H5Tclose(datatype); - arma_H5Tclose(mat_type); - arma_H5Sclose(filespace); + H5Tclose(datatype); + H5Tclose(mat_type); + H5Sclose(filespace); } - arma_H5Dclose(dataset); - - arma_H5Fclose(fid); - + H5Dclose(dataset); + + H5Fclose(fid); + if(load_okay == false) { - err_msg = "unsupported or missing HDF5 data in "; + err_msg = "unsupported or missing HDF5 data"; } } else { - err_msg = "cannot open file "; + err_msg = "cannot open"; } - + return load_okay; } #else @@ -2081,9 +2719,9 @@ diskio::load_hdf5_binary(Mat& x, const hdf5_name& spec, std::string& err_msg arma_ignore(x); arma_ignore(spec); arma_ignore(err_msg); - + arma_stop_logic_error("Mat::load(): use of HDF5 must be enabled"); - + return false; } #endif @@ -2099,13 +2737,15 @@ diskio::load_auto_detect(Mat& x, const std::string& name, std::string& err_m { arma_extra_debug_sigprint(); + if(diskio::is_readable(name) == false) { return false; } + #if defined(ARMA_USE_HDF5) // We're currently using the C bindings for the HDF5 library, which don't support C++ streams - if( arma_H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } + if( H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } #endif - + std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -2164,12 +2804,16 @@ diskio::load_auto_detect(Mat& x, std::istream& f, std::string& err_msg) } else { - const file_type ft = guess_file_type(f); + const file_type ft = guess_file_type_internal(f); switch(ft) { case csv_ascii: - return load_csv_ascii(x, f, err_msg); + return load_csv_ascii(x, f, err_msg, char(','), false); + break; + + case ssv_ascii: + return load_csv_ascii(x, f, err_msg, char(';'), false); break; case raw_binary: @@ -2181,7 +2825,7 @@ diskio::load_auto_detect(Mat& x, std::istream& f, std::string& err_msg) break; default: - err_msg = "unknown data in "; + err_msg = "unknown data"; return false; } } @@ -2201,26 +2845,43 @@ diskio::load_auto_detect(Mat& x, std::istream& f, std::string& err_msg) template inline bool -diskio::save_csv_ascii(const SpMat& x, const std::string& final_name) +diskio::save_csv_ascii(const SpMat& x, const std::string& final_name, const field& header, const bool with_header, const char separator) { arma_extra_debug_sigprint(); const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str()); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); - if(save_okay) + if(save_okay == false) { return false; } + + if(with_header) { - save_okay = diskio::save_csv_ascii(x, f); + arma_extra_debug_print("diskio::save_csv_ascii(): writing header"); - f.flush(); - f.close(); + for(uword i=0; i < header.n_elem; ++i) + { + f << header(i); + + if(i != (header.n_elem-1)) { f.put(separator); } + } - if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + f.put('\n'); + + save_okay = f.good(); } + if(save_okay) { save_okay = diskio::save_csv_ascii(x, f, separator); } + + f.flush(); + f.close(); + + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } + return save_okay; } @@ -2230,33 +2891,37 @@ diskio::save_csv_ascii(const SpMat& x, const std::string& final_name) template inline bool -diskio::save_csv_ascii(const SpMat& x, std::ostream& f) +diskio::save_csv_ascii(const SpMat& x, std::ostream& f, const char separator) { arma_extra_debug_sigprint(); - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); - if( (is_float::value) || (is_double::value) ) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + diskio::prepare_stream(f); x.sync(); uword x_n_rows = x.n_rows; uword x_n_cols = x.n_cols; + const eT eT_zero = eT(0); + for(uword row=0; row < x_n_rows; ++row) { for(uword col=0; col < x_n_cols; ++col) { const eT val = x.at(row,col); - if(val != eT(0)) { arma_ostream::print_elem(f, val, false); } + if(val == eT_zero) + { + f.put('0'); + } + else + { + arma_ostream::raw_print_elem(f, val); + } - if( col < (x_n_cols-1) ) { f.put(','); } + if( col < (x_n_cols-1) ) { f.put(separator); } } f.put('\n'); @@ -2264,7 +2929,7 @@ diskio::save_csv_ascii(const SpMat& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -2275,14 +2940,15 @@ diskio::save_csv_ascii(const SpMat& x, std::ostream& f) template inline bool -diskio::save_csv_ascii(const SpMat< std::complex >& x, std::ostream& f) +diskio::save_csv_ascii(const SpMat< std::complex >& x, std::ostream& f, const char separator) { arma_extra_debug_sigprint(); arma_ignore(x); arma_ignore(f); + arma_ignore(separator); - arma_warn("saving complex sparse matrices as csv_ascii not yet implemented"); + arma_debug_warn_level(1, "saving complex sparse matrices as csv_ascii not yet implemented"); return false; } @@ -2296,23 +2962,25 @@ bool diskio::save_coord_ascii(const SpMat& x, const std::string& final_name) { arma_extra_debug_sigprint(); - + const std::string tmp_name = diskio::gen_tmp_name(final_name); - - std::ofstream f(tmp_name.c_str()); - + + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); + bool save_okay = f.is_open(); - + if(save_okay) { save_okay = diskio::save_coord_ascii(x, f); - + f.flush(); f.close(); - + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } } - + return save_okay; } @@ -2326,21 +2994,18 @@ diskio::save_coord_ascii(const SpMat& x, std::ostream& f) { arma_extra_debug_sigprint(); - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); + + diskio::prepare_stream(f); - if( (is_float::value) || (is_double::value) ) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } - typename SpMat::const_iterator iter = x.begin(); typename SpMat::const_iterator iter_end = x.end(); for(; iter != iter_end; ++iter) { - f << iter.row() << ' ' << iter.col() << ' ' << (*iter) << '\n'; + const eT val = (*iter); + + f << iter.row() << ' ' << iter.col() << ' ' << val << '\n'; } @@ -2358,7 +3023,7 @@ diskio::save_coord_ascii(const SpMat& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -2375,14 +3040,9 @@ diskio::save_coord_ascii(const SpMat< std::complex >& x, std::ostream& f) typedef typename std::complex eT; - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); - if( (is_float::value) || (is_double::value) ) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + diskio::prepare_stream(f); typename SpMat::const_iterator iter = x.begin(); typename SpMat::const_iterator iter_end = x.end(); @@ -2408,7 +3068,7 @@ diskio::save_coord_ascii(const SpMat< std::complex >& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -2423,23 +3083,23 @@ bool diskio::save_arma_binary(const SpMat& x, const std::string& final_name) { arma_extra_debug_sigprint(); - + const std::string tmp_name = diskio::gen_tmp_name(final_name); - - std::ofstream f(tmp_name.c_str(), std::fstream::binary); - + + std::ofstream f(tmp_name, std::fstream::binary); + bool save_okay = f.is_open(); - + if(save_okay) { save_okay = diskio::save_arma_binary(x, f); - + f.flush(); f.close(); - + if(save_okay) { save_okay = diskio::safe_rename(tmp_name, final_name); } } - + return save_okay; } @@ -2469,21 +3129,70 @@ diskio::save_arma_binary(const SpMat& x, std::ostream& f) template inline bool -diskio::load_csv_ascii(SpMat& x, const std::string& name, std::string& err_msg) +diskio::load_csv_ascii(SpMat& x, const std::string& name, std::string& err_msg, field& header, const bool with_header, const char separator) { arma_extra_debug_sigprint(); - std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); + if(load_okay == false) { return false; } + + if(with_header) + { + arma_extra_debug_print("diskio::load_csv_ascii(): reading header"); + + std::string header_line; + std::stringstream header_stream; + std::vector header_tokens; + + std::getline(f, header_line); + + load_okay = f.good(); + + if(load_okay) + { + std::string token; + + header_stream.clear(); + header_stream.str(header_line); + + uword header_n_tokens = 0; + + while(header_stream.good()) + { + std::getline(header_stream, token, separator); + + diskio::sanitise_token(token); + + ++header_n_tokens; + + header_tokens.push_back(token); + } + + if(header_n_tokens == uword(0)) + { + header.reset(); + } + else + { + header.set_size(1,header_n_tokens); + + for(uword i=0; i < header_n_tokens; ++i) { header.at(i) = header_tokens[i]; } + } + } + } + if(load_okay) { - load_okay = diskio::load_csv_ascii(x, f, err_msg); - f.close(); + load_okay = diskio::load_csv_ascii(x, f, err_msg, separator); } + f.close(); + return load_okay; } @@ -2492,14 +3201,13 @@ diskio::load_csv_ascii(SpMat& x, const std::string& name, std::string& err_m template inline bool -diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg) +diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg, const char separator) { arma_extra_debug_sigprint(); - arma_ignore(err_msg); // TODO: replace with more efficient implementation - bool load_okay = f.good(); + if(f.good() == false) { return false; } f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -2515,7 +3223,7 @@ diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg) std::string token; - while( f.good() && load_okay ) + while(f.good()) { std::getline(f, line_string); @@ -2528,7 +3236,7 @@ diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg) while(line_stream.good()) { - std::getline(line_stream, token, ','); + std::getline(line_stream, token, separator); ++line_n_cols; } @@ -2540,38 +3248,50 @@ diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg) f.clear(); f.seekg(pos1); - x.zeros(f_n_rows, f_n_cols); - - uword row = 0; + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } - while(f.good()) + try { - std::getline(f, line_string); - - if(line_string.size() == 0) { break; } - - line_stream.clear(); - line_stream.str(line_string); + MapMat tmp(f_n_rows, f_n_cols); - uword col = 0; + uword row = 0; - while(line_stream.good()) + while(f.good()) { - std::getline(line_stream, token, ','); + std::getline(f, line_string); - eT val = eT(0); + if(line_string.size() == 0) { break; } - diskio::convert_token( val, token ); + line_stream.clear(); + line_stream.str(line_string); - if(val != eT(0)) { x(row,col) = val; } + uword col = 0; - ++col; + while(line_stream.good()) + { + std::getline(line_stream, token, separator); + + eT val = eT(0); + + diskio::convert_token( val, token ); + + if(val != eT(0)) { tmp(row,col) = val; } + + ++col; + } + + ++row; } - ++row; + x = tmp; + } + catch(...) + { + err_msg = "not enough memory"; + return false; } - return load_okay; + return true; } @@ -2579,15 +3299,16 @@ diskio::load_csv_ascii(SpMat& x, std::istream& f, std::string& err_msg) template inline bool -diskio::load_csv_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg) +diskio::load_csv_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg, const char separator) { arma_extra_debug_sigprint(); arma_ignore(x); arma_ignore(f); arma_ignore(err_msg); + arma_ignore(separator); - arma_warn("loading complex sparse matrices as csv_ascii not yet implemented"); + arma_debug_warn_level(1, "loading complex sparse matrices as csv_ascii not yet implemented"); return false; } @@ -2601,8 +3322,9 @@ diskio::load_coord_ascii(SpMat& x, const std::string& name, std::string& err { arma_extra_debug_sigprint(); - std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); @@ -2623,9 +3345,8 @@ bool diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) { arma_extra_debug_sigprint(); - arma_ignore(err_msg); - bool load_okay = f.good(); + if(f.good() == false) { return false; } f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -2642,7 +3363,7 @@ diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) std::string token; - while( f.good() && load_okay ) + while(f.good()) { std::getline(f, line_string); @@ -2658,7 +3379,7 @@ diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) line_stream >> line_row; - if(line_stream.good() == false) { load_okay = false; break; } + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } line_stream >> line_col; @@ -2668,16 +3389,16 @@ diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) if(f_n_cols < line_col) { f_n_cols = line_col; } } - // take into account that indices start at 0 if(size_found) { ++f_n_rows; ++f_n_cols; } + f.clear(); + f.seekg(pos1); - if(load_okay) + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { - f.clear(); - f.seekg(pos1); - MapMat tmp(f_n_rows, f_n_cols); while(f.good()) @@ -2699,18 +3420,20 @@ diskio::load_coord_ascii(SpMat& x, std::istream& f, std::string& err_msg) line_stream >> token; - if(line_stream.fail() == false) - { - diskio::convert_token( val, token ); - } + if(line_stream.fail() == false) { diskio::convert_token( val, token ); } if(val != eT(0)) { tmp(line_row,line_col) = val; } } x = tmp; } + catch(...) + { + err_msg = "not enough memory"; + return false; + } - return load_okay; + return true; } @@ -2721,9 +3444,8 @@ bool diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::string& err_msg) { arma_extra_debug_sigprint(); - arma_ignore(err_msg); - bool load_okay = f.good(); + if(f.good() == false) { return false; } f.clear(); const std::fstream::pos_type pos1 = f.tellg(); @@ -2741,7 +3463,7 @@ diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::stri std::string token_real; std::string token_imag; - while( f.good() && load_okay ) + while(f.good()) { std::getline(f, line_string); @@ -2757,7 +3479,7 @@ diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::stri line_stream >> line_row; - if(line_stream.good() == false) { load_okay = false; break; } + if(line_stream.good() == false) { err_msg = "incorrect format"; return false; } line_stream >> line_col; @@ -2770,11 +3492,13 @@ diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::stri // take into account that indices start at 0 if(size_found) { ++f_n_rows; ++f_n_cols; } - if(load_okay) + f.clear(); + f.seekg(pos1); + + if(f.fail() || (f.tellg() != pos1)) { err_msg = "seek failure"; return false; } + + try { - f.clear(); - f.seekg(pos1); - MapMat< std::complex > tmp(f_n_rows, f_n_cols); while(f.good()) @@ -2797,18 +3521,11 @@ diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::stri line_stream >> token_real; - if(line_stream.fail() == false) - { - diskio::convert_token( val_real, token_real ); - } - + if(line_stream.fail() == false) { diskio::convert_token( val_real, token_real ); } line_stream >> token_imag; - if(line_stream.fail() == false) - { - diskio::convert_token( val_imag, token_imag ); - } + if(line_stream.fail() == false) { diskio::convert_token( val_imag, token_imag ); } if( (val_real != T(0)) || (val_imag != T(0)) ) { @@ -2818,8 +3535,13 @@ diskio::load_coord_ascii(SpMat< std::complex >& x, std::istream& f, std::stri x = tmp; } + catch(...) + { + err_msg = "not enough memory"; + return false; + } - return load_okay; + return true; } @@ -2832,18 +3554,18 @@ bool diskio::load_arma_binary(SpMat& x, const std::string& name, std::string& err_msg) { arma_extra_debug_sigprint(); - + std::ifstream f; - f.open(name.c_str(), std::fstream::binary); - + f.open(name, std::fstream::binary); + bool load_okay = f.is_open(); - + if(load_okay) { load_okay = diskio::load_arma_binary(x, f, err_msg); f.close(); } - + return load_okay; } @@ -2875,7 +3597,7 @@ diskio::load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg) //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters f.get(); - x.reserve(f_n_rows, f_n_cols, f_n_nz); + try { x.reserve(f_n_rows, f_n_cols, f_n_nz); } catch(...) { err_msg = "not enough memory"; return false; } f.read( reinterpret_cast(access::rwp(x.values)), std::streamsize(x.n_nonzero*sizeof(eT)) ); @@ -2928,7 +3650,7 @@ diskio::load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg) if((check1 == false) || (check2 == false) || (check3 == false)) { load_okay = false; - err_msg = "inconsistent data in "; + err_msg = "inconsistent data"; } else { @@ -2938,9 +3660,9 @@ diskio::load_arma_binary(SpMat& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "incorrect header in "; + err_msg = "incorrect header"; } - + return load_okay; } @@ -2960,7 +3682,9 @@ diskio::save_raw_ascii(const Cube& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::fstream f(tmp_name.c_str(), std::fstream::out); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); @@ -2987,22 +3711,9 @@ diskio::save_raw_ascii(const Cube& x, std::ostream& f) { arma_extra_debug_sigprint(); - uword cell_width; - - if(is_real::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - cell_width = 22; - } + const arma_ostream_state stream_state(f); - if(is_cx::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + const std::streamsize cell_width = diskio::prepare_stream(f); for(uword slice=0; slice < x.n_slices; ++slice) { @@ -3012,19 +3723,20 @@ diskio::save_raw_ascii(const Cube& x, std::ostream& f) { f.put(' '); - if(is_real::value) - { - f.width(std::streamsize(cell_width)); - } + if(is_real::value) { f.width(cell_width); } - arma_ostream::print_elem(f, x.at(row,col,slice), false); + arma_ostream::raw_print_elem(f, x.at(row,col,slice)); } f.put('\n'); } } - return f.good(); + const bool save_okay = f.good(); + + stream_state.restore(f); + + return save_okay; } @@ -3039,7 +3751,7 @@ diskio::save_raw_binary(const Cube& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str(), std::fstream::binary); + std::ofstream f(tmp_name, std::fstream::binary); bool save_okay = f.is_open(); @@ -3083,7 +3795,9 @@ diskio::save_arma_ascii(const Cube& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str()); + std::ofstream f; + + (arma_config::text_as_binary) ? f.open(tmp_name, std::fstream::binary) : f.open(tmp_name); bool save_okay = f.is_open(); @@ -3111,27 +3825,12 @@ diskio::save_arma_ascii(const Cube& x, std::ostream& f) { arma_extra_debug_sigprint(); - const ios::fmtflags orig_flags = f.flags(); + const arma_ostream_state stream_state(f); f << diskio::gen_txt_header(x) << '\n'; f << x.n_rows << ' ' << x.n_cols << ' ' << x.n_slices << '\n'; - uword cell_width; - - if(is_real::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - cell_width = 22; - } - - if(is_cx::value) - { - f.unsetf(ios::fixed); - f.setf(ios::scientific); - f.precision(14); - } + const std::streamsize cell_width = diskio::prepare_stream(f); for(uword slice=0; slice < x.n_slices; ++slice) { @@ -3141,12 +3840,9 @@ diskio::save_arma_ascii(const Cube& x, std::ostream& f) { f.put(' '); - if(is_real::value) - { - f.width(std::streamsize(cell_width)); - } + if(is_real::value) { f.width(cell_width); } - arma_ostream::print_elem(f, x.at(row,col,slice), false); + arma_ostream::raw_print_elem(f, x.at(row,col,slice)); } f.put('\n'); @@ -3155,7 +3851,7 @@ diskio::save_arma_ascii(const Cube& x, std::ostream& f) const bool save_okay = f.good(); - f.flags(orig_flags); + stream_state.restore(f); return save_okay; } @@ -3173,7 +3869,7 @@ diskio::save_arma_binary(const Cube& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f(tmp_name.c_str(), std::fstream::binary); + std::ofstream f(tmp_name, std::fstream::binary); bool save_okay = f.is_open(); @@ -3228,12 +3924,12 @@ diskio::save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& const bool append = bool(spec.opts.flags & hdf5_opts::flag_append); const bool replace = bool(spec.opts.flags & hdf5_opts::flag_replace); - const bool use_existing_file = ((append || replace) && (arma_H5Fis_hdf5(spec.filename.c_str()) > 0)); + const bool use_existing_file = ((append || replace) && (H5Fis_hdf5(spec.filename.c_str()) > 0)); const std::string tmp_name = (use_existing_file) ? std::string() : diskio::gen_tmp_name(spec.filename); // Set up the file according to HDF5's preferences - hid_t file = (use_existing_file) ? arma_H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : arma_H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); + hid_t file = (use_existing_file) ? H5Fopen(spec.filename.c_str(), H5F_ACC_RDWR, H5P_DEFAULT) : H5Fcreate(tmp_name.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); if(file < 0) { return false; } @@ -3243,7 +3939,7 @@ diskio::save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& dims[1] = x.n_cols; dims[0] = x.n_slices; - hid_t dataspace = arma_H5Screate_simple(3, dims, NULL); // treat the cube as a 3d array dataspace + hid_t dataspace = H5Screate_simple(3, dims, NULL); // treat the cube as a 3d array dataspace hid_t datatype = hdf5_misc::get_hdf5_type(); // If this returned something invalid, well, it's time to crash. @@ -3256,16 +3952,16 @@ diskio::save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& std::vector groups; std::string full_name = spec.dsname; size_t loc; - while ((loc = full_name.find("/")) != std::string::npos) + while((loc = full_name.find("/")) != std::string::npos) { // Create another group... - if (loc != 0) // Ignore the first /, if there is a leading /. + if(loc != 0) // Ignore the first /, if there is a leading /. { - hid_t gid = arma_H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + hid_t gid = H5Gcreate((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); if((gid < 0) && use_existing_file) { - gid = arma_H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); + gid = H5Gopen((groups.size() == 0) ? file : groups[groups.size() - 1], full_name.substr(0, loc).c_str(), H5P_DEFAULT); } groups.push_back(gid); @@ -3280,32 +3976,32 @@ diskio::save_hdf5_binary(const Cube& x, const hdf5_name& spec, std::string& if(use_existing_file && replace) { - arma_H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); + H5Ldelete(last_group, dataset_name.c_str(), H5P_DEFAULT); // NOTE: H5Ldelete() in HDF5 v1.8 doesn't reclaim the deleted space; use h5repack to reclaim space: h5repack oldfile.h5 newfile.h5 // NOTE: has this behaviour changed in HDF5 1.10 ? // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010482.html // NOTE: https://lists.hdfgroup.org/pipermail/hdf-forum_lists.hdfgroup.org/2017-August/010486.html } - hid_t dataset = arma_H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + hid_t dataset = H5Dcreate(last_group, dataset_name.c_str(), datatype, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); if(dataset < 0) { save_okay = false; - err_msg = "couldn't create dataset in "; + err_msg = "failed to create dataset"; } else { - save_okay = (arma_H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); + save_okay = (H5Dwrite(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, x.mem) >= 0); - arma_H5Dclose(dataset); + H5Dclose(dataset); } - arma_H5Tclose(datatype); - arma_H5Sclose(dataspace); - for (size_t i = 0; i < groups.size(); ++i) { arma_H5Gclose(groups[i]); } - arma_H5Fclose(file); + H5Tclose(datatype); + H5Sclose(dataspace); + for(size_t i = 0; i < groups.size(); ++i) { H5Gclose(groups[i]); } + H5Fclose(file); if((use_existing_file == false) && (save_okay == true)) { save_okay = diskio::safe_rename(tmp_name, spec.filename); } @@ -3342,7 +4038,7 @@ diskio::load_raw_ascii(Cube& x, const std::string& name, std::string& err_ms { if(tmp.is_empty() == false) { - x.set_size(tmp.n_rows, tmp.n_cols, 1); + try { x.set_size(tmp.n_rows, tmp.n_cols, 1); } catch(...) { err_msg = "not enough memory"; return false; } x.slice(0) = tmp; } @@ -3373,7 +4069,7 @@ diskio::load_raw_ascii(Cube& x, std::istream& f, std::string& err_msg) { if(tmp.is_empty() == false) { - x.set_size(tmp.n_rows, tmp.n_cols, 1); + try { x.set_size(tmp.n_rows, tmp.n_cols, 1); } catch(...) { err_msg = "not enough memory"; return false; } x.slice(0) = tmp; } @@ -3398,7 +4094,7 @@ diskio::load_raw_binary(Cube& x, const std::string& name, std::string& err_m arma_extra_debug_sigprint(); std::ifstream f; - f.open(name.c_str(), std::fstream::binary); + f.open(name, std::fstream::binary); bool load_okay = f.is_open(); @@ -3419,7 +4115,6 @@ bool diskio::load_raw_binary(Cube& x, std::istream& f, std::string& err_msg) { arma_extra_debug_sigprint(); - arma_ignore(err_msg); f.clear(); const std::streampos pos1 = f.tellg(); @@ -3436,7 +4131,7 @@ diskio::load_raw_binary(Cube& x, std::istream& f, std::string& err_msg) //f.seekg(0, ios::beg); f.seekg(pos1); - x.set_size(N / uword(sizeof(eT)), 1, 1); + try { x.set_size(N / uword(sizeof(eT)), 1, 1); } catch(...) { err_msg = "not enough memory"; return false; } f.clear(); f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem * uword(sizeof(eT))) ); @@ -3455,7 +4150,9 @@ diskio::load_arma_ascii(Cube& x, const std::string& name, std::string& err_m { arma_extra_debug_sigprint(); - std::ifstream f(name.c_str()); + std::ifstream f; + + (arma_config::text_as_binary) ? f.open(name, std::fstream::binary) : f.open(name); bool load_okay = f.is_open(); @@ -3480,7 +4177,7 @@ diskio::load_arma_ascii(Cube& x, std::istream& f, std::string& err_msg) arma_extra_debug_sigprint(); std::streampos pos = f.tellg(); - + bool load_okay = true; std::string f_header; @@ -3495,7 +4192,7 @@ diskio::load_arma_ascii(Cube& x, std::istream& f, std::string& err_msg) if(f_header == diskio::gen_txt_header(x)) { - x.set_size(f_n_rows, f_n_cols, f_n_slices); + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } for(uword slice = 0; slice < x.n_slices; ++slice) for(uword row = 0; row < x.n_rows; ++row ) @@ -3509,7 +4206,7 @@ diskio::load_arma_ascii(Cube& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "incorrect header in "; + err_msg = "incorrect header"; } @@ -3559,7 +4256,7 @@ diskio::load_arma_binary(Cube& x, const std::string& name, std::string& err_ arma_extra_debug_sigprint(); std::ifstream f; - f.open(name.c_str(), std::fstream::binary); + f.open(name, std::fstream::binary); bool load_okay = f.is_open(); @@ -3600,7 +4297,8 @@ diskio::load_arma_binary(Cube& x, std::istream& f, std::string& err_msg) //f.seekg(1, ios::cur); // NOTE: this may not be portable, as on a Windows machine a newline could be two characters f.get(); - x.set_size(f_n_rows, f_n_cols, f_n_slices); + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } + f.read( reinterpret_cast(x.memptr()), std::streamsize(x.n_elem*sizeof(eT)) ); load_okay = f.good(); @@ -3608,7 +4306,7 @@ diskio::load_arma_binary(Cube& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "incorrect header in "; + err_msg = "incorrect header"; } @@ -3655,15 +4353,17 @@ bool diskio::load_hdf5_binary(Cube& x, const hdf5_name& spec, std::string& err_msg) { arma_extra_debug_sigprint(); - + #if defined(ARMA_USE_HDF5) { + if(diskio::is_readable(spec.filename) == false) { return false; } + hdf5_misc::hdf5_suspend_printing_errors hdf5_print_suspender; - + bool load_okay = false; - - hid_t fid = arma_H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); - + + hid_t fid = H5Fopen(spec.filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + if(fid >= 0) { // MATLAB HDF5 dataset names are user-specified; @@ -3689,71 +4389,71 @@ diskio::load_hdf5_binary(Cube& x, const hdf5_name& spec, std::string& err_ms if(dataset >= 0) { - hid_t filespace = arma_H5Dget_space(dataset); - + hid_t filespace = H5Dget_space(dataset); + // This must be <= 3 due to our search rules. - const int ndims = arma_H5Sget_simple_extent_ndims(filespace); - + const int ndims = H5Sget_simple_extent_ndims(filespace); + hsize_t dims[3]; - const herr_t query_status = arma_H5Sget_simple_extent_dims(filespace, dims, NULL); - + const herr_t query_status = H5Sget_simple_extent_dims(filespace, dims, NULL); + // arma_check(query_status < 0, "Cube::load(): cannot get size of HDF5 dataset"); if(query_status < 0) { - err_msg = "cannot get size of HDF5 dataset in "; - - arma_H5Sclose(filespace); - arma_H5Dclose(dataset); - arma_H5Fclose(fid); - + err_msg = "cannot get size of HDF5 dataset"; + + H5Sclose(filespace); + H5Dclose(dataset); + H5Fclose(fid); + return false; } - - if (ndims == 1) { dims[1] = 1; dims[2] = 1; } // Vector case; one row/colum, several slices - if (ndims == 2) { dims[2] = 1; } // Matrix case; one column, several rows/slices - - x.set_size(dims[2], dims[1], dims[0]); - + + if(ndims == 1) { dims[1] = 1; dims[2] = 1; } // Vector case; one row/colum, several slices + if(ndims == 2) { dims[2] = 1; } // Matrix case; one column, several rows/slices + + try { x.set_size(dims[2], dims[1], dims[0]); } catch(...) { err_msg = "not enough memory"; return false; } + // Now we have to see what type is stored to figure out how to load it. - hid_t datatype = arma_H5Dget_type(dataset); + hid_t datatype = H5Dget_type(dataset); hid_t mat_type = hdf5_misc::get_hdf5_type(); - + // If these are the same type, it is simple. - if(arma_H5Tequal(datatype, mat_type) > 0) + if(H5Tequal(datatype, mat_type) > 0) { // Load directly; H5S_ALL used so that we load the entire dataset. - hid_t read_status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); - + hid_t read_status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(x.memptr())); + if(read_status >= 0) { load_okay = true; } } else { // Load into another array and convert its type accordingly. hid_t read_status = hdf5_misc::load_and_convert_hdf5(x.memptr(), dataset, datatype, x.n_elem); - + if(read_status >= 0) { load_okay = true; } } - + // Now clean up. - arma_H5Tclose(datatype); - arma_H5Tclose(mat_type); - arma_H5Sclose(filespace); + H5Tclose(datatype); + H5Tclose(mat_type); + H5Sclose(filespace); } - - arma_H5Dclose(dataset); - - arma_H5Fclose(fid); - + + H5Dclose(dataset); + + H5Fclose(fid); + if(load_okay == false) { - err_msg = "unsupported or missing HDF5 data in "; + err_msg = "unsupported or missing HDF5 data"; } } else { - err_msg = "cannot open file "; + err_msg = "cannot open"; } - + return load_okay; } #else @@ -3779,13 +4479,15 @@ diskio::load_auto_detect(Cube& x, const std::string& name, std::string& err_ { arma_extra_debug_sigprint(); + if(diskio::is_readable(name) == false) { return false; } + #if defined(ARMA_USE_HDF5) // We're currently using the C bindings for the HDF5 library, which don't support C++ streams - if( arma_H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } + if( H5Fis_hdf5(name.c_str()) ) { return load_hdf5_binary(x, name, err_msg); } #endif - + std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -3844,7 +4546,7 @@ diskio::load_auto_detect(Cube& x, std::istream& f, std::string& err_msg) } else { - const file_type ft = guess_file_type(f); + const file_type ft = guess_file_type_internal(f); switch(ft) { @@ -3861,7 +4563,7 @@ diskio::load_auto_detect(Cube& x, std::istream& f, std::string& err_msg) break; default: - err_msg = "unknown data in "; + err_msg = "unknown data"; return false; } } @@ -3886,7 +4588,7 @@ diskio::save_arma_binary(const field& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f( tmp_name.c_str(), std::fstream::binary ); + std::ofstream f( tmp_name, std::fstream::binary ); bool save_okay = f.is_open(); @@ -3949,7 +4651,7 @@ diskio::load_arma_binary(field& x, const std::string& name, std::string& err { arma_extra_debug_sigprint(); - std::ifstream f( name.c_str(), std::fstream::binary ); + std::ifstream f( name, std::fstream::binary ); bool load_okay = f.is_open(); @@ -3982,13 +4684,13 @@ diskio::load_arma_binary(field& x, std::istream& f, std::string& err_msg) { uword f_n_rows; uword f_n_cols; - + f >> f_n_rows; f >> f_n_cols; - x.set_size(f_n_rows, f_n_cols); + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } - f.get(); + f.get(); for(uword i=0; i& x, std::istream& f, std::string& err_msg) uword f_n_rows; uword f_n_cols; uword f_n_slices; - + f >> f_n_rows; f >> f_n_cols; f >> f_n_slices; - x.set_size(f_n_rows, f_n_cols, f_n_slices); + try { x.set_size(f_n_rows, f_n_cols, f_n_slices); } catch(...) { err_msg = "not enough memory"; return false; } - f.get(); + f.get(); for(uword i=0; i& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "unsupported field type in "; + err_msg = "unsupported field type"; } return load_okay; @@ -4038,7 +4740,7 @@ diskio::save_std_string(const field& x, const std::string& final_na const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f( tmp_name.c_str(), std::fstream::binary ); + std::ofstream f( tmp_name, std::fstream::binary ); bool save_okay = f.is_open(); @@ -4089,7 +4791,7 @@ diskio::load_std_string(field& x, const std::string& name, std::str { arma_extra_debug_sigprint(); - std::ifstream f( name.c_str() ); + std::ifstream f(name); bool load_okay = f.is_open(); @@ -4133,7 +4835,7 @@ diskio::load_std_string(field& x, std::istream& f, std::string& err uword line_n_cols = 0; - while (line_stream >> token) { line_n_cols++; } + while(line_stream >> token) { line_n_cols++; } if(f_n_cols_found == false) { @@ -4145,7 +4847,7 @@ diskio::load_std_string(field& x, std::istream& f, std::string& err if(line_n_cols != f_n_cols) { load_okay = false; - err_msg = "inconsistent number of columns in "; + err_msg = "inconsistent number of columns"; } } @@ -4158,8 +4860,8 @@ diskio::load_std_string(field& x, std::istream& f, std::string& err f.seekg(0, ios::beg); //f.seekg(start); - x.set_size(f_n_rows, f_n_cols); - + try { x.set_size(f_n_rows, f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + for(uword row=0; row < x.n_rows; ++row) for(uword col=0; col < x.n_cols; ++col) { @@ -4183,7 +4885,7 @@ diskio::load_auto_detect(field& x, const std::string& name, std::string& err arma_extra_debug_sigprint(); std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -4241,7 +4943,7 @@ diskio::load_auto_detect(field& x, std::istream& f, std::string& err_msg) } else { - err_msg = "unsupported header in "; + err_msg = "unsupported header"; return false; } } @@ -4260,7 +4962,7 @@ diskio::load_ppm_binary(Cube& x, const std::string& name, std::string& err_m arma_extra_debug_sigprint(); std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -4293,21 +4995,21 @@ diskio::load_ppm_binary(Cube& x, std::istream& f, std::string& err_msg) uword f_n_rows = 0; uword f_n_cols = 0; int f_maxval = 0; - + diskio::pnm_skip_comments(f); - + f >> f_n_cols; diskio::pnm_skip_comments(f); - + f >> f_n_rows; diskio::pnm_skip_comments(f); - + f >> f_maxval; f.get(); - if( (f_maxval > 0) || (f_maxval <= 65535) ) + if( (f_maxval > 0) && (f_maxval <= 65535) ) { - x.set_size(f_n_rows, f_n_cols, 3); + try { x.set_size(f_n_rows, f_n_cols, 3); } catch(...) { err_msg = "not enough memory"; return false; } if(f_maxval <= 255) { @@ -4352,15 +5054,15 @@ diskio::load_ppm_binary(Cube& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "currently no code available to handle loading "; + err_msg = "functionality unimplemented"; } - + if(f.good() == false) { load_okay = false; } } else { load_okay = false; - err_msg = "unsupported header in "; + err_msg = "unsupported header"; } return load_okay; @@ -4377,7 +5079,7 @@ diskio::save_ppm_binary(const Cube& x, const std::string& final_name) const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f( tmp_name.c_str(), std::fstream::binary ); + std::ofstream f( tmp_name, std::fstream::binary ); bool save_okay = f.is_open(); @@ -4446,7 +5148,7 @@ diskio::load_ppm_binary(field& x, const std::string& name, std::string& err_ arma_extra_debug_sigprint(); std::fstream f; - f.open(name.c_str(), std::fstream::in | std::fstream::binary); + f.open(name, std::fstream::in | std::fstream::binary); bool load_okay = f.is_open(); @@ -4482,28 +5184,28 @@ diskio::load_ppm_binary(field& x, std::istream& f, std::string& err_msg) uword f_n_rows = 0; uword f_n_cols = 0; int f_maxval = 0; - + diskio::pnm_skip_comments(f); - + f >> f_n_cols; diskio::pnm_skip_comments(f); - + f >> f_n_rows; diskio::pnm_skip_comments(f); - + f >> f_maxval; f.get(); - if( (f_maxval > 0) || (f_maxval <= 65535) ) + if( (f_maxval > 0) && (f_maxval <= 65535) ) { x.set_size(3); Mat& R = x(0); Mat& G = x(1); Mat& B = x(2); - R.set_size(f_n_rows,f_n_cols); - G.set_size(f_n_rows,f_n_cols); - B.set_size(f_n_rows,f_n_cols); + try { R.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + try { G.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } + try { B.set_size(f_n_rows,f_n_cols); } catch(...) { err_msg = "not enough memory"; return false; } if(f_maxval <= 255) { @@ -4552,7 +5254,7 @@ diskio::load_ppm_binary(field& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "currently no code available to handle loading "; + err_msg = "functionality unimplemented"; } if(f.good() == false) { load_okay = false; } @@ -4560,7 +5262,7 @@ diskio::load_ppm_binary(field& x, std::istream& f, std::string& err_msg) else { load_okay = false; - err_msg = "unsupported header in "; + err_msg = "unsupported header"; } return load_okay; @@ -4576,7 +5278,7 @@ diskio::save_ppm_binary(const field& x, const std::string& final_name) arma_extra_debug_sigprint(); const std::string tmp_name = diskio::gen_tmp_name(final_name); - std::ofstream f( tmp_name.c_str(), std::fstream::binary ); + std::ofstream f( tmp_name, std::fstream::binary ); bool save_okay = f.is_open(); diff --git a/src/armadillo_bits/distr_param.hpp b/src/armadillo_bits/distr_param.hpp index 7208469c..61f3c234 100644 --- a/src/armadillo_bits/distr_param.hpp +++ b/src/armadillo_bits/distr_param.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -24,41 +26,64 @@ class distr_param { public: - uword state; + const uword state; - union - { - int a_int; - double a_double; - }; + private: - union - { - int b_int; - double b_double; - }; + int a_int; + int b_int; + + double a_double; + double b_double; + public: inline distr_param() - : state(0) + : state (0) + , a_int (0) + , b_int (0) + , a_double(0) + , b_double(0) { } inline explicit distr_param(const int a, const int b) - : state(1) - , a_int(a) - , b_int(b) + : state (1) + , a_int (a) + , b_int (b) + , a_double(double(a)) + , b_double(double(b)) { } inline explicit distr_param(const double a, const double b) - : state(2) + : state (2) + , a_int (int(a)) + , b_int (int(b)) , a_double(a) , b_double(b) { } + + + inline void get_int_vals(int& out_a, int& out_b) const + { + if(state == 0) { return; } + + out_a = a_int; + out_b = b_int; + } + + + inline void get_double_vals(double& out_a, double& out_b) const + { + if(state == 0) { return; } + + out_a = a_double; + out_b = b_double; + } }; diff --git a/src/armadillo_bits/eGlueCube_bones.hpp b/src/armadillo_bits/eGlueCube_bones.hpp index c0aa4269..8c157d8f 100644 --- a/src/armadillo_bits/eGlueCube_bones.hpp +++ b/src/armadillo_bits/eGlueCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,16 +21,16 @@ template -class eGlueCube : public BaseCube > +class eGlueCube : public BaseCube< typename T1::elem_type, eGlueCube > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - static const bool use_at = (ProxyCube::use_at || ProxyCube::use_at ); - static const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp ); - static const bool has_subview = (ProxyCube::has_subview || ProxyCube::has_subview); + static constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at ); + static constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp ); + static constexpr bool has_subview = (ProxyCube::has_subview || ProxyCube::has_subview); arma_aligned const ProxyCube P1; arma_aligned const ProxyCube P2; diff --git a/src/armadillo_bits/eGlueCube_meat.hpp b/src/armadillo_bits/eGlueCube_meat.hpp index 312c8e19..59b30d39 100644 --- a/src/armadillo_bits/eGlueCube_meat.hpp +++ b/src/armadillo_bits/eGlueCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/eGlue_bones.hpp b/src/armadillo_bits/eGlue_bones.hpp index c4d96ec9..097dc6cb 100644 --- a/src/armadillo_bits/eGlue_bones.hpp +++ b/src/armadillo_bits/eGlue_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,7 @@ template -class eGlue : public Base > +class eGlue : public Base< typename T1::elem_type, eGlue > { public: @@ -28,14 +30,13 @@ class eGlue : public Base > typedef Proxy proxy1_type; typedef Proxy proxy2_type; - static const bool use_at = (Proxy::use_at || Proxy::use_at ); - static const bool use_mp = (Proxy::use_mp || Proxy::use_mp ); - static const bool has_subview = (Proxy::has_subview || Proxy::has_subview); - static const bool fake_mat = (Proxy::fake_mat || Proxy::fake_mat ); + static constexpr bool use_at = (Proxy::use_at || Proxy::use_at ); + static constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp ); + static constexpr bool has_subview = (Proxy::has_subview || Proxy::has_subview); - static const bool is_col = (Proxy::is_col || Proxy::is_col ); - static const bool is_row = (Proxy::is_row || Proxy::is_row ); - static const bool is_xvec = (Proxy::is_xvec || Proxy::is_xvec); + static constexpr bool is_col = (Proxy::is_col || Proxy::is_col ); + static constexpr bool is_row = (Proxy::is_row || Proxy::is_row ); + static constexpr bool is_xvec = (Proxy::is_xvec || Proxy::is_xvec); arma_aligned const Proxy P1; arma_aligned const Proxy P2; diff --git a/src/armadillo_bits/eGlue_meat.hpp b/src/armadillo_bits/eGlue_meat.hpp index 0e88489c..30fb5072 100644 --- a/src/armadillo_bits/eGlue_meat.hpp +++ b/src/armadillo_bits/eGlue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/eOpCube_bones.hpp b/src/armadillo_bits/eOpCube_bones.hpp index 8639a042..b6bcaba4 100644 --- a/src/armadillo_bits/eOpCube_bones.hpp +++ b/src/armadillo_bits/eOpCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class eOpCube : public BaseCube > +class eOpCube : public BaseCube< typename T1::elem_type, eOpCube > { public: typedef typename T1::elem_type elem_type; typedef typename get_pod_type::result pod_type; - static const bool use_at = ProxyCube::use_at; - static const bool use_mp = ProxyCube::use_mp || eop_type::use_mp; - static const bool has_subview = ProxyCube::has_subview; + static constexpr bool use_at = ProxyCube::use_at; + static constexpr bool use_mp = ProxyCube::use_mp || eop_type::use_mp; + static constexpr bool has_subview = ProxyCube::has_subview; arma_aligned const ProxyCube P; arma_aligned elem_type aux; //!< storage of auxiliary data, user defined format diff --git a/src/armadillo_bits/eOpCube_meat.hpp b/src/armadillo_bits/eOpCube_meat.hpp index efabbef7..6a165f44 100644 --- a/src/armadillo_bits/eOpCube_meat.hpp +++ b/src/armadillo_bits/eOpCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/eOp_bones.hpp b/src/armadillo_bits/eOp_bones.hpp index 602a3bb2..d32abddb 100644 --- a/src/armadillo_bits/eOp_bones.hpp +++ b/src/armadillo_bits/eOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,7 @@ template -class eOp : public Base > +class eOp : public Base< typename T1::elem_type, eOp > { public: @@ -28,14 +30,13 @@ class eOp : public Base > typedef typename get_pod_type::result pod_type; typedef Proxy proxy_type; - static const bool use_at = Proxy::use_at; - static const bool use_mp = Proxy::use_mp || eop_type::use_mp; - static const bool has_subview = Proxy::has_subview; - static const bool fake_mat = Proxy::fake_mat; + static constexpr bool use_at = Proxy::use_at; + static constexpr bool use_mp = Proxy::use_mp || eop_type::use_mp; + static constexpr bool has_subview = Proxy::has_subview; - static const bool is_row = Proxy::is_row; - static const bool is_col = Proxy::is_col; - static const bool is_xvec = Proxy::is_xvec; + static constexpr bool is_row = Proxy::is_row; + static constexpr bool is_col = Proxy::is_col; + static constexpr bool is_xvec = Proxy::is_xvec; arma_aligned const Proxy P; diff --git a/src/armadillo_bits/eOp_meat.hpp b/src/armadillo_bits/eOp_meat.hpp index f5314d61..e087505c 100644 --- a/src/armadillo_bits/eOp_meat.hpp +++ b/src/armadillo_bits/eOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/eglue_core_bones.hpp b/src/armadillo_bits/eglue_core_bones.hpp index ad826bd5..67db2438 100644 --- a/src/armadillo_bits/eglue_core_bones.hpp +++ b/src/armadillo_bits/eglue_core_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/eglue_core_meat.hpp b/src/armadillo_bits/eglue_core_meat.hpp index 9de342e4..a36f3895 100644 --- a/src/armadillo_bits/eglue_core_meat.hpp +++ b/src/armadillo_bits/eglue_core_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -177,7 +179,7 @@ -#if (defined(ARMA_USE_OPENMP) && defined(ARMA_USE_CXX11)) +#if defined(ARMA_USE_OPENMP) #define arma_applier_1_mp(operatorA, operatorB) \ {\ @@ -253,7 +255,6 @@ template template -arma_hot inline void eglue_core::apply(outT& out, const eGlue& x) @@ -262,8 +263,8 @@ eglue_core::apply(outT& out, const eGlue& x) typedef typename T1::elem_type eT; - const bool use_at = (Proxy::use_at || Proxy::use_at); - const bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); // NOTE: we're assuming that the matrix has already been set to the correct size and there is no aliasing; // size setting and alias checking is done by either the Mat contructor or operator=() @@ -353,7 +354,6 @@ eglue_core::apply(outT& out, const eGlue& x) template template -arma_hot inline void eglue_core::apply_inplace_plus(Mat& out, const eGlue& x) @@ -369,8 +369,8 @@ eglue_core::apply_inplace_plus(Mat& out, con eT* out_mem = out.memptr(); - const bool use_at = (Proxy::use_at || Proxy::use_at); - const bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -451,7 +451,6 @@ eglue_core::apply_inplace_plus(Mat& out, con template template -arma_hot inline void eglue_core::apply_inplace_minus(Mat& out, const eGlue& x) @@ -467,8 +466,8 @@ eglue_core::apply_inplace_minus(Mat& out, co eT* out_mem = out.memptr(); - const bool use_at = (Proxy::use_at || Proxy::use_at); - const bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -549,7 +548,6 @@ eglue_core::apply_inplace_minus(Mat& out, co template template -arma_hot inline void eglue_core::apply_inplace_schur(Mat& out, const eGlue& x) @@ -565,8 +563,8 @@ eglue_core::apply_inplace_schur(Mat& out, co eT* out_mem = out.memptr(); - const bool use_at = (Proxy::use_at || Proxy::use_at); - const bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -647,7 +645,6 @@ eglue_core::apply_inplace_schur(Mat& out, co template template -arma_hot inline void eglue_core::apply_inplace_div(Mat& out, const eGlue& x) @@ -663,8 +660,8 @@ eglue_core::apply_inplace_div(Mat& out, cons eT* out_mem = out.memptr(); - const bool use_at = (Proxy::use_at || Proxy::use_at); - const bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (Proxy::use_at || Proxy::use_at); + constexpr bool use_mp = (Proxy::use_mp || Proxy::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -750,7 +747,6 @@ eglue_core::apply_inplace_div(Mat& out, cons template template -arma_hot inline void eglue_core::apply(Cube& out, const eGlueCube& x) @@ -759,8 +755,8 @@ eglue_core::apply(Cube& out, const eGlueCube typedef typename T1::elem_type eT; - const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); - const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); // NOTE: we're assuming that the cube has already been set to the correct size and there is no aliasing; // size setting and alias checking is done by either the Cube contructor or operator=() @@ -851,7 +847,6 @@ eglue_core::apply(Cube& out, const eGlueCube template template -arma_hot inline void eglue_core::apply_inplace_plus(Cube& out, const eGlueCube& x) @@ -868,8 +863,8 @@ eglue_core::apply_inplace_plus(Cube& out, co eT* out_mem = out.memptr(); - const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); - const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -950,7 +945,6 @@ eglue_core::apply_inplace_plus(Cube& out, co template template -arma_hot inline void eglue_core::apply_inplace_minus(Cube& out, const eGlueCube& x) @@ -967,8 +961,8 @@ eglue_core::apply_inplace_minus(Cube& out, c eT* out_mem = out.memptr(); - const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); - const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -1049,7 +1043,6 @@ eglue_core::apply_inplace_minus(Cube& out, c template template -arma_hot inline void eglue_core::apply_inplace_schur(Cube& out, const eGlueCube& x) @@ -1066,8 +1059,8 @@ eglue_core::apply_inplace_schur(Cube& out, c eT* out_mem = out.memptr(); - const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); - const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); if(use_at == false) { @@ -1148,7 +1141,6 @@ eglue_core::apply_inplace_schur(Cube& out, c template template -arma_hot inline void eglue_core::apply_inplace_div(Cube& out, const eGlueCube& x) @@ -1165,8 +1157,8 @@ eglue_core::apply_inplace_div(Cube& out, con eT* out_mem = out.memptr(); - const bool use_at = (ProxyCube::use_at || ProxyCube::use_at); - const bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::cxx11 && arma_config::openmp); + constexpr bool use_at = (ProxyCube::use_at || ProxyCube::use_at); + constexpr bool use_mp = (ProxyCube::use_mp || ProxyCube::use_mp) && (arma_config::openmp); if(use_at == false) { diff --git a/src/armadillo_bits/eop_aux.hpp b/src/armadillo_bits/eop_aux.hpp index c95a6f6e..2b66ef2a 100644 --- a/src/armadillo_bits/eop_aux.hpp +++ b/src/armadillo_bits/eop_aux.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -34,17 +36,17 @@ class eop_aux template arma_inline static typename arma_real_only::result asin (const eT x) { return std::asin(x); } template arma_inline static typename arma_real_only::result atan (const eT x) { return std::atan(x); } - template arma_inline static typename arma_cx_only::result acos (const eT x) { return arma_acos(x); } - template arma_inline static typename arma_cx_only::result asin (const eT x) { return arma_asin(x); } - template arma_inline static typename arma_cx_only::result atan (const eT x) { return arma_atan(x); } + template arma_inline static typename arma_cx_only::result acos (const eT x) { return std::acos(x); } + template arma_inline static typename arma_cx_only::result asin (const eT x) { return std::asin(x); } + template arma_inline static typename arma_cx_only::result atan (const eT x) { return std::atan(x); } - template arma_inline static typename arma_integral_only::result acosh (const eT x) { return eT( arma_acosh(double(x)) ); } - template arma_inline static typename arma_integral_only::result asinh (const eT x) { return eT( arma_asinh(double(x)) ); } - template arma_inline static typename arma_integral_only::result atanh (const eT x) { return eT( arma_atanh(double(x)) ); } + template arma_inline static typename arma_integral_only::result acosh (const eT x) { return eT( std::acosh(double(x)) ); } + template arma_inline static typename arma_integral_only::result asinh (const eT x) { return eT( std::asinh(double(x)) ); } + template arma_inline static typename arma_integral_only::result atanh (const eT x) { return eT( std::atanh(double(x)) ); } - template arma_inline static typename arma_real_or_cx_only::result acosh (const eT x) { return arma_acosh(x); } - template arma_inline static typename arma_real_or_cx_only::result asinh (const eT x) { return arma_asinh(x); } - template arma_inline static typename arma_real_or_cx_only::result atanh (const eT x) { return arma_atanh(x); } + template arma_inline static typename arma_real_or_cx_only::result acosh (const eT x) { return std::acosh(x); } + template arma_inline static typename arma_real_or_cx_only::result asinh (const eT x) { return std::asinh(x); } + template arma_inline static typename arma_real_or_cx_only::result atanh (const eT x) { return std::atanh(x); } template arma_inline static typename arma_not_cx::result conj(const eT x) { return x; } template arma_inline static std::complex conj(const std::complex& x) { return std::conj(x); } @@ -82,123 +84,56 @@ class eop_aux template arma_inline static typename arma_real_only::result ceil (const eT x) { return std::ceil(x); } template arma_inline static typename arma_cx_only::result ceil (const eT& x) { return eT( std::ceil(x.real()), std::ceil(x.imag()) ); } - - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result round (const eT x) { return x; } template arma_inline static typename arma_real_only::result round (const eT x) { return std::round(x); } template arma_inline static typename arma_cx_only::result round (const eT& x) { return eT( std::round(x.real()), std::round(x.imag()) ); } - #else - template arma_inline static typename arma_integral_only::result round (const eT x) { return x; } - template arma_inline static typename arma_real_only::result round (const eT x) { return (x >= eT(0)) ? std::floor(x+0.5) : std::ceil(x-0.5); } - template arma_inline static typename arma_cx_only::result round (const eT& x) { return eT( eop_aux::round(x.real()), eop_aux::round(x.imag()) ); } - #endif - - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result trunc (const eT x) { return x; } template arma_inline static typename arma_real_only::result trunc (const eT x) { return std::trunc(x); } template arma_inline static typename arma_cx_only::result trunc (const eT& x) { return eT( std::trunc(x.real()), std::trunc(x.imag()) ); } - #else - template arma_inline static typename arma_integral_only::result trunc (const eT x) { return x; } - template arma_inline static typename arma_real_only::result trunc (const eT x) { return (x >= eT(0)) ? std::floor(x) : std::ceil(x); } - template arma_inline static typename arma_cx_only::result trunc (const eT& x) { return eT( eop_aux::trunc(x.real()), eop_aux::trunc(x.imag()) ); } - #endif - - #if defined(ARMA_USE_CXX11) - template arma_inline static typename arma_integral_only::result log2 (const eT x) { return eT( std::log(double(x))/ double(0.69314718055994530942) ); } + template arma_inline static typename arma_integral_only::result log2 (const eT x) { return eT( std::log2(double(x)) ); } template arma_inline static typename arma_real_only::result log2 (const eT x) { return std::log2(x); } template arma_inline static typename arma_cx_only::result log2 (const eT& x) { typedef typename get_pod_type::result T; return std::log(x) / T(0.69314718055994530942); } - #else - template arma_inline static typename arma_integral_only::result log2 (const eT x) { return eT( std::log(double(x))/ double(0.69314718055994530942) ); } - template arma_inline static typename arma_real_or_cx_only::result log2 (const eT x) { typedef typename get_pod_type::result T; return std::log(x) / T(0.69314718055994530942); } - #endif - - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result log1p (const eT x) { return eT( std::log1p(double(x)) ); } template arma_inline static typename arma_real_only::result log1p (const eT x) { return std::log1p(x); } template arma_inline static typename arma_cx_only::result log1p (const eT& x) { arma_ignore(x); return eT(0); } - #elif defined(ARMA_HAVE_TR1) - template arma_inline static typename arma_integral_only::result log1p (const eT x) { return eT( std::tr1::log1p(double(x)) ); } - template arma_inline static typename arma_real_only::result log1p (const eT x) { return std::tr1::log1p(x); } - template arma_inline static typename arma_cx_only::result log1p (const eT& x) { arma_ignore(x); return eT(0); } - #else - template arma_inline static eT log1p (const eT x) { arma_ignore(x); arma_stop_logic_error("log1p(): C++11 compiler required"); return eT(0); } - #endif - - #if defined(ARMA_USE_CXX11) - template arma_inline static typename arma_integral_only::result exp2 (const eT x) { return eT( std::pow(double(2), double(x)) ); } + template arma_inline static typename arma_integral_only::result exp2 (const eT x) { return eT( std::exp2(double(x)) ); } template arma_inline static typename arma_real_only::result exp2 (const eT x) { return std::exp2(x); } template arma_inline static typename arma_cx_only::result exp2 (const eT& x) { typedef typename get_pod_type::result T; return std::pow( T(2), x); } - #else - template arma_inline static typename arma_integral_only::result exp2 (const eT x) { return eT( std::pow(double(2), double(x)) ); } - template arma_inline static typename arma_real_or_cx_only::result exp2 (const eT x) { typedef typename get_pod_type::result T; return std::pow( T(2), x); } - #endif - template arma_inline static typename arma_integral_only::result exp10 (const eT x) { return eT( std::pow(double(10), double(x)) ); } template arma_inline static typename arma_real_or_cx_only::result exp10 (const eT x) { typedef typename get_pod_type::result T; return std::pow( T(10), x); } - - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result expm1 (const eT x) { return eT( std::expm1(double(x)) ); } template arma_inline static typename arma_real_only::result expm1 (const eT x) { return std::expm1(x); } template arma_inline static typename arma_cx_only::result expm1 (const eT& x) { arma_ignore(x); return eT(0); } - #elif defined(ARMA_HAVE_TR1) - template arma_inline static typename arma_integral_only::result expm1 (const eT x) { return eT( std::tr1::expm1(double(x)) ); } - template arma_inline static typename arma_real_only::result expm1 (const eT x) { return std::tr1::expm1(x); } - template arma_inline static typename arma_cx_only::result expm1 (const eT& x) { arma_ignore(x); return eT(0); } - #else - template arma_inline static eT expm1 (const eT x) { arma_ignore(x); arma_stop_logic_error("expm1(): C++11 compiler required"); return eT(0); } - #endif - template arma_inline static typename arma_unsigned_integral_only::result arma_abs (const eT x) { return x; } template arma_inline static typename arma_signed_integral_only::result arma_abs (const eT x) { return std::abs(x); } template arma_inline static typename arma_real_only::result arma_abs (const eT x) { return std::abs(x); } template arma_inline static typename arma_real_only< T>::result arma_abs (const std::complex& x) { return std::abs(x); } - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result erf (const eT x) { return eT( std::erf(double(x)) ); } template arma_inline static typename arma_real_only::result erf (const eT x) { return std::erf(x); } template arma_inline static typename arma_cx_only::result erf (const eT& x) { arma_ignore(x); return eT(0); } - #elif defined(ARMA_HAVE_TR1) - template arma_inline static typename arma_integral_only::result erf (const eT x) { return eT( std::tr1::erf(double(x)) ); } - template arma_inline static typename arma_real_only::result erf (const eT x) { return std::tr1::erf(x); } - template arma_inline static typename arma_cx_only::result erf (const eT& x) { arma_ignore(x); return eT(0); } - #else - template arma_inline static eT erf (const eT x) { arma_ignore(x); arma_stop_logic_error("erf(): C++11 compiler required"); return eT(0); } - #endif - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result erfc (const eT x) { return eT( std::erfc(double(x)) ); } template arma_inline static typename arma_real_only::result erfc (const eT x) { return std::erfc(x); } template arma_inline static typename arma_cx_only::result erfc (const eT& x) { arma_ignore(x); return eT(0); } - #elif defined(ARMA_HAVE_TR1) - template arma_inline static typename arma_integral_only::result erfc (const eT x) { return eT( std::tr1::erfc(double(x)) ); } - template arma_inline static typename arma_real_only::result erfc (const eT x) { return std::tr1::erfc(x); } - template arma_inline static typename arma_cx_only::result erfc (const eT& x) { arma_ignore(x); return eT(0); } - #else - template arma_inline static eT erfc (const eT x) { arma_ignore(x); arma_stop_logic_error("erfc(): C++11 compiler required"); return eT(0); } - #endif - #if defined(ARMA_USE_CXX11) template arma_inline static typename arma_integral_only::result lgamma (const eT x) { return eT( std::lgamma(double(x)) ); } template arma_inline static typename arma_real_only::result lgamma (const eT x) { return std::lgamma(x); } template arma_inline static typename arma_cx_only::result lgamma (const eT& x) { arma_ignore(x); return eT(0); } - #elif defined(ARMA_HAVE_TR1) - template arma_inline static typename arma_integral_only::result lgamma (const eT x) { return eT( std::tr1::lgamma(double(x)) ); } - template arma_inline static typename arma_real_only::result lgamma (const eT x) { return std::tr1::lgamma(x); } - template arma_inline static typename arma_cx_only::result lgamma (const eT& x) { arma_ignore(x); return eT(0); } - #else - template arma_inline static eT lgamma (const eT x) { arma_ignore(x); arma_stop_logic_error("lgamma(): C++11 compiler required"); return eT(0); } - #endif + template arma_inline static typename arma_integral_only::result tgamma (const eT x) { return eT( std::tgamma(double(x)) ); } + template arma_inline static typename arma_real_only::result tgamma (const eT x) { return std::tgamma(x); } + template arma_inline static typename arma_cx_only::result tgamma (const eT& x) { arma_ignore(x); return eT(0); } template arma_inline static typename arma_integral_only::result pow (const T1 base, const T2 exponent) { return T1( std::pow( double(base), double(exponent) ) ); } - template arma_inline static typename arma_real_or_cx_only::result pow (const T1 base, const T2 exponent) { return std::pow(base, exponent); } + template arma_inline static typename arma_real_or_cx_only::result pow (const T1 base, const T2 exponent) { return T1( std::pow( base, exponent ) ); } template diff --git a/src/armadillo_bits/eop_core_bones.hpp b/src/armadillo_bits/eop_core_bones.hpp index 1a36d918..8b9c75cc 100644 --- a/src/armadillo_bits/eop_core_bones.hpp +++ b/src/armadillo_bits/eop_core_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -46,12 +48,12 @@ class eop_core // common - template arma_hot arma_inline static eT process(const eT val, const eT k); + template arma_inline static eT process(const eT val, const eT k); }; -struct eop_use_mp_true { static const bool use_mp = true; }; -struct eop_use_mp_false { static const bool use_mp = false; }; +struct eop_use_mp_true { static constexpr bool use_mp = true; }; +struct eop_use_mp_false { static constexpr bool use_mp = false; }; class eop_neg : public eop_core , public eop_use_mp_false {}; @@ -99,6 +101,7 @@ class eop_sign : public eop_core , public eo class eop_erf : public eop_core , public eop_use_mp_true {}; class eop_erfc : public eop_core , public eop_use_mp_true {}; class eop_lgamma : public eop_core , public eop_use_mp_true {}; +class eop_tgamma : public eop_core , public eop_use_mp_true {}; diff --git a/src/armadillo_bits/eop_core_meat.hpp b/src/armadillo_bits/eop_core_meat.hpp index cfad1749..4bc0c7fe 100644 --- a/src/armadillo_bits/eop_core_meat.hpp +++ b/src/armadillo_bits/eop_core_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -163,7 +165,7 @@ -#if (defined(ARMA_USE_OPENMP) && defined(ARMA_USE_CXX11)) +#if defined(ARMA_USE_OPENMP) #define arma_applier_1_mp(operatorA) \ {\ @@ -239,7 +241,6 @@ template template -arma_hot inline void eop_core::apply(outT& out, const eOp& x) @@ -254,7 +255,7 @@ eop_core::apply(outT& out, const eOp& x) const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(Proxy::use_at == false) { @@ -315,7 +316,6 @@ eop_core::apply(outT& out, const eOp& x) template template -arma_hot inline void eop_core::apply_inplace_plus(Mat& out, const eOp& x) @@ -332,7 +332,7 @@ eop_core::apply_inplace_plus(Mat& out, const e const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(Proxy::use_at == false) { @@ -390,7 +390,7 @@ eop_core::apply_inplace_plus(Mat& out, const e template template -arma_hot + inline void eop_core::apply_inplace_minus(Mat& out, const eOp& x) @@ -407,7 +407,7 @@ eop_core::apply_inplace_minus(Mat& out, const const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(Proxy::use_at == false) { @@ -465,7 +465,7 @@ eop_core::apply_inplace_minus(Mat& out, const template template -arma_hot + inline void eop_core::apply_inplace_schur(Mat& out, const eOp& x) @@ -482,7 +482,7 @@ eop_core::apply_inplace_schur(Mat& out, const const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(Proxy::use_at == false) { @@ -540,7 +540,7 @@ eop_core::apply_inplace_schur(Mat& out, const template template -arma_hot + inline void eop_core::apply_inplace_div(Mat& out, const eOp& x) @@ -557,7 +557,7 @@ eop_core::apply_inplace_div(Mat& out, const eO const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOp::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(Proxy::use_at == false) { @@ -620,7 +620,7 @@ eop_core::apply_inplace_div(Mat& out, const eO template template -arma_hot + inline void eop_core::apply(Cube& out, const eOpCube& x) @@ -635,7 +635,7 @@ eop_core::apply(Cube& out, const eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(ProxyCube::use_at == false) { @@ -697,7 +697,7 @@ eop_core::apply(Cube& out, const eOpCube template -arma_hot + inline void eop_core::apply_inplace_plus(Cube& out, const eOpCube& x) @@ -715,7 +715,7 @@ eop_core::apply_inplace_plus(Cube& out, const const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(ProxyCube::use_at == false) { @@ -773,7 +773,7 @@ eop_core::apply_inplace_plus(Cube& out, const template template -arma_hot + inline void eop_core::apply_inplace_minus(Cube& out, const eOpCube& x) @@ -791,7 +791,7 @@ eop_core::apply_inplace_minus(Cube& out, const const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(ProxyCube::use_at == false) { @@ -849,7 +849,7 @@ eop_core::apply_inplace_minus(Cube& out, const template template -arma_hot + inline void eop_core::apply_inplace_schur(Cube& out, const eOpCube& x) @@ -867,7 +867,7 @@ eop_core::apply_inplace_schur(Cube& out, const const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(ProxyCube::use_at == false) { @@ -925,7 +925,7 @@ eop_core::apply_inplace_schur(Cube& out, const template template -arma_hot + inline void eop_core::apply_inplace_div(Cube& out, const eOpCube& x) @@ -943,7 +943,7 @@ eop_core::apply_inplace_div(Cube& out, const e const eT k = x.aux; eT* out_mem = out.memptr(); - const bool use_mp = (arma_config::cxx11 && arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); + const bool use_mp = (arma_config::openmp) && (eOpCube::use_mp || (is_same_type::value && (is_cx::yes || x.aux != eT(2)))); if(ProxyCube::use_at == false) { @@ -1006,7 +1006,6 @@ eop_core::apply_inplace_div(Cube& out, const e template template -arma_hot arma_inline eT eop_core::process(const eT, const eT) @@ -1017,141 +1016,144 @@ eop_core::process(const eT, const eT) -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return val + k; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return k - val; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return val - k; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return val * k; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return k / val; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return val / k; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return val*val; } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::neg(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::sqrt(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::log(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::log2(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::log10(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return arma::trunc_log(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::log1p(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::exp(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::exp2(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::exp10(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return arma::trunc_exp(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::expm1(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::cos(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::sin(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::tan(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::acos(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::asin(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::atan(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::cosh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::sinh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::tanh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::acosh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::asinh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::atanh(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return arma_sinc(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::direct_eps(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::arma_abs(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return arma_arg::eval(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::conj(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT k) { return eop_aux::pow(val, k); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::floor(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::ceil(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::round(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::trunc(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return arma_sign(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::erf(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::erfc(val); } -template<> template arma_hot arma_inline eT +template<> template arma_inline eT eop_core::process(const eT val, const eT ) { return eop_aux::lgamma(val); } +template<> template arma_inline eT +eop_core::process(const eT val, const eT ) { return eop_aux::tgamma(val); } + #undef arma_applier_1u #undef arma_applier_1a diff --git a/src/armadillo_bits/fft_engine_fftw3.hpp b/src/armadillo_bits/fft_engine_fftw3.hpp new file mode 100644 index 00000000..cabe5c96 --- /dev/null +++ b/src/armadillo_bits/fft_engine_fftw3.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ------------------------------------------------------------------------ + + +//! \addtogroup fft_engine_fftw3 +//! @{ + + +#if defined(ARMA_USE_FFTW3) + +template +class fft_engine_fftw3 + { + public: + + constexpr static int fftw3_sign_forward = -1; + constexpr static int fftw3_sign_backward = +1; + + constexpr static unsigned int fftw3_flag_destroy = (1u << 0); + constexpr static unsigned int fftw3_flag_preserve = (1u << 4); + constexpr static unsigned int fftw3_flag_estimate = (1u << 6); + + const uword N; + + void_ptr fftw3_plan; + + podarray X_work; // for storing copy of input (can be overwritten by FFTW3) + podarray Y_work; // for storing output + + inline + ~fft_engine_fftw3() + { + arma_extra_debug_sigprint(); + + if(fftw3_plan != nullptr) { fftw3::destroy_plan(fftw3_plan); } + + // fftw3::cleanup(); // NOTE: this also removes any wisdom acquired by FFTW3 + } + + inline + fft_engine_fftw3(const uword in_N) + : N (in_N ) + , fftw3_plan(nullptr) + { + arma_extra_debug_sigprint(); + + if(N == 0) { return; } + + if(N > uword(std::numeric_limits::max())) + { + arma_stop_runtime_error("integer overflow: FFT size too large for integer type used by FFTW3"); + } + + arma_extra_debug_print("fft_engine_fftw3::constructor: allocating work arrays"); + X_work.set_size(N); + Y_work.set_size(N); + + const int fftw3_sign = (inverse) ? fftw3_sign_backward : fftw3_sign_forward; + const int fftw3_flags = fftw3_flag_destroy | fftw3_flag_estimate; + + arma_extra_debug_print("fft_engine_fftw3::constructor: generating 1D plan"); + fftw3_plan = fftw3::plan_dft_1d(N, X_work.memptr(), Y_work.memptr(), fftw3_sign, fftw3_flags); + + if(fftw3_plan == nullptr) { arma_stop_runtime_error("fft_engine_fftw3::constructor: failed to create plan"); } + } + + inline + void + run(cx_type* Y, const cx_type* X) + { + arma_extra_debug_sigprint(); + + if(fftw3_plan == nullptr) { return; } + + arma_extra_debug_print("fft_engine_fftw3::run(): copying input array"); + arrayops::copy(X_work.memptr(), X, N); + + arma_extra_debug_print("fft_engine_fftw3::run(): executing plan"); + fftw3::execute(fftw3_plan); + + arma_extra_debug_print("fft_engine_fftw3::run(): copying output array"); + arrayops::copy(Y, Y_work.memptr(), N); + } + }; + +#endif + + +//! @} diff --git a/src/armadillo_bits/fft_engine.hpp b/src/armadillo_bits/fft_engine_kissfft.hpp similarity index 82% rename from src/armadillo_bits/fft_engine.hpp rename to src/armadillo_bits/fft_engine_kissfft.hpp index 56359e27..0c8c40c0 100644 --- a/src/armadillo_bits/fft_engine.hpp +++ b/src/armadillo_bits/fft_engine_kissfft.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,7 +20,6 @@ // licensed under the following conditions. // // Copyright (c) 2003-2010 Mark Borgerding -// // All rights reserved. // // Redistribution and use in source and binary forms, with or without modification, @@ -50,59 +51,25 @@ // ------------------------------------------------------------------------ -//! \addtogroup fft_engine +//! \addtogroup fft_engine_kissfft //! @{ -template struct fft_store {}; - -template -struct fft_store +template +class fft_engine_kissfft { - static const uword N = fixed_N; - - arma_aligned cx_type coeffs_array[fixed_N]; + public: - inline fft_store() {} - inline fft_store(uword) {} + typedef typename get_pod_type::result T; - arma_inline cx_type* coeffs_ptr() { return &coeffs_array[0]; } - arma_inline const cx_type* coeffs_ptr() const { return &coeffs_array[0]; } - }; - - - -template -struct fft_store - { const uword N; podarray coeffs_array; - - inline fft_store() : N(0) {} - inline fft_store(uword in_N) : N(in_N) { coeffs_array.set_size(N); } - - arma_inline cx_type* coeffs_ptr() { return coeffs_array.memptr(); } - arma_inline const cx_type* coeffs_ptr() const { return coeffs_array.memptr(); } - }; - - - -template -class fft_engine : public fft_store 0)> - { - public: - - typedef typename get_pod_type::result T; - - using fft_store 0)>::N; - using fft_store 0)>::coeffs_ptr; + podarray tmp_array; podarray residue; podarray radix; - podarray tmp_array; - template inline @@ -140,8 +107,8 @@ class fft_engine : public fft_store 0)> inline - fft_engine(const uword in_N) - : fft_store< cx_type, fixed_N, (fixed_N > 0) >(in_N) + fft_engine_kissfft(const uword in_N) + : N(in_N) { arma_extra_debug_sigprint(); @@ -155,7 +122,9 @@ class fft_engine : public fft_store 0)> // calculate the constant coefficients - cx_type* coeffs = coeffs_ptr(); + coeffs_array.set_size(N); + + cx_type* coeffs = coeffs_array.memptr(); const T k = T( (inverse) ? +2 : -2 ) * std::acos( T(-1) ) / T(N); @@ -167,11 +136,11 @@ class fft_engine : public fft_store 0)> arma_hot inline void - butterfly_2(cx_type* Y, const uword stride, const uword m) + butterfly_2(cx_type* Y, const uword stride, const uword m) const { - arma_extra_debug_sigprint(); + // arma_extra_debug_sigprint(); - const cx_type* coeffs = coeffs_ptr(); + const cx_type* coeffs = coeffs_array.memptr(); for(uword i=0; i < m; ++i) { @@ -187,14 +156,14 @@ class fft_engine : public fft_store 0)> arma_hot inline void - butterfly_3(cx_type* Y, const uword stride, const uword m) + butterfly_3(cx_type* Y, const uword stride, const uword m) const { - arma_extra_debug_sigprint(); + // arma_extra_debug_sigprint(); arma_aligned cx_type tmp[5]; - cx_type* coeffs1 = coeffs_ptr(); - cx_type* coeffs2 = coeffs1; + const cx_type* coeffs1 = coeffs_array.memptr(); + const cx_type* coeffs2 = coeffs1; const T coeff_sm_imag = coeffs1[stride*m].imag(); @@ -233,13 +202,13 @@ class fft_engine : public fft_store 0)> arma_hot inline void - butterfly_4(cx_type* Y, const uword stride, const uword m) + butterfly_4(cx_type* Y, const uword stride, const uword m) const { - arma_extra_debug_sigprint(); + // arma_extra_debug_sigprint(); arma_aligned cx_type tmp[7]; - const cx_type* coeffs = coeffs_ptr(); + const cx_type* coeffs = coeffs_array.memptr(); const uword m2 = m*2; const uword m3 = m*3; @@ -273,16 +242,16 @@ class fft_engine : public fft_store 0)> - inline arma_hot + inline void - butterfly_5(cx_type* Y, const uword stride, const uword m) + butterfly_5(cx_type* Y, const uword stride, const uword m) const { - arma_extra_debug_sigprint(); + // arma_extra_debug_sigprint(); arma_aligned cx_type tmp[13]; - const cx_type* coeffs = coeffs_ptr(); + const cx_type* coeffs = coeffs_array.memptr(); const T a_real = coeffs[stride*1*m].real(); const T a_imag = coeffs[stride*1*m].imag(); @@ -342,9 +311,9 @@ class fft_engine : public fft_store 0)> void butterfly_N(cx_type* Y, const uword stride, const uword m, const uword r) { - arma_extra_debug_sigprint(); + // arma_extra_debug_sigprint(); - const cx_type* coeffs = coeffs_ptr(); + const cx_type* coeffs = coeffs_array.memptr(); tmp_array.set_min_size(r); cx_type* tmp = tmp_array.memptr(); diff --git a/src/armadillo_bits/field_bones.hpp b/src/armadillo_bits/field_bones.hpp index b3c01fcf..d3d2b053 100644 --- a/src/armadillo_bits/field_bones.hpp +++ b/src/armadillo_bits/field_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,7 @@ struct field_prealloc_n_elem { - static const uword val = 16; + static constexpr uword val = 16; }; @@ -65,13 +67,15 @@ class field inline explicit field(const SizeMat& s); inline explicit field(const SizeCube& s); - inline void set_size(const uword n_obj_in); - inline void set_size(const uword n_rows_in, const uword n_cols_in); - inline void set_size(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in); - inline void set_size(const SizeMat& s); - inline void set_size(const SizeCube& s); + inline field& set_size(const uword n_obj_in); + inline field& set_size(const uword n_rows_in, const uword n_cols_in); + inline field& set_size(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in); + inline field& set_size(const SizeMat& s); + inline field& set_size(const SizeCube& s); + + inline field(const std::vector& x); + inline field& operator=(const std::vector& x); - #if defined(ARMA_USE_CXX11) inline field(const std::initializer_list& list); inline field& operator=(const std::initializer_list& list); @@ -80,34 +84,51 @@ class field inline field(field&& X); inline field& operator=(field&& X); - #endif template - inline void copy_size(const field& x); + inline field& copy_size(const field& x); - arma_inline oT& operator[](const uword i); - arma_inline const oT& operator[](const uword i) const; + arma_warn_unused arma_inline oT& operator[](const uword i); + arma_warn_unused arma_inline const oT& operator[](const uword i) const; - arma_inline oT& at(const uword i); - arma_inline const oT& at(const uword i) const; + arma_warn_unused arma_inline oT& at(const uword i); + arma_warn_unused arma_inline const oT& at(const uword i) const; - arma_inline oT& operator()(const uword i); - arma_inline const oT& operator()(const uword i) const; + arma_warn_unused arma_inline oT& operator()(const uword i); + arma_warn_unused arma_inline const oT& operator()(const uword i) const; - arma_inline oT& at(const uword row, const uword col); - arma_inline const oT& at(const uword row, const uword col) const; - - arma_inline oT& at(const uword row, const uword col, const uword slice); - arma_inline const oT& at(const uword row, const uword col, const uword slice) const; + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline oT& operator[](const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator[](const uword row, const uword col) const; + #endif + + arma_warn_unused arma_inline oT& at(const uword row, const uword col); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col) const; + + #if defined(__cpp_multidimensional_subscript) + arma_warn_unused arma_inline oT& operator[](const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator[](const uword row, const uword col, const uword slice) const; + #endif + + arma_warn_unused arma_inline oT& at(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col, const uword slice) const; - arma_inline oT& operator()(const uword row, const uword col); - arma_inline const oT& operator()(const uword row, const uword col) const; + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col) const; - arma_inline oT& operator()(const uword row, const uword col, const uword slice); - arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; + + + arma_warn_unused arma_inline oT& front(); + arma_warn_unused arma_inline const oT& front() const; - inline field_injector operator<<(const oT& val); - inline field_injector operator<<(const injector_end_of_row<>& x); + arma_warn_unused arma_inline oT& back(); + arma_warn_unused arma_inline const oT& back() const; + + + arma_frown("use braced initialiser list instead") inline field_injector operator<<(const oT& val); + arma_frown("use braced initialiser list instead") inline field_injector operator<<(const injector_end_of_row<>& x); inline subview_field row(const uword row_num); @@ -162,50 +183,45 @@ class field arma_cold inline void print( const std::string extra_text = "") const; arma_cold inline void print(std::ostream& user_stream, const std::string extra_text = "") const; - #if defined(ARMA_USE_CXX11) - inline const field& for_each(const std::function< void( oT&) >& F); + inline field& for_each(const std::function< void( oT&) >& F); inline const field& for_each(const std::function< void(const oT&) >& F) const; - #else - template inline const field& for_each(functor F); - template inline const field& for_each(functor F) const; - #endif - inline const field& fill(const oT& x); + inline field& fill(const oT& x); inline void reset(); inline void reset_objects(); - arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_empty() const; - arma_inline arma_warn_unused bool in_range(const uword i) const; - arma_inline arma_warn_unused bool in_range(const span& x) const; + arma_warn_unused arma_inline bool in_range(const uword i) const; + arma_warn_unused arma_inline bool in_range(const span& x) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const uword in_col) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const span& col_span) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const uword in_col) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const span& col_span) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const SizeMat& s) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; - arma_inline arma_warn_unused bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice) const; + arma_warn_unused arma_inline bool in_range(const span& row_span, const span& col_span, const span& slice_span) const; - arma_inline arma_warn_unused bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; + arma_warn_unused arma_inline bool in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const; - inline arma_cold bool save(const std::string name, const file_type type = arma_binary, const bool print_status = true) const; - inline arma_cold bool save( std::ostream& os, const file_type type = arma_binary, const bool print_status = true) const; + arma_cold inline bool save(const std::string name, const file_type type = arma_binary) const; + arma_cold inline bool save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool load(const std::string name, const file_type type = auto_detect, const bool print_status = true); - inline arma_cold bool load( std::istream& is, const file_type type = auto_detect, const bool print_status = true); + arma_cold inline bool load(const std::string name, const file_type type = auto_detect); + arma_cold inline bool load( std::istream& is, const file_type type = auto_detect); - inline arma_cold bool quiet_save(const std::string name, const file_type type = arma_binary) const; - inline arma_cold bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save(const std::string name, const file_type type = arma_binary) const; + arma_deprecated inline bool quiet_save( std::ostream& os, const file_type type = arma_binary) const; - inline arma_cold bool quiet_load(const std::string name, const file_type type = auto_detect); - inline arma_cold bool quiet_load( std::istream& is, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load(const std::string name, const file_type type = auto_detect); + arma_deprecated inline bool quiet_load( std::istream& is, const file_type type = auto_detect); // for container-like functionality @@ -286,7 +302,7 @@ class field public: - #ifdef ARMA_EXTRA_FIELD_PROTO + #if defined(ARMA_EXTRA_FIELD_PROTO) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_FIELD_PROTO) #endif }; diff --git a/src/armadillo_bits/field_meat.hpp b/src/armadillo_bits/field_meat.hpp index aa1bf3ba..b65bb84f 100644 --- a/src/armadillo_bits/field_meat.hpp +++ b/src/armadillo_bits/field_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ field::~field() if(n_elem > field_prealloc_n_elem::val) { delete [] mem; } // try to expose buggy user code that accesses deleted objects - if(arma_config::debug) { mem = 0; } + if(arma_config::debug) { mem = nullptr; } } @@ -41,7 +43,7 @@ field::field() , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); } @@ -56,7 +58,7 @@ field::field(const field& x) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint(arma_str::format("this = %x x = %x") % this % &x); @@ -74,12 +76,13 @@ field::operator=(const field& x) arma_extra_debug_sigprint(); init(x); + return *this; } -//! construct a field from subview_field (e.g. construct a field from a delayed subfield operation) +//! construct a field from subview_field (eg. construct a field from a delayed subfield operation) template inline field::field(const subview_field& X) @@ -87,7 +90,7 @@ field::field(const subview_field& X) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -96,7 +99,7 @@ field::field(const subview_field& X) -//! construct a field from subview_field (e.g. construct a field from a delayed subfield operation) +//! construct a field from subview_field (eg. construct a field from a delayed subfield operation) template inline field& @@ -105,6 +108,7 @@ field::operator=(const subview_field& X) arma_extra_debug_sigprint(); subview_field::extract(*this, X); + return *this; } @@ -119,7 +123,7 @@ field::field(const uword n_elem_in) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -136,7 +140,7 @@ field::field(const uword n_rows_in, const uword n_cols_in) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -153,7 +157,7 @@ field::field(const uword n_rows_in, const uword n_cols_in, const uword n_sli , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -169,7 +173,7 @@ field::field(const SizeMat& s) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -185,7 +189,7 @@ field::field(const SizeCube& s) , n_cols(0) , n_slices(0) , n_elem(0) - , mem(0) + , mem(nullptr) { arma_extra_debug_sigprint_this(this); @@ -198,12 +202,14 @@ field::field(const SizeCube& s) //! assuming a column-major layout (data is not preserved) template inline -void +field& field::set_size(const uword n_elem_in) { - arma_extra_debug_sigprint(arma_str::format("n_elem_in = %d") % n_elem_in); + arma_extra_debug_sigprint(arma_str::format("n_elem_in = %u") % n_elem_in); init(n_elem_in, 1); + + return *this; } @@ -211,12 +217,14 @@ field::set_size(const uword n_elem_in) //! change the field to have the specified dimensions (data is not preserved) template inline -void +field& field::set_size(const uword n_rows_in, const uword n_cols_in) { - arma_extra_debug_sigprint(arma_str::format("n_rows_in = %d, n_cols_in = %d") % n_rows_in % n_cols_in); + arma_extra_debug_sigprint(arma_str::format("n_rows_in = %u, n_cols_in = %u") % n_rows_in % n_cols_in); init(n_rows_in, n_cols_in); + + return *this; } @@ -224,212 +232,242 @@ field::set_size(const uword n_rows_in, const uword n_cols_in) //! change the field to have the specified dimensions (data is not preserved) template inline -void +field& field::set_size(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in) { - arma_extra_debug_sigprint(arma_str::format("n_rows_in = %d, n_cols_in = %d, n_slices_in = %d") % n_rows_in % n_cols_in % n_slices_in); + arma_extra_debug_sigprint(arma_str::format("n_rows_in = %u, n_cols_in = %u, n_slices_in = %u") % n_rows_in % n_cols_in % n_slices_in); init(n_rows_in, n_cols_in, n_slices_in); + + return *this; } template inline -void +field& field::set_size(const SizeMat& s) { + arma_extra_debug_sigprint(); + init(s.n_rows, s.n_cols); + + return *this; } template inline -void +field& field::set_size(const SizeCube& s) { + arma_extra_debug_sigprint(); + init(s.n_rows, s.n_cols, s.n_slices); + + return *this; } -#if defined(ARMA_USE_CXX11) +template +inline +field::field(const std::vector& x) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); - template - inline - field::field(const std::initializer_list& list) - : n_rows (0) - , n_cols (0) - , n_slices(0) - , n_elem (0) - { - arma_extra_debug_sigprint_this(this); - - (*this).operator=(list); - } + (*this).operator=(x); + } + + + +template +inline +field& +field::operator=(const std::vector& x) + { + arma_extra_debug_sigprint(); + const uword N = uword(x.size()); + set_size(N, 1); - template - inline - field& - field::operator=(const std::initializer_list& list) - { - arma_extra_debug_sigprint(); - - const uword N = uword(list.size()); - - set_size(1, N); - - const oT* item_ptr = list.begin(); - - for(uword i=0; i +inline +field::field(const std::initializer_list& list) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); + (*this).operator=(list); + } + + + +template +inline +field& +field::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); - template - inline - field::field(const std::initializer_list< std::initializer_list >& list) - : n_rows (0) - , n_cols (0) - , n_slices(0) - , n_elem (0) - { - arma_extra_debug_sigprint_this(this); - - (*this).operator=(list); - } + const uword N = uword(list.size()); + set_size(1, N); + const oT* item_ptr = list.begin(); - template - inline - field& - field::operator=(const std::initializer_list< std::initializer_list >& list) + for(uword i=0; i +inline +field::field(const std::initializer_list< std::initializer_list >& list) + : n_rows (0) + , n_cols (0) + , n_slices(0) + , n_elem (0) + { + arma_extra_debug_sigprint_this(this); + + (*this).operator=(list); + } + + + +template +inline +field& +field::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + uword x_n_rows = uword(list.size()); + uword x_n_cols = 0; + + auto it = list.begin(); + auto it_end = list.end(); + + for(; it != it_end; ++it) { x_n_cols = (std::max)(x_n_cols, uword((*it).size())); } + + field& t = (*this); + + t.set_size(x_n_rows, x_n_cols); + + uword row_num = 0; + + auto row_it = list.begin(); + auto row_it_end = list.end(); + + for(; row_it != row_it_end; ++row_it) { - arma_extra_debug_sigprint(); + uword col_num = 0; - uword x_n_rows = uword(list.size()); - uword x_n_cols = 0; + auto col_it = (*row_it).begin(); + auto col_it_end = (*row_it).end(); - bool x_n_cols_found = false; - - auto it = list.begin(); - auto it_end = list.end(); - - for(; it != it_end; ++it) + for(; col_it != col_it_end; ++col_it) { - if(x_n_cols_found == false) - { - x_n_cols = uword( (*it).size() ); - x_n_cols_found = true; - } - else - { - arma_check( (uword((*it).size()) != x_n_cols), "field::init(): inconsistent number of columns in initialiser list" ); - } + t.at(row_num, col_num) = (*col_it); + ++col_num; } - field& t = (*this); - - t.set_size(x_n_rows, x_n_cols); - - uword row_num = 0; - - auto row_it = list.begin(); - auto row_it_end = list.end(); - - for(; row_it != row_it_end; ++row_it) + for(uword c=col_num; c < x_n_cols; ++c) { - uword col_num = 0; - - auto col_it = (*row_it).begin(); - auto col_it_end = (*row_it).end(); - - for(; col_it != col_it_end; ++col_it) - { - t.at(row_num, col_num) = (*col_it); - ++col_num; - } - - ++row_num; + t.at(row_num, c) = oT(); } - return *this; + ++row_num; } + return *this; + } + + + +template +inline +field::field(field&& X) + : n_rows (X.n_rows ) + , n_cols (X.n_cols ) + , n_slices(X.n_slices) + , n_elem (X.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - - template - inline - field::field(field&& X) - : n_rows (X.n_rows ) - , n_cols (X.n_cols ) - , n_slices(X.n_slices) - , n_elem (X.n_elem ) + if(n_elem > field_prealloc_n_elem::val) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - - if(n_elem > field_prealloc_n_elem::val) - { - mem = X.mem; - } - else - { - arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); - mem = mem_local; - } - - access::rw(X.n_rows ) = 0; - access::rw(X.n_cols ) = 0; - access::rw(X.n_slices) = 0; - access::rw(X.n_elem ) = 0; - access::rw(X.mem ) = 0; + mem = X.mem; } + else + { + arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); + mem = mem_local; + } + + access::rw(X.n_rows ) = 0; + access::rw(X.n_cols ) = 0; + access::rw(X.n_slices) = 0; + access::rw(X.n_elem ) = 0; + access::rw(X.mem ) = nullptr; + } + + + +template +inline +field& +field::operator=(field&& X) + { + arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); + if(this == &X) { return *this; } + reset(); - template - inline - field& - field::operator=(field&& X) + access::rw(n_rows ) = X.n_rows; + access::rw(n_cols ) = X.n_cols; + access::rw(n_slices) = X.n_slices; + access::rw(n_elem ) = X.n_elem; + + if(n_elem > field_prealloc_n_elem::val) { - arma_extra_debug_sigprint(arma_str::format("this = %x X = %x") % this % &X); - - reset(); - - access::rw(n_rows ) = X.n_rows; - access::rw(n_cols ) = X.n_cols; - access::rw(n_slices) = X.n_slices; - access::rw(n_elem ) = X.n_elem; - - if(n_elem > field_prealloc_n_elem::val) - { - mem = X.mem; - } - else - { - arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); - mem = mem_local; - } - - access::rw(X.n_rows ) = 0; - access::rw(X.n_cols ) = 0; - access::rw(X.n_elem ) = 0; - access::rw(X.n_slices) = 0; - access::rw(X.mem ) = 0; - - return *this; + mem = X.mem; + } + else + { + arrayops::copy(&mem_local[0], &X.mem_local[0], n_elem); + mem = mem_local; } -#endif + access::rw(X.n_rows ) = 0; + access::rw(X.n_cols ) = 0; + access::rw(X.n_elem ) = 0; + access::rw(X.n_slices) = 0; + access::rw(X.mem ) = nullptr; + + return *this; + } @@ -437,12 +475,14 @@ field::set_size(const SizeCube& s) template template inline -void +field& field::copy_size(const field& x) { arma_extra_debug_sigprint(); init(x.n_rows, x.n_cols, x.n_slices); + + return *this; } @@ -497,7 +537,8 @@ arma_inline oT& field::operator() (const uword i) { - arma_debug_check( (i >= n_elem), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "field::operator(): index out of bounds" ); + return (*mem[i]); } @@ -509,7 +550,8 @@ arma_inline const oT& field::operator() (const uword i) const { - arma_debug_check( (i >= n_elem), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "field::operator(): index out of bounds" ); + return (*mem[i]); } @@ -521,7 +563,8 @@ arma_inline oT& field::operator() (const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + return (*mem[in_row + in_col*n_rows]); } @@ -533,7 +576,8 @@ arma_inline const oT& field::operator() (const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (0 >= n_slices) ), "field::operator(): index out of bounds" ); + return (*mem[in_row + in_col*n_rows]); } @@ -545,7 +589,8 @@ arma_inline oT& field::operator() (const uword in_row, const uword in_col, const uword in_slice) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); } @@ -557,12 +602,39 @@ arma_inline const oT& field::operator() (const uword in_row, const uword in_col, const uword in_slice) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "field::operator(): index out of bounds" ); + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); } +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + oT& + field::operator[] (const uword in_row, const uword in_col) + { + return (*mem[in_row + in_col*n_rows]); + } + + + + //! element accessor; no bounds check + template + arma_inline + const oT& + field::operator[] (const uword in_row, const uword in_col) const + { + return (*mem[in_row + in_col*n_rows]); + } + +#endif + + + //! element accessor; no bounds check template arma_inline @@ -585,6 +657,32 @@ field::at(const uword in_row, const uword in_col) const +#if defined(__cpp_multidimensional_subscript) + + //! element accessor; no bounds check + template + arma_inline + oT& + field::operator[] (const uword in_row, const uword in_col, const uword in_slice) + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + + + + //! element accessor; no bounds check + template + arma_inline + const oT& + field::operator[] (const uword in_row, const uword in_col, const uword in_slice) const + { + return (*mem[in_row + in_col*n_rows + in_slice*(n_rows*n_cols)]); + } + +#endif + + + //! element accessor; no bounds check template arma_inline @@ -607,6 +705,54 @@ field::at(const uword in_row, const uword in_col, const uword in_slice) cons +template +arma_inline +oT& +field::front() + { + arma_debug_check( (n_elem == 0), "field::front(): field is empty" ); + + return (*mem[0]); + } + + + +template +arma_inline +const oT& +field::front() const + { + arma_debug_check( (n_elem == 0), "field::front(): field is empty" ); + + return (*mem[0]); + } + + + +template +arma_inline +oT& +field::back() + { + arma_debug_check( (n_elem == 0), "field::back(): field is empty" ); + + return (*mem[n_elem-1]); + } + + + +template +arma_inline +const oT& +field::back() const + { + arma_debug_check( (n_elem == 0), "field::back(): field is empty" ); + + return (*mem[n_elem-1]); + } + + + template inline field_injector< field > @@ -637,7 +783,7 @@ field::row(const uword row_num) arma_debug_check( (n_slices >= 2), "field::row(): field must be 2D" ); - arma_debug_check( (row_num >= n_rows), "field::row(): row out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "field::row(): row out of bounds" ); return subview_field(*this, row_num, 0, 1, n_cols); } @@ -654,7 +800,7 @@ field::row(const uword row_num) const arma_debug_check( (n_slices >= 2), "field::row(): field must be 2D" ); - arma_debug_check( (row_num >= n_rows), "field::row(): row out of bounds" ); + arma_debug_check_bounds( (row_num >= n_rows), "field::row(): row out of bounds" ); return subview_field(*this, row_num, 0, 1, n_cols); } @@ -671,7 +817,7 @@ field::col(const uword col_num) arma_debug_check( (n_slices >= 2), "field::col(): field must be 2D" ); - arma_debug_check( (col_num >= n_cols), "field::col(): out of bounds" ); + arma_debug_check_bounds( (col_num >= n_cols), "field::col(): out of bounds" ); return subview_field(*this, 0, col_num, n_rows, 1); } @@ -688,7 +834,7 @@ field::col(const uword col_num) const arma_debug_check( (n_slices >= 2), "field::col(): field must be 2D" ); - arma_debug_check( (col_num >= n_cols), "field::col(): out of bounds" ); + arma_debug_check_bounds( (col_num >= n_cols), "field::col(): out of bounds" ); return subview_field(*this, 0, col_num, n_rows, 1); } @@ -703,7 +849,7 @@ field::slice(const uword slice_num) { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "field::slice(): out of bounds" ); + arma_debug_check_bounds( (slice_num >= n_slices), "field::slice(): out of bounds" ); return subview_field(*this, 0, 0, slice_num, n_rows, n_cols, 1); } @@ -718,7 +864,7 @@ field::slice(const uword slice_num) const { arma_extra_debug_sigprint(); - arma_debug_check( (slice_num >= n_slices), "field::slice(): out of bounds" ); + arma_debug_check_bounds( (slice_num >= n_slices), "field::slice(): out of bounds" ); return subview_field(*this, 0, 0, slice_num, n_rows, n_cols, 1); } @@ -735,7 +881,7 @@ field::rows(const uword in_row1, const uword in_row2) arma_debug_check( (n_slices >= 2), "field::rows(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( ( (in_row1 > in_row2) || (in_row2 >= n_rows) ), "field::rows(): indicies out of bounds or incorrectly used" @@ -758,7 +904,7 @@ field::rows(const uword in_row1, const uword in_row2) const arma_debug_check( (n_slices >= 2), "field::rows(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( ( (in_row1 > in_row2) || (in_row2 >= n_rows) ), "field::rows(): indicies out of bounds or incorrectly used" @@ -781,7 +927,7 @@ field::cols(const uword in_col1, const uword in_col2) arma_debug_check( (n_slices >= 2), "field::cols(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( ( (in_col1 > in_col2) || (in_col2 >= n_cols) ), "field::cols(): indicies out of bounds or incorrectly used" @@ -804,7 +950,7 @@ field::cols(const uword in_col1, const uword in_col2) const arma_debug_check( (n_slices >= 2), "field::cols(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( ( (in_col1 > in_col2) || (in_col2 >= n_cols) ), "field::cols(): indicies out of bounds or incorrectly used" @@ -825,7 +971,7 @@ field::slices(const uword in_slice1, const uword in_slice2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices) ), "field::slices(): indicies out of bounds or incorrectly used" @@ -846,7 +992,7 @@ field::slices(const uword in_slice1, const uword in_slice2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( ( (in_slice1 > in_slice2) || (in_slice2 >= n_slices) ), "field::slices(): indicies out of bounds or incorrectly used" @@ -869,7 +1015,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_row arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "field::subfield(): indices out of bounds or incorrectly used" @@ -893,7 +1039,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_row arma_debug_check( (n_slices >= 2), "field::subfield(): field must be 2D" ); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "field::subfield(): indices out of bounds or incorrectly used" @@ -915,7 +1061,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_sli { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), "field::subfield(): indices out of bounds or incorrectly used" @@ -938,7 +1084,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_sli { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_slice1 > in_slice2) || (in_row2 >= n_rows) || (in_col2 >= n_cols) || (in_slice2 >= n_slices), "field::subfield(): indices out of bounds or incorrectly used" @@ -969,7 +1115,7 @@ field::subfield(const uword in_row1, const uword in_col1, const SizeMat& s) const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "field::subfield(): indices or size out of bounds" @@ -996,7 +1142,7 @@ field::subfield(const uword in_row1, const uword in_col1, const SizeMat& s) const uword s_n_rows = s.n_rows; const uword s_n_cols = s.n_cols; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols)), "field::subfield(): indices or size out of bounds" @@ -1023,7 +1169,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_sli const uword s_n_cols = s.n_cols; const uword sub_n_slices = s.n_slices; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || (in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + sub_n_slices) > l_n_slices)), "field::subfield(): indices or size out of bounds" @@ -1050,7 +1196,7 @@ field::subfield(const uword in_row1, const uword in_col1, const uword in_sli const uword s_n_cols = s.n_cols; const uword sub_n_slices = s.n_slices; - arma_debug_check + arma_debug_check_bounds ( ((in_row1 >= l_n_rows) || (in_col1 >= l_n_cols) || (in_slice1 >= l_n_slices) || ((in_row1 + s_n_rows) > l_n_rows) || ((in_col1 + s_n_cols) > l_n_cols) || ((in_slice1 + sub_n_slices) > l_n_slices)), "field::subfield(): indices or size out of bounds" @@ -1085,7 +1231,7 @@ field::subfield(const span& row_span, const span& col_span) const uword in_col2 = col_span.b; const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1123,7 +1269,7 @@ field::subfield(const span& row_span, const span& col_span) const const uword in_col2 = col_span.b; const uword sub_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1165,7 +1311,7 @@ field::subfield(const span& row_span, const span& col_span, const span& slic const uword in_slice2 = slice_span.b; const uword sub_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1209,7 +1355,7 @@ field::subfield(const span& row_span, const span& col_span, const span& slic const uword in_slice2 = slice_span.b; const uword sub_n_slices = slice_all ? local_n_slices : in_slice2 - in_slice1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -1327,11 +1473,10 @@ field::operator()(const uword in_row1, const uword in_col1, const uword in_s //! but the associated operator<< function for type oT //! may still modify the stream's parameters. //! NOTE: this function assumes that type oT can be printed, -//! i.e. the function "std::ostream& operator<< (std::ostream&, const oT&)" +//! ie. the function "std::ostream& operator<< (std::ostream&, const oT&)" //! has been defined. template -arma_cold inline void field::print(const std::string extra_text) const @@ -1358,11 +1503,10 @@ field::print(const std::string extra_text) const //! but the associated operator<< function for type oT //! may still modify the stream's parameters. //! NOTE: this function assumes that type oT can be printed, -//! i.e. the function "std::ostream& operator<< (std::ostream&, const oT&)" +//! ie. the function "std::ostream& operator<< (std::ostream&, const oT&)" //! has been defined. template -arma_cold inline void field::print(std::ostream& user_stream, const std::string extra_text) const @@ -1383,103 +1527,53 @@ field::print(std::ostream& user_stream, const std::string extra_text) const -#if defined(ARMA_USE_CXX11) - - //! apply a lambda function to each object - template - inline - const field& - field::for_each(const std::function< void(oT&) >& F) - { - arma_extra_debug_sigprint(); - - for(uword i=0; i < n_elem; ++i) - { - F(operator[](i)); - } - - return *this; - } - - - - template - inline - const field& - field::for_each(const std::function< void(const oT&) >& F) const - { - arma_extra_debug_sigprint(); - - for(uword i=0; i < n_elem; ++i) - { - F(operator[](i)); - } - - return *this; - } +//! apply a lambda function to each object +template +inline +field& +field::for_each(const std::function< void(oT&) >& F) + { + arma_extra_debug_sigprint(); -#else - - //! apply a functor to each object - template - template - inline - const field& - field::for_each(functor F) - { - arma_extra_debug_sigprint(); - - for(uword i=0; i < n_elem; ++i) - { - F(operator[](i)); - } - - return *this; - } + for(uword i=0; i < n_elem; ++i) { F(operator[](i)); } + return *this; + } + + + +template +inline +const field& +field::for_each(const std::function< void(const oT&) >& F) const + { + arma_extra_debug_sigprint(); + for(uword i=0; i < n_elem; ++i) { F(operator[](i)); } - template - template - inline - const field& - field::for_each(functor F) const - { - arma_extra_debug_sigprint(); - - for(uword i=0; i < n_elem; ++i) - { - F(operator[](i)); - } - - return *this; - } - -#endif + return *this; + } //! fill the field with an object template inline -const field& +field& field::fill(const oT& x) { arma_extra_debug_sigprint(); field& t = *this; - for(uword i=0; i inline void @@ -1519,7 +1613,6 @@ field::is_empty() const //! returns true if the given index is currently in range template arma_inline -arma_warn_unused bool field::in_range(const uword i) const { @@ -1531,13 +1624,12 @@ field::in_range(const uword i) const //! returns true if the given start and end indices are currently in range template arma_inline -arma_warn_unused bool field::in_range(const span& x) const { arma_extra_debug_sigprint(); - if(x.whole == true) + if(x.whole) { return true; } @@ -1555,7 +1647,6 @@ field::in_range(const span& x) const //! returns true if the given location is currently in range template arma_inline -arma_warn_unused bool field::in_range(const uword in_row, const uword in_col) const { @@ -1566,13 +1657,12 @@ field::in_range(const uword in_row, const uword in_col) const template arma_inline -arma_warn_unused bool field::in_range(const span& row_span, const uword in_col) const { arma_extra_debug_sigprint(); - if(row_span.whole == true) + if(row_span.whole) { return (in_col < n_cols); } @@ -1589,13 +1679,12 @@ field::in_range(const span& row_span, const uword in_col) const template arma_inline -arma_warn_unused bool field::in_range(const uword in_row, const span& col_span) const { arma_extra_debug_sigprint(); - if(col_span.whole == true) + if(col_span.whole) { return (in_row < n_rows); } @@ -1612,7 +1701,6 @@ field::in_range(const uword in_row, const span& col_span) const template arma_inline -arma_warn_unused bool field::in_range(const span& row_span, const span& col_span) const { @@ -1627,14 +1715,13 @@ field::in_range(const span& row_span, const span& col_span) const const bool rows_ok = row_span.whole ? true : ( (in_row1 <= in_row2) && (in_row2 < n_rows) ); const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2) && (in_col2 < n_cols) ); - return ( (rows_ok == true) && (cols_ok == true) ); + return ( rows_ok && cols_ok ); } template arma_inline -arma_warn_unused bool field::in_range(const uword in_row, const uword in_col, const SizeMat& s) const { @@ -1655,7 +1742,6 @@ field::in_range(const uword in_row, const uword in_col, const SizeMat& s) co template arma_inline -arma_warn_unused bool field::in_range(const uword in_row, const uword in_col, const uword in_slice) const { @@ -1666,7 +1752,6 @@ field::in_range(const uword in_row, const uword in_col, const uword in_slice template arma_inline -arma_warn_unused bool field::in_range(const span& row_span, const span& col_span, const span& slice_span) const { @@ -1685,14 +1770,13 @@ field::in_range(const span& row_span, const span& col_span, const span& slic const bool cols_ok = col_span.whole ? true : ( (in_col1 <= in_col2 ) && (in_col2 < n_cols ) ); const bool slices_ok = slice_span.whole ? true : ( (in_slice1 <= in_slice2) && (in_slice2 < n_slices) ); - return ( (rows_ok == true) && (cols_ok == true) && (slices_ok == true) ); + return ( rows_ok && cols_ok && slices_ok ); } template arma_inline -arma_warn_unused bool field::in_range(const uword in_row, const uword in_col, const uword in_slice, const SizeCube& s) const { @@ -1714,24 +1798,24 @@ field::in_range(const uword in_row, const uword in_col, const uword in_slice template inline -arma_cold bool -field::save(const std::string name, const file_type type, const bool print_status) const +field::save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); std::string err_msg; + const bool save_okay = field_aux::save(*this, name, type, err_msg); - if( (print_status == true) && (save_okay == false) ) + if(save_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("field::save(): ", err_msg, name); + arma_debug_warn_level(3, "field::save(): ", err_msg, "; file: ", name); } else { - arma_debug_warn("field::save(): couldn't write to ", name); + arma_debug_warn_level(3, "field::save(): write failed; file: ", name); } } @@ -1742,24 +1826,24 @@ field::save(const std::string name, const file_type type, const bool print_s template inline -arma_cold bool -field::save(std::ostream& os, const file_type type, const bool print_status) const +field::save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); std::string err_msg; + const bool save_okay = field_aux::save(*this, os, type, err_msg); - if( (print_status == true) && (save_okay == false) ) + if(save_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("field::save(): ", err_msg, "[ostream]"); + arma_debug_warn_level(3, "field::save(): ", err_msg); } else { - arma_debug_warn("field::save(): couldn't write to [ostream]"); + arma_debug_warn_level(3, "field::save(): stream write failed"); } } @@ -1770,31 +1854,28 @@ field::save(std::ostream& os, const file_type type, const bool print_status) template inline -arma_cold bool -field::load(const std::string name, const file_type type, const bool print_status) +field::load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); std::string err_msg; + const bool load_okay = field_aux::load(*this, name, type, err_msg); - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("field::load(): ", err_msg, name); + arma_debug_warn_level(3, "field::load(): ", err_msg, "; file: ", name); } else { - arma_debug_warn("field::load(): couldn't read from ", name); + arma_debug_warn_level(3, "field::load(): read failed; file: ", name); } } - if(load_okay == false) - { - (*this).reset(); - } + if(load_okay == false) { (*this).reset(); } return load_okay; } @@ -1803,31 +1884,27 @@ field::load(const std::string name, const file_type type, const bool print_s template inline -arma_cold bool -field::load(std::istream& is, const file_type type, const bool print_status) +field::load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); std::string err_msg; const bool load_okay = field_aux::load(*this, is, type, err_msg); - if( (print_status == true) && (load_okay == false) ) + if(load_okay == false) { if(err_msg.length() > 0) { - arma_debug_warn("field::load(): ", err_msg, "[istream]"); + arma_debug_warn_level(3, "field::load(): ", err_msg); } else { - arma_debug_warn("field::load(): couldn't read from [istream]"); + arma_debug_warn_level(3, "field::load(): stream read failed"); } } - if(load_okay == false) - { - (*this).reset(); - } + if(load_okay == false) { (*this).reset(); } return load_okay; } @@ -1836,52 +1913,48 @@ field::load(std::istream& is, const file_type type, const bool print_status) template inline -arma_cold bool field::quiet_save(const std::string name, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(name, type, false); + return (*this).save(name, type); } template inline -arma_cold bool field::quiet_save(std::ostream& os, const file_type type) const { arma_extra_debug_sigprint(); - return (*this).save(os, type, false); + return (*this).save(os, type); } template inline -arma_cold bool field::quiet_load(const std::string name, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(name, type, false); + return (*this).load(name, type); } template inline -arma_cold bool field::quiet_load(std::istream& is, const file_type type) { arma_extra_debug_sigprint(); - return (*this).load(is, type, false); + return (*this).load(is, type); } @@ -1941,12 +2014,12 @@ inline void field::init(const uword n_rows_in, const uword n_cols_in, const uword n_slices_in) { - arma_extra_debug_sigprint( arma_str::format("n_rows_in = %d, n_cols_in = %d, n_slices_in = %d") % n_rows_in % n_cols_in % n_slices_in ); + arma_extra_debug_sigprint( arma_str::format("n_rows_in = %u, n_cols_in = %u, n_slices_in = %u") % n_rows_in % n_cols_in % n_slices_in ); - #if (defined(ARMA_USE_CXX11) || defined(ARMA_64BIT_WORD)) + #if defined(ARMA_64BIT_WORD) const char* error_message = "field::init(): requested size is too large"; #else - const char* error_message = "field::init(): requested size is too large; suggest to compile in C++11 mode or enable ARMA_64BIT_WORD"; + const char* error_message = "field::init(): requested size is too large; suggest to enable ARMA_64BIT_WORD"; #endif arma_debug_check @@ -1973,26 +2046,17 @@ field::init(const uword n_rows_in, const uword n_cols_in, const uword n_slic { delete_objects(); - if(n_elem > field_prealloc_n_elem::val) - { - delete [] mem; - } + if(n_elem > field_prealloc_n_elem::val) { delete [] mem; } if(n_elem_new <= field_prealloc_n_elem::val) { - if(n_elem_new == 0) - { - mem = NULL; - } - else - { - mem = mem_local; - } + mem = (n_elem_new == 0) ? nullptr : mem_local; } else { mem = new(std::nothrow) oT* [n_elem_new]; - arma_check_bad_alloc( (mem == 0), "field::init(): out of memory" ); + + arma_check_bad_alloc( (mem == nullptr), "field::init(): out of memory" ); } access::rw(n_rows) = n_rows_in; @@ -2011,14 +2075,14 @@ inline void field::delete_objects() { - arma_extra_debug_sigprint( arma_str::format("n_elem = %d") % n_elem ); + arma_extra_debug_sigprint( arma_str::format("n_elem = %u") % n_elem ); for(uword i=0; i::create_objects() { - arma_extra_debug_sigprint( arma_str::format("n_elem = %d") % n_elem ); + arma_extra_debug_sigprint( arma_str::format("n_elem = %u") % n_elem ); - for(uword i=0; i::iterator& field::iterator::operator--() { - if(i > 0) - { - --i; - } + if(i > 0) { --i; } return *this; } @@ -2187,10 +2245,7 @@ inline typename field::const_iterator& field::const_iterator::operator--() { - if(i > 0) - { - --i; - } + if(i > 0) { --i; } return *this; } @@ -2355,10 +2410,7 @@ field_aux::reset_objects(field< Mat >& x) { arma_extra_debug_sigprint(); - for(uword i=0; i >& x) { arma_extra_debug_sigprint(); - for(uword i=0; i >& x) { arma_extra_debug_sigprint(); - for(uword i=0; i >& x) { arma_extra_debug_sigprint(); - for(uword i=0; i& x) { arma_extra_debug_sigprint(); - for(uword i=0; i&, const std::string&, const file_type, std::stri { arma_extra_debug_sigprint(); - err_msg = " [saving/loading this type of field is currently not supported] filename = "; + err_msg = "saving/loading this type of field is currently not supported"; return false; } @@ -2449,7 +2489,7 @@ field_aux::save(const field&, std::ostream&, const file_type, std::string& e { arma_extra_debug_sigprint(); - err_msg = " [saving/loading this type of field is currently not supported] filename = "; + err_msg = "saving/loading this type of field is currently not supported"; return false; } @@ -2463,7 +2503,7 @@ field_aux::load(field&, const std::string&, const file_type, std::string& er { arma_extra_debug_sigprint(); - err_msg = " [saving/loading this type of field is currently not supported] filename = "; + err_msg = "saving/loading this type of field is currently not supported"; return false; } @@ -2477,7 +2517,7 @@ field_aux::load(field&, std::istream&, const file_type, std::string& err_msg { arma_extra_debug_sigprint(); - err_msg = " [saving/loading this type of field is currently not supported] filename = "; + err_msg = "saving/loading this type of field is currently not supported"; return false; } @@ -2502,7 +2542,7 @@ field_aux::save(const field< Mat >& x, const std::string& name, const file_t break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2527,7 +2567,7 @@ field_aux::save(const field< Mat >& x, std::ostream& os, const file_type typ break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2556,7 +2596,7 @@ field_aux::load(field< Mat >& x, const std::string& name, const file_type ty break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2585,7 +2625,7 @@ field_aux::load(field< Mat >& x, std::istream& is, const file_type type, std break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2610,7 +2650,7 @@ field_aux::save(const field< Col >& x, const std::string& name, const file_t break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2635,7 +2675,7 @@ field_aux::save(const field< Col >& x, std::ostream& os, const file_type typ break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2664,7 +2704,7 @@ field_aux::load(field< Col >& x, const std::string& name, const file_type ty break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2693,7 +2733,7 @@ field_aux::load(field< Col >& x, std::istream& is, const file_type type, std break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2718,7 +2758,7 @@ field_aux::save(const field< Row >& x, const std::string& name, const file_t break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2743,7 +2783,7 @@ field_aux::save(const field< Row >& x, std::ostream& os, const file_type typ break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2772,7 +2812,7 @@ field_aux::load(field< Row >& x, const std::string& name, const file_type ty break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2801,7 +2841,7 @@ field_aux::load(field< Row >& x, std::istream& is, const file_type type, std break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2822,7 +2862,7 @@ field_aux::save(const field< Cube >& x, const std::string& name, const file_ break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2843,7 +2883,7 @@ field_aux::save(const field< Cube >& x, std::ostream& os, const file_type ty break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2865,7 +2905,7 @@ field_aux::load(field< Cube >& x, const std::string& name, const file_type t break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2887,7 +2927,7 @@ field_aux::load(field< Cube >& x, std::istream& is, const file_type type, st break; default: - err_msg = " [unsupported type] filename = "; + err_msg = "unsupported type"; return false; } } @@ -2950,7 +2990,7 @@ field_aux::load(field< std::string >& x, std::istream& is, const file_type type, -#ifdef ARMA_EXTRA_FIELD_MEAT +#if defined(ARMA_EXTRA_FIELD_MEAT) #include ARMA_INCFILE_WRAP(ARMA_EXTRA_FIELD_MEAT) #endif diff --git a/src/armadillo_bits/fill.hpp b/src/armadillo_bits/fill.hpp new file mode 100644 index 00000000..8b410977 --- /dev/null +++ b/src/armadillo_bits/fill.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fill +//! @{ + + +namespace fill + { + struct fill_none {}; + struct fill_zeros {}; + struct fill_ones {}; + struct fill_eye {}; + struct fill_randu {}; + struct fill_randn {}; + + template + struct fill_class { inline constexpr fill_class() {} }; + + static constexpr fill_class none; + static constexpr fill_class zeros; + static constexpr fill_class ones; + static constexpr fill_class eye; + static constexpr fill_class randu; + static constexpr fill_class randn; + + // + + template + struct allow_conversion + { + static constexpr bool value = true; + }; + + template<> struct allow_conversion, double> { static constexpr bool value = false; }; + template<> struct allow_conversion, float > { static constexpr bool value = false; }; + template<> struct allow_conversion, u64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u8 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s8 > { static constexpr bool value = false; }; + + template<> struct allow_conversion, double> { static constexpr bool value = false; }; + template<> struct allow_conversion, float > { static constexpr bool value = false; }; + template<> struct allow_conversion, u64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s64 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s32 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s16 > { static constexpr bool value = false; }; + template<> struct allow_conversion, u8 > { static constexpr bool value = false; }; + template<> struct allow_conversion, s8 > { static constexpr bool value = false; }; + + // + + template inline bool isfinite_wrapper(eT ) { return true; } + template<> inline bool isfinite_wrapper(float x) { return std::isfinite(x); } + template<> inline bool isfinite_wrapper(double x) { return std::isfinite(x); } + template inline bool isfinite_wrapper(std::complex& x) { return std::isfinite(x.real()) && std::isfinite(x.imag()); } + + // + + template + struct scalar_holder + { + const scalar_type1 scalar; + + inline explicit scalar_holder(const scalar_type1& in_scalar) : scalar(in_scalar) {} + + inline scalar_holder() = delete; + + template + < + typename scalar_type2, + typename arma::enable_if2::value, int>::result = 0 + > + inline + operator scalar_holder() const + { + const bool ok_conversion = (std::is_integral::value && std::is_floating_point::value) ? isfinite_wrapper(scalar) : true; + + return scalar_holder( ok_conversion ? scalar_type2(scalar) : scalar_type2(0) ); + } + }; + + // + + template + inline + typename enable_if2< is_supported_elem_type::value, scalar_holder >::result + value(const scalar_type& in_scalar) + { + return scalar_holder(in_scalar); + } + } + + +//! @} diff --git a/src/armadillo_bits/fn_accu.hpp b/src/armadillo_bits/fn_accu.hpp index 283481b5..957459cd 100644 --- a/src/armadillo_bits/fn_accu.hpp +++ b/src/armadillo_bits/fn_accu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -68,7 +70,7 @@ accu_proxy_linear(const Proxy& P) } else { - #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0) + #if defined(__FAST_MATH__) { if(P.is_aligned()) { @@ -413,6 +415,118 @@ accu(const mtOp& X) +template +arma_warn_unused +inline +uword +accu(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_assert_same_size(PA, PB, "operator!="); + + uword n_nonzero = 0; + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typedef typename Proxy::ea_type PA_ea_type; + typedef typename Proxy::ea_type PB_ea_type; + + PA_ea_type A = PA.get_ea(); + PB_ea_type B = PB.get_ea(); + const uword n_elem = PA.get_n_elem(); + + for(uword i=0; i < n_elem; ++i) + { + n_nonzero += (A[i] != B[i]) ? uword(1) : uword(0); + } + } + else + { + const uword PA_n_cols = PA.get_n_cols(); + const uword PA_n_rows = PA.get_n_rows(); + + if(PA_n_rows == 1) + { + for(uword col=0; col < PA_n_cols; ++col) + { + n_nonzero += (PA.at(0,col) != PB.at(0,col)) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < PA_n_cols; ++col) + for(uword row=0; row < PA_n_rows; ++row) + { + n_nonzero += (PA.at(row,col) != PB.at(row,col)) ? uword(1) : uword(0); + } + } + } + + return n_nonzero; + } + + + +template +arma_warn_unused +inline +uword +accu(const mtGlue& X) + { + arma_extra_debug_sigprint(); + + const Proxy PA(X.A); + const Proxy PB(X.B); + + arma_debug_assert_same_size(PA, PB, "operator=="); + + uword n_nonzero = 0; + + if( (Proxy::use_at == false) && (Proxy::use_at == false) ) + { + typedef typename Proxy::ea_type PA_ea_type; + typedef typename Proxy::ea_type PB_ea_type; + + PA_ea_type A = PA.get_ea(); + PB_ea_type B = PB.get_ea(); + const uword n_elem = PA.get_n_elem(); + + for(uword i=0; i < n_elem; ++i) + { + n_nonzero += (A[i] == B[i]) ? uword(1) : uword(0); + } + } + else + { + const uword PA_n_cols = PA.get_n_cols(); + const uword PA_n_rows = PA.get_n_rows(); + + if(PA_n_rows == 1) + { + for(uword col=0; col < PA_n_cols; ++col) + { + n_nonzero += (PA.at(0,col) == PB.at(0,col)) ? uword(1) : uword(0); + } + } + else + { + for(uword col=0; col < PA_n_cols; ++col) + for(uword row=0; row < PA_n_rows; ++row) + { + n_nonzero += (PA.at(row,col) == PB.at(row,col)) ? uword(1) : uword(0); + } + } + } + + return n_nonzero; + } + + + //! accumulate the elements of a subview (submatrix) template arma_warn_unused @@ -426,29 +540,35 @@ accu(const subview& X) const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; - eT val = eT(0); - if(X_n_rows == 1) { - typedef subview_row sv_type; + const Mat& m = X.m; - const sv_type& sv = reinterpret_cast(X); // subview_row is a child class of subview and has no extra data + const uword col_offset = X.aux_col1; + const uword row_offset = X.aux_row1; - const Proxy P(sv); + eT val1 = eT(0); + eT val2 = eT(0); - val = accu_proxy_linear(P); - } - else - if(X_n_cols == 1) - { - val = arrayops::accumulate( X.colptr(0), X_n_rows ); - } - else - { - for(uword col=0; col < X_n_cols; ++col) + uword i,j; + for(i=0, j=1; j < X_n_cols; i+=2, j+=2) { - val += arrayops::accumulate( X.colptr(col), X_n_rows ); + val1 += m.at(row_offset, col_offset + i); + val2 += m.at(row_offset, col_offset + j); } + + if(i < X_n_cols) { val1 += m.at(row_offset, col_offset + i); } + + return val1 + val2; + } + + if(X_n_cols == 1) { return arrayops::accumulate( X.colptr(0), X_n_rows ); } + + eT val = eT(0); + + for(uword col=0; col < X_n_cols; ++col) + { + val += arrayops::accumulate( X.colptr(col), X_n_rows ); } return val; @@ -465,7 +585,7 @@ accu(const subview_col& X) { arma_extra_debug_sigprint(); - return arrayops::accumulate( X.colptr(0), X.n_rows ); + return arrayops::accumulate( X.colmem, X.n_rows ); } @@ -523,7 +643,7 @@ accu_cube_proxy_linear(const ProxyCube& P) } else { - #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0) + #if defined(__FAST_MATH__) { if(P.is_aligned()) { @@ -725,23 +845,36 @@ accu(const SpBase& expr) const SpProxy P(expr.get_ref()); + const uword N = P.get_n_nonzero(); + + if(N == 0) { return eT(0); } + if(SpProxy::use_iterator == false) { // direct counting - return arrayops::accumulate(P.get_values(), P.get_n_nonzero()); + return arrayops::accumulate(P.get_values(), N); } - else + + if(is_SpSubview::stored_type>::value) { - typename SpProxy::const_iterator_type it = P.begin(); - - const uword P_n_nz = P.get_n_nonzero(); - - eT val = eT(0); - - for(uword i=0; i < P_n_nz; ++i) { val += (*it); ++it; } + const SpSubview& sv = reinterpret_cast< const SpSubview& >(P.Q); - return val; + if(sv.n_rows == sv.m.n_rows) + { + const SpMat& m = sv.m; + const uword col = sv.aux_col1; + + return arrayops::accumulate(&(m.values[ m.col_ptrs[col] ]), N); + } } + + typename SpProxy::const_iterator_type it = P.begin(); + + eT val = eT(0); + + for(uword i=0; i < N; ++i) { val += (*it); ++it; } + + return val; } diff --git a/src/armadillo_bits/fn_all.hpp b/src/armadillo_bits/fn_all.hpp index 53f229b9..c69095d0 100644 --- a/src/armadillo_bits/fn_all.hpp +++ b/src/armadillo_bits/fn_all.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_any.hpp b/src/armadillo_bits/fn_any.hpp index cf2e5b24..9038059d 100644 --- a/src/armadillo_bits/fn_any.hpp +++ b/src/armadillo_bits/fn_any.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_approx_equal.hpp b/src/armadillo_bits/fn_approx_equal.hpp index 70a89a13..e92b94fd 100644 --- a/src/armadillo_bits/fn_approx_equal.hpp +++ b/src/armadillo_bits/fn_approx_equal.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -287,7 +289,7 @@ internal_approx_equal_handler(const T1& A, const T2& B, const char* method, cons typedef typename T1::pod_type T; - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); @@ -322,7 +324,7 @@ internal_approx_equal_handler(const T1& A, const T2& B, const char* method, cons typedef typename T1::pod_type T; - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 'a') && (sig != 'r') && (sig != 'b')), "approx_equal(): argument 'method' must be \"absdiff\" or \"reldiff\" or \"both\"" ); @@ -408,7 +410,7 @@ approx_equal(const SpBase& A, const SpBase& A, const SpBase::apply(const T1& X) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - const Proxy P(X); - if(P.get_n_elem() != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); return (Proxy::use_at) ? P.at(0,0) : P[0]; } @@ -145,19 +140,14 @@ as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_time const strip_inv strip1(X.A.B); const strip_diagmat strip2(strip1.M); - const bool tmp2_do_inv = strip1.do_inv; + const bool tmp2_do_inv_gen = strip1.do_inv_gen && arma_config::optimise_invexpr; const bool tmp2_do_diagmat = strip2.do_diagmat; if(tmp2_do_diagmat == false) { const Mat tmp(X); - if(tmp.n_elem != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (tmp.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); return tmp[0]; } @@ -197,7 +187,7 @@ as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_time if(B_is_vec) { - if(tmp2_do_inv) + if(tmp2_do_inv_gen) { return val * op_dotext::direct_rowvec_invdiagvec_colvec(A.mem, B, C.mem); } @@ -208,7 +198,7 @@ as_scalar_redirect<3>::apply(const Glue< Glue, T3, glue_time } else { - if(tmp2_do_inv) + if(tmp2_do_inv_gen) { return val * op_dotext::direct_rowvec_invdiagmat_colvec(A.mem, B, C.mem); } @@ -234,12 +224,7 @@ as_scalar_diag(const Base& X) const unwrap tmp(X.get_ref()); const Mat& A = tmp.M; - if(A.n_elem != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); return A.mem[0]; } @@ -309,70 +294,20 @@ as_scalar_diag(const Glue< Glue, T3, glue_times >& X) template arma_warn_unused -arma_inline -typename T1::elem_type -as_scalar(const Glue& X, const typename arma_not_cx::result* junk = 0) - { - arma_extra_debug_sigprint(); - arma_ignore(junk); - - if(is_glue_times_diag::value == false) - { - const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; - - arma_extra_debug_print(arma_str::format("N_mat = %d") % N_mat); - - return as_scalar_redirect::apply(X); - } - else - { - return as_scalar_diag(X); - } - } - - - -template -arma_warn_unused inline typename T1::elem_type -as_scalar(const Base& X) +as_scalar(const Glue& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); + arma_ignore(junk); - typedef typename T1::elem_type eT; - - const Proxy P(X.get_ref()); - - if(P.get_n_elem() != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } - - return (Proxy::use_at) ? P.at(0,0) : P[0]; - } - - -template -arma_warn_unused -inline -typename T1::elem_type -as_scalar(const Gen& X) - { - arma_extra_debug_sigprint(); + if(is_glue_times_diag::value) { return as_scalar_diag(X); } - typedef typename T1::elem_type eT; + constexpr uword N_mat = 1 + depth_lhs< glue_times, Glue >::num; - if( (X.n_rows != 1) || (X.n_cols != 1) ) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_extra_debug_print(arma_str::format("N_mat = %u") % N_mat); - return eT(arma_rng::randu()); + return as_scalar_redirect::apply(X); } @@ -381,24 +316,18 @@ template arma_warn_unused inline typename T1::elem_type -as_scalar(const Gen& X) +as_scalar(const Base& X) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; + const Proxy P(X.get_ref()); - if( (X.n_rows != 1) || (X.n_cols != 1) ) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); - return eT(arma_rng::randn()); + return (Proxy::use_at) ? P.at(0,0) : P[0]; } - template arma_warn_unused inline @@ -407,16 +336,9 @@ as_scalar(const BaseCube& X) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - const ProxyCube P(X.get_ref()); - if(P.get_n_elem() != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (P.get_n_elem() != 1), "as_scalar(): expression must evaluate to exactly one element" ); return (ProxyCube::use_at) ? P.at(0,0,0) : P[0]; } @@ -440,17 +362,14 @@ inline typename T1::elem_type as_scalar(const SpBase& X) { + arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; const unwrap_spmat tmp(X.get_ref()); const SpMat& A = tmp.M; - if(A.n_elem != 1) - { - arma_debug_check(true, "as_scalar(): expression doesn't evaluate to exactly one element"); - - return Datum::nan; - } + arma_debug_check( (A.n_elem != 1), "as_scalar(): expression must evaluate to exactly one element" ); return A.at(0,0); } diff --git a/src/armadillo_bits/fn_chi2rnd.hpp b/src/armadillo_bits/fn_chi2rnd.hpp index 51752e16..9da08f44 100644 --- a/src/armadillo_bits/fn_chi2rnd.hpp +++ b/src/armadillo_bits/fn_chi2rnd.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,19 +28,9 @@ chi2rnd(const double df) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) - { - op_chi2rnd_varying_df generator; - - return generator(df); - } - #else - { - arma_stop_logic_error("chi2rnd(): C++11 compiler required"); - - return double(0); - } - #endif + op_chi2rnd_varying_df generator; + + return generator(df); } @@ -51,19 +43,9 @@ chi2rnd(const eT df) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) - { - op_chi2rnd_varying_df generator; - - return generator(df); - } - #else - { - arma_stop_logic_error("chi2rnd(): C++11 compiler required"); - - return eT(0); - } - #endif + op_chi2rnd_varying_df generator; + + return generator(df); } @@ -109,7 +91,7 @@ chi2rnd(const typename obj_type::elem_type df, const uword n_rows, const uword n arma_debug_check( (n_rows != 1), "chi2rnd(): incompatible size" ); } - obj_type out(n_rows, n_cols); + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); op_chi2rnd::fill_constant_df(out, df); diff --git a/src/armadillo_bits/fn_chol.hpp b/src/armadillo_bits/fn_chol.hpp index e9f497cf..dfd9e6e4 100644 --- a/src/armadillo_bits/fn_chol.hpp +++ b/src/armadillo_bits/fn_chol.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -31,7 +33,7 @@ chol { arma_extra_debug_sigprint(); - const char sig = (layout != NULL) ? layout[0] : char(0); + const char sig = (layout != nullptr) ? layout[0] : char(0); arma_debug_check( ((sig != 'u') && (sig != 'l')), "chol(): layout must be \"upper\" or \"lower\"" ); @@ -52,7 +54,7 @@ chol { arma_extra_debug_sigprint(); - const char sig = (layout != NULL) ? layout[0] : char(0); + const char sig = (layout != nullptr) ? layout[0] : char(0); arma_debug_check( ((sig != 'u') && (sig != 'l')), "chol(): layout must be \"upper\" or \"lower\"" ); @@ -61,7 +63,82 @@ chol if(status == false) { out.soft_reset(); - arma_debug_warn("chol(): decomposition failed"); + arma_debug_warn_level(3, "chol(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +chol + ( + Mat& out, + Mat& P, + const Base& X, + const char* layout = "upper", + const char* P_mode = "matrix" + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const char sig_layout = (layout != nullptr) ? layout[0] : char(0); + const char sig_P_mode = (P_mode != nullptr) ? P_mode[0] : char(0); + + arma_debug_check( ((sig_layout != 'u') && (sig_layout != 'l')), "chol(): argument 'layout' must be \"upper\" or \"lower\"" ); + arma_debug_check( ((sig_P_mode != 'm') && (sig_P_mode != 'v')), "chol(): argument 'P_mode' must be \"vector\" or \"matrix\"" ); + + out = X.get_ref(); + + arma_debug_check( (out.is_square() == false), "chol(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if(out.is_empty()) + { + P.reset(); + return true; + } + + if((arma_config::debug) && (auxlib::rudimentary_sym_check(out) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "chol(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "chol(): given matrix is not hermitian"); } + } + + bool status = false; + + if(sig_P_mode == 'v') + { + status = auxlib::chol_pivot(out, P, ((sig_layout == 'u') ? 0 : 1)); + } + else + if(sig_P_mode == 'm') + { + Mat P_vec; + + status = auxlib::chol_pivot(out, P_vec, ((sig_layout == 'u') ? 0 : 1)); + + if(status) + { + // construct P + + const uword N = P_vec.n_rows; + + P.zeros(N,N); + + for(uword i=0; i < N; ++i) { P.at(P_vec[i], i) = uword(1); } + } + } + + if(status == false) + { + out.soft_reset(); + P.soft_reset(); + arma_debug_warn_level(3, "chol(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_clamp.hpp b/src/armadillo_bits/fn_clamp.hpp index 2aad326c..a7da6d48 100644 --- a/src/armadillo_bits/fn_clamp.hpp +++ b/src/armadillo_bits/fn_clamp.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -32,24 +34,38 @@ clamp(const T1& X, const typename T1::elem_type min_val, const typename T1::elem { arma_extra_debug_sigprint(); - arma_debug_check( (min_val > max_val), "clamp(): min_val has to be smaller than max_val" ); - return mtOp(mtOp_dual_aux_indicator(), X, min_val, max_val); } +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && is_cx::yes, + const mtOp + >::result +clamp(const T1& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + return mtOp(mtOp_dual_aux_indicator(), X, min_val, max_val); + } + + + template arma_warn_unused inline const mtOpCube -clamp(const BaseCube& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val, typename arma_not_cx::result* junk = 0) +clamp(const BaseCube& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val, typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - arma_debug_check( (min_val > max_val), "clamp(): min_val has to be smaller than max_val" ); - return mtOpCube(mtOpCube_dual_aux_indicator(), X.get_ref(), min_val, max_val); } @@ -58,36 +74,40 @@ clamp(const BaseCube& X, const typename T1::elem_type template arma_warn_unused inline -typename -enable_if2 - < - is_cx::no, - SpMat - >::result +const mtOpCube +clamp(const BaseCube& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val, typename arma_cx_only::result* junk = nullptr) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtOpCube(mtOpCube_dual_aux_indicator(), X.get_ref(), min_val, max_val); + } + + + +template +arma_warn_unused +inline +SpMat clamp(const SpBase& X, const typename T1::elem_type min_val, const typename T1::elem_type max_val) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - arma_debug_check( (min_val > max_val), "clamp(): min_val has to be smaller than max_val" ); - - SpMat out = X.get_ref(); - - out.sync(); - - const uword N = out.n_nonzero; - - eT* out_values = access::rwp(out.values); - - for(uword i=0; i::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "clamp(): min_val must be less than max_val" ); + } + else { - eT& out_val = out_values[i]; - - out_val = (out_val < min_val) ? min_val : ( (out_val > max_val) ? max_val : out_val ); + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "clamp(): imag(min_val) must be less than imag(max_val)" ); } - if( (min_val == eT(0)) || (max_val == eT(0)) ) { out.remove_zeros(); } + SpMat out = X.get_ref(); + + out.clamp(min_val, max_val); return out; } diff --git a/src/armadillo_bits/fn_cond.hpp b/src/armadillo_bits/fn_cond_rcond.hpp similarity index 74% rename from src/armadillo_bits/fn_cond.hpp rename to src/armadillo_bits/fn_cond_rcond.hpp index f955310b..fae0a06a 100644 --- a/src/armadillo_bits/fn_cond.hpp +++ b/src/armadillo_bits/fn_cond_rcond.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,7 +29,7 @@ cond(const Base& X) { arma_extra_debug_sigprint(); - return op_cond::cond(X.get_ref()); + return op_cond::apply(X.get_ref()); } @@ -40,9 +42,22 @@ rcond(const Base& X) { arma_extra_debug_sigprint(); - return op_cond::rcond(X.get_ref()); + return op_rcond::apply(X.get_ref()); } +// template +// arma_warn_unused +// inline +// typename enable_if2::value, typename T1::pod_type>::result +// rcond(const SpBase& X) +// { +// arma_extra_debug_sigprint(); +// +// return sp_auxlib::rcond(X.get_ref()); +// } + + + //! @} diff --git a/src/armadillo_bits/fn_conv.hpp b/src/armadillo_bits/fn_conv.hpp index cb196d59..44b7a0e0 100644 --- a/src/armadillo_bits/fn_conv.hpp +++ b/src/armadillo_bits/fn_conv.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -34,7 +36,7 @@ conv(const T1& A, const T2& B, const char* shape = "full") { arma_extra_debug_sigprint(); - const char sig = (shape != NULL) ? shape[0] : char(0); + const char sig = (shape != nullptr) ? shape[0] : char(0); arma_debug_check( ((sig != 'f') && (sig != 's')), "conv(): unsupported value of 'shape' parameter" ); @@ -58,7 +60,7 @@ conv2(const T1& A, const T2& B, const char* shape = "full") { arma_extra_debug_sigprint(); - const char sig = (shape != NULL) ? shape[0] : char(0); + const char sig = (shape != nullptr) ? shape[0] : char(0); arma_debug_check( ((sig != 'f') && (sig != 's')), "conv2(): unsupported value of 'shape' parameter" ); diff --git a/src/armadillo_bits/fn_conv_to.hpp b/src/armadillo_bits/fn_conv_to.hpp index 4be7bf7c..dbfc7fe2 100644 --- a/src/armadillo_bits/fn_conv_to.hpp +++ b/src/armadillo_bits/fn_conv_to.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,16 +29,16 @@ class conv_to public: template - inline static out_eT from(const Base& in, const typename arma_not_cx::result* junk = 0); + inline static out_eT from(const Base& in, const typename arma_not_cx::result* junk = nullptr); template - inline static out_eT from(const Base& in, const typename arma_cx_only::result* junk = 0); + inline static out_eT from(const Base& in, const typename arma_cx_only::result* junk = nullptr); template - inline static out_eT from(const BaseCube& in, const typename arma_not_cx::result* junk = 0); + inline static out_eT from(const BaseCube& in, const typename arma_not_cx::result* junk = nullptr); template - inline static out_eT from(const BaseCube& in, const typename arma_cx_only::result* junk = 0); + inline static out_eT from(const BaseCube& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -55,7 +57,7 @@ conv_to::from(const Base& in, const typename arma_not_cx P(in.get_ref()); - arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object doesn't have exactly one element" ); + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); return out_eT(Proxy::use_at ? P.at(0,0) : P[0]); } @@ -76,7 +78,7 @@ conv_to::from(const Base& in, const typename arma_cx_only P(in.get_ref()); - arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object doesn't have exactly one element" ); + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); out_eT out; @@ -101,7 +103,7 @@ conv_to::from(const BaseCube& in, const typename arma_not_cx< const ProxyCube P(in.get_ref()); - arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object doesn't have exactly one element" ); + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); return out_eT(ProxyCube::use_at ? P.at(0,0,0) : P[0]); } @@ -122,7 +124,7 @@ conv_to::from(const BaseCube& in, const typename arma_cx_only const ProxyCube P(in.get_ref()); - arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object doesn't have exactly one element" ); + arma_debug_check( (P.get_n_elem() != 1), "conv_to(): given object does not have exactly one element" ); out_eT out; @@ -140,10 +142,10 @@ class conv_to< Mat > public: template - inline static Mat from(const Base& in, const typename arma_not_cx::result* junk = 0); + inline static Mat from(const Base& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Mat from(const Base& in, const typename arma_cx_only::result* junk = 0); + inline static Mat from(const Base& in, const typename arma_cx_only::result* junk = nullptr); template inline static Mat from(const SpBase& in); @@ -151,10 +153,10 @@ class conv_to< Mat > template - inline static Mat from(const std::vector& in, const typename arma_not_cx::result* junk = 0); + inline static Mat from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Mat from(const std::vector& in, const typename arma_cx_only::result* junk = 0); + inline static Mat from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -172,7 +174,7 @@ conv_to< Mat >::from(const Base& in, const typename arma_not_ const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - Mat out(X.n_rows, X.n_cols); + Mat out(X.n_rows, X.n_cols, arma_nozeros_indicator()); arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); @@ -194,7 +196,7 @@ conv_to< Mat >::from(const Base& in, const typename arma_cx_o const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - Mat out(X.n_rows, X.n_cols); + Mat out(X.n_rows, X.n_cols, arma_nozeros_indicator()); arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); @@ -229,7 +231,7 @@ conv_to< Mat >::from(const std::vector& in, const typename arma_n const uword N = uword( in.size() ); - Mat out(N, 1); + Mat out(N, 1, arma_nozeros_indicator()); if(N > 0) { @@ -253,7 +255,7 @@ conv_to< Mat >::from(const std::vector& in, const typename arma_c const uword N = uword( in.size() ); - Mat out(N, 1); + Mat out(N, 1, arma_nozeros_indicator()); if(N > 0) { @@ -272,18 +274,18 @@ class conv_to< Row > public: template - inline static Row from(const Base& in, const typename arma_not_cx::result* junk = 0); + inline static Row from(const Base& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Row from(const Base& in, const typename arma_cx_only::result* junk = 0); + inline static Row from(const Base& in, const typename arma_cx_only::result* junk = nullptr); template - inline static Row from(const std::vector& in, const typename arma_not_cx::result* junk = 0); + inline static Row from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Row from(const std::vector& in, const typename arma_cx_only::result* junk = 0); + inline static Row from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -301,9 +303,9 @@ conv_to< Row >::from(const Base& in, const typename arma_not_ const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); - Row out(X.n_elem); + Row out(X.n_elem, arma_nozeros_indicator()); arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); @@ -325,9 +327,9 @@ conv_to< Row >::from(const Base& in, const typename arma_cx_o const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); - Row out(X.n_rows, X.n_cols); + Row out(X.n_rows, X.n_cols, arma_nozeros_indicator()); arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); @@ -348,7 +350,7 @@ conv_to< Row >::from(const std::vector& in, const typename arma_n const uword N = uword( in.size() ); - Row out(N); + Row out(N, arma_nozeros_indicator()); if(N > 0) { @@ -372,7 +374,7 @@ conv_to< Row >::from(const std::vector& in, const typename arma_c const uword N = uword( in.size() ); - Row out(N); + Row out(N, arma_nozeros_indicator()); if(N > 0) { @@ -391,18 +393,18 @@ class conv_to< Col > public: template - inline static Col from(const Base& in, const typename arma_not_cx::result* junk = 0); + inline static Col from(const Base& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Col from(const Base& in, const typename arma_cx_only::result* junk = 0); + inline static Col from(const Base& in, const typename arma_cx_only::result* junk = nullptr); template - inline static Col from(const std::vector& in, const typename arma_not_cx::result* junk = 0); + inline static Col from(const std::vector& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Col from(const std::vector& in, const typename arma_cx_only::result* junk = 0); + inline static Col from(const std::vector& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -420,9 +422,9 @@ conv_to< Col >::from(const Base& in, const typename arma_not_ const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); - Col out(X.n_elem); + Col out(X.n_elem, arma_nozeros_indicator()); arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); @@ -444,9 +446,9 @@ conv_to< Col >::from(const Base& in, const typename arma_cx_o const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); - Col out(X.n_rows, X.n_cols); + Col out(X.n_rows, X.n_cols, arma_nozeros_indicator()); arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); @@ -467,7 +469,7 @@ conv_to< Col >::from(const std::vector& in, const typename arma_n const uword N = uword( in.size() ); - Col out(N); + Col out(N, arma_nozeros_indicator()); if(N > 0) { @@ -491,7 +493,7 @@ conv_to< Col >::from(const std::vector& in, const typename arma_c const uword N = uword( in.size() ); - Col out(N); + Col out(N, arma_nozeros_indicator()); if(N > 0) { @@ -510,10 +512,10 @@ class conv_to< SpMat > public: template - inline static SpMat from(const SpBase& in, const typename arma_not_cx::result* junk = 0); + inline static SpMat from(const SpBase& in, const typename arma_not_cx::result* junk = nullptr); template - inline static SpMat from(const SpBase& in, const typename arma_cx_only::result* junk = 0); + inline static SpMat from(const SpBase& in, const typename arma_cx_only::result* junk = nullptr); template inline static SpMat from(const Base& in); @@ -590,10 +592,10 @@ class conv_to< Cube > public: template - inline static Cube from(const BaseCube& in, const typename arma_not_cx::result* junk = 0); + inline static Cube from(const BaseCube& in, const typename arma_not_cx::result* junk = nullptr); template - inline static Cube from(const BaseCube& in, const typename arma_cx_only::result* junk = 0); + inline static Cube from(const BaseCube& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -611,7 +613,7 @@ conv_to< Cube >::from(const BaseCube& in, const typename arma const unwrap_cube tmp( in.get_ref() ); const Cube& X = tmp.M; - Cube out(X.n_rows, X.n_cols, X.n_slices); + Cube out(X.n_rows, X.n_cols, X.n_slices, arma_nozeros_indicator()); arrayops::convert( out.memptr(), X.memptr(), X.n_elem ); @@ -633,7 +635,7 @@ conv_to< Cube >::from(const BaseCube& in, const typename arma const unwrap_cube tmp( in.get_ref() ); const Cube& X = tmp.M; - Cube out(X.n_rows, X.n_cols, X.n_slices); + Cube out(X.n_rows, X.n_cols, X.n_slices, arma_nozeros_indicator()); arrayops::convert_cx( out.memptr(), X.memptr(), X.n_elem ); @@ -649,10 +651,10 @@ class conv_to< std::vector > public: template - inline static std::vector from(const Base& in, const typename arma_not_cx::result* junk = 0); + inline static std::vector from(const Base& in, const typename arma_not_cx::result* junk = nullptr); template - inline static std::vector from(const Base& in, const typename arma_cx_only::result* junk = 0); + inline static std::vector from(const Base& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -670,7 +672,7 @@ conv_to< std::vector >::from(const Base& in, const typename a const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); const uword N = X.n_elem; @@ -699,7 +701,7 @@ conv_to< std::vector >::from(const Base& in, const typename a const quasi_unwrap tmp(in.get_ref()); const Mat& X = tmp.M; - arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object can't be interpreted as a vector" ); + arma_debug_check( ( (X.is_vec() == false) && (X.is_empty() == false) ), "conv_to(): given object cannot be interpreted as a vector" ); const uword N = X.n_elem; diff --git a/src/armadillo_bits/fn_cor.hpp b/src/armadillo_bits/fn_cor.hpp index 37007504..18cd2fa7 100644 --- a/src/armadillo_bits/fn_cor.hpp +++ b/src/armadillo_bits/fn_cor.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_cov.hpp b/src/armadillo_bits/fn_cov.hpp index 873daeaf..ee61c005 100644 --- a/src/armadillo_bits/fn_cov.hpp +++ b/src/armadillo_bits/fn_cov.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_cross.hpp b/src/armadillo_bits/fn_cross.hpp index 5b6c1981..bb08e179 100644 --- a/src/armadillo_bits/fn_cross.hpp +++ b/src/armadillo_bits/fn_cross.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,12 +25,17 @@ template arma_warn_unused inline -const Glue -cross(const Base& X, const Base& Y) +typename +enable_if2 + < + is_arma_type::value && is_arma_type::value && is_same_type::value, + const Glue + >::result +cross(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); - return Glue(X.get_ref(), Y.get_ref()); + return Glue(X, Y); } diff --git a/src/armadillo_bits/fn_cumprod.hpp b/src/armadillo_bits/fn_cumprod.hpp index ef417115..f6cd1e17 100644 --- a/src/armadillo_bits/fn_cumprod.hpp +++ b/src/armadillo_bits/fn_cumprod.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_cumsum.hpp b/src/armadillo_bits/fn_cumsum.hpp index d77f74aa..ad6c6375 100644 --- a/src/armadillo_bits/fn_cumsum.hpp +++ b/src/armadillo_bits/fn_cumsum.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_det.hpp b/src/armadillo_bits/fn_det.hpp index 972c1e49..3941a85a 100644 --- a/src/armadillo_bits/fn_det.hpp +++ b/src/armadillo_bits/fn_det.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,156 +25,45 @@ template arma_warn_unused inline typename enable_if2< is_supported_blas_type::value, typename T1::elem_type >::result -det - ( - const Base& X - ) - { - arma_extra_debug_sigprint(); - - return auxlib::det(X.get_ref()); - } - - - -template -arma_warn_unused -inline -typename T1::elem_type -det - ( - const Op& X - ) +det(const Base& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const diagmat_proxy A(X.m); - - arma_debug_check( (A.n_rows != A.n_cols), "det(): given matrix must be square sized" ); + eT out_val = eT(0); - const uword N = (std::min)(A.n_rows, A.n_cols); + const bool status = op_det::apply_direct(out_val, X.get_ref()); - eT val1 = eT(1); - eT val2 = eT(1); - - uword i,j; - for(i=0, j=1; j -arma_warn_unused inline -typename T1::elem_type -det - ( - const Op& X - ) +typename enable_if2< is_supported_blas_type::value, bool >::result +det(typename T1::elem_type& out_val, const Base& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(X.m); - - const uword N = P.get_n_rows(); - - arma_debug_check( (N != P.get_n_cols()), "det(): given matrix must be square sized" ); - - eT val1 = eT(1); - eT val2 = eT(1); - - uword i,j; - for(i=0, j=1; j -arma_warn_unused -inline -typename enable_if2< is_supported_blas_type::value, typename T1::elem_type >::result -det - ( - const Op& X - ) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const eT tmp = det(X.m); - - if(tmp == eT(0)) { arma_debug_warn("det(): denominator is zero" ); } - - return eT(1) / tmp; - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, typename T1::elem_type >::result -det - ( - const Base& X, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("det(X,bool) is deprecated and will be removed; change to det(X)"); - - return det(X.get_ref()); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, typename T1::elem_type >::result -det - ( - const Base& X, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("det(X,char*) is deprecated and will be removed; change to det(X)"); - - return det(X.get_ref()); + return status; } diff --git a/src/armadillo_bits/fn_diagmat.hpp b/src/armadillo_bits/fn_diagmat.hpp index 88508119..7d6c5b0d 100644 --- a/src/armadillo_bits/fn_diagmat.hpp +++ b/src/armadillo_bits/fn_diagmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,7 +20,7 @@ //! @{ -//! interpret a matrix or a vector as a diagonal matrix (i.e. off-diagonal entries are zero) +//! interpret a matrix or a vector as a diagonal matrix (ie. off-diagonal entries are zero) template arma_warn_unused arma_inline diff --git a/src/armadillo_bits/fn_diags_spdiags.hpp b/src/armadillo_bits/fn_diags_spdiags.hpp new file mode 100644 index 00000000..ceb8fd0f --- /dev/null +++ b/src/armadillo_bits/fn_diags_spdiags.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_diags_spdiags +//! @{ + + + +template +inline +Mat +diags(const Base& V_expr, const Base& D_expr, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UV(V_expr.get_ref()); + const Mat& V = UV.M; + + const quasi_unwrap UD(D_expr.get_ref()); + const Mat& D = UD.M; + + arma_debug_check( ((D.is_vec() == false) && (D.is_empty() == false)), "D must be a vector" ); + + arma_debug_check( (V.n_cols != D.n_elem), "number of colums in matrix V must match the length of vector D" ); + + Mat out(n_rows, n_cols, fill::zeros); + + for(uword i=0; i < D.n_elem; ++i) + { + const sword diag_id = D[i]; + + const uword row_offset = (diag_id < 0) ? uword(-diag_id) : 0; + const uword col_offset = (diag_id > 0) ? uword( diag_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diags(): requested diagonal out of bounds" + ); + + const uword diag_len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword V_start = (diag_id < 0) ? uword(0) : uword(diag_id); + + const eT* V_colmem = V.colptr(i); + + for(uword j=0; j < diag_len; ++j) + { + const uword V_index = V_start + j; + + if(V_index >= V.n_rows) { break; } + + out.at(j + row_offset, j + col_offset) = V_colmem[V_index]; + } + } + + return out; + } + + + +template +inline +SpMat +spdiags(const Base& V_expr, const Base& D_expr, const uword n_rows, const uword n_cols) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UV(V_expr.get_ref()); + const Mat& V = UV.M; + + const quasi_unwrap UD(D_expr.get_ref()); + const Mat& D = UD.M; + + arma_debug_check( ((D.is_vec() == false) && (D.is_empty() == false)), "D must be a vector" ); + + arma_debug_check( (V.n_cols != D.n_elem), "number of colums in matrix V must match the length of vector D" ); + + MapMat tmp(n_rows, n_cols); + + for(uword i=0; i < D.n_elem; ++i) + { + const sword diag_id = D[i]; + + const uword row_offset = (diag_id < 0) ? uword(-diag_id) : 0; + const uword col_offset = (diag_id > 0) ? uword( diag_id) : 0; + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diags(): requested diagonal out of bounds" + ); + + const uword diag_len = (std::min)(n_rows - row_offset, n_cols - col_offset); + + const uword V_start = (diag_id < 0) ? uword(0) : uword(diag_id); + + const eT* V_colmem = V.colptr(i); + + for(uword j=0; j < diag_len; ++j) + { + const uword V_index = V_start + j; + + if(V_index >= V.n_rows) { break; } + + tmp.at(j + row_offset, j + col_offset) = V_colmem[V_index]; + } + } + + return SpMat(tmp); + } + + + +//! @} diff --git a/src/armadillo_bits/fn_diagvec.hpp b/src/armadillo_bits/fn_diagvec.hpp index 664c9acb..873c3506 100644 --- a/src/armadillo_bits/fn_diagvec.hpp +++ b/src/armadillo_bits/fn_diagvec.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,16 +20,30 @@ //! @{ -//! extract a diagonal from a matrix +//! extract main diagonal from matrix template arma_warn_unused arma_inline const Op -diagvec(const Base& X, const sword diag_id = 0) +diagvec(const Base& X) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +//! extract arbitrary diagonal from matrix +template +arma_warn_unused +arma_inline +const Op +diagvec(const Base& X, const sword diag_id) { arma_extra_debug_sigprint(); - return Op(X.get_ref(), ((diag_id < 0) ? -diag_id : diag_id), ((diag_id < 0) ? 1 : 0) ); + return Op(X.get_ref(), ((diag_id < 0) ? -diag_id : diag_id), ((diag_id < 0) ? 1 : 0) ); } diff --git a/src/armadillo_bits/fn_diff.hpp b/src/armadillo_bits/fn_diff.hpp index 248f06e8..2d7e8d3a 100644 --- a/src/armadillo_bits/fn_diff.hpp +++ b/src/armadillo_bits/fn_diff.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_dot.hpp b/src/armadillo_bits/fn_dot.hpp index dec104e5..d2cbfc81 100644 --- a/src/armadillo_bits/fn_dot.hpp +++ b/src/armadillo_bits/fn_dot.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_eig_gen.hpp b/src/armadillo_bits/fn_eig_gen.hpp index ff1c5f07..ad228217 100644 --- a/src/armadillo_bits/fn_eig_gen.hpp +++ b/src/armadillo_bits/fn_eig_gen.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -33,11 +35,11 @@ eig_gen typedef typename T1::pod_type T; typedef typename std::complex eT; - const char sig = (option != NULL) ? option[0] : char(0); + const char sig = (option != nullptr) ? option[0] : char(0); arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); - if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn( "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } Col eigvals; Mat eigvecs; @@ -70,11 +72,11 @@ eig_gen typedef typename T1::pod_type T; typedef typename std::complex eT; - const char sig = (option != NULL) ? option[0] : char(0); + const char sig = (option != nullptr) ? option[0] : char(0); arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); - if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn( "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } Mat eigvecs; @@ -83,7 +85,7 @@ eig_gen if(status == false) { eigvals.soft_reset(); - arma_debug_warn("eig_gen(): decomposition failed"); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); } return status; @@ -106,11 +108,11 @@ eig_gen arma_debug_check( (void_ptr(&eigvals) == void_ptr(&eigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); - const char sig = (option != NULL) ? option[0] : char(0); + const char sig = (option != nullptr) ? option[0] : char(0); arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); - if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn( "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } const bool status = (sig == 'b') ? auxlib::eig_gen_balance(eigvals, eigvecs, true, expr.get_ref()) : auxlib::eig_gen(eigvals, eigvecs, true, expr.get_ref()); @@ -118,7 +120,46 @@ eig_gen { eigvals.soft_reset(); eigvecs.soft_reset(); - arma_debug_warn("eig_gen(): decomposition failed"); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_gen + ( + Col< std::complex >& eigvals, + Mat< std::complex >& leigvecs, + Mat< std::complex >& reigvecs, + const Base& expr, + const char* option = "nobalance" + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&leigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'leigvec'" ); + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&reigvecs)), "eig_gen(): parameter 'eigval' is an alias of parameter 'reigvec'" ); + arma_debug_check( (void_ptr(&leigvecs) == void_ptr(&reigvecs)), "eig_gen(): parameter 'leigvec' is an alias of parameter 'reigvec'" ); + + const char sig = (option != nullptr) ? option[0] : char(0); + + arma_debug_check( ((sig != 'n') && (sig != 'b')), "eig_gen(): unknown option" ); + + if( auxlib::crippled_lapack(expr) && (sig == 'b') ) { arma_debug_warn_level(1, "eig_gen(): 'balance' option ignored due to linking with crippled lapack"); } + + const bool status = (sig == 'b') ? auxlib::eig_gen_twosided_balance(eigvals, leigvecs, reigvecs, expr.get_ref()) : auxlib::eig_gen_twosided(eigvals, leigvecs, reigvecs, expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + leigvecs.soft_reset(); + reigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_gen(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_eig_pair.hpp b/src/armadillo_bits/fn_eig_pair.hpp index e1afc1ee..cef13894 100644 --- a/src/armadillo_bits/fn_eig_pair.hpp +++ b/src/armadillo_bits/fn_eig_pair.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -69,7 +71,7 @@ eig_pair if(status == false) { eigvals.soft_reset(); - arma_debug_warn("eig_pair(): decomposition failed"); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); } return status; @@ -98,11 +100,45 @@ eig_pair { eigvals.soft_reset(); eigvecs.soft_reset(); - arma_debug_warn("eig_pair(): decomposition failed"); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +eig_pair + ( + Col< std::complex >& eigvals, + Mat< std::complex >& leigvecs, + Mat< std::complex >& reigvecs, + const Base< typename T1::elem_type, T1 >& A_expr, + const Base< typename T1::elem_type, T2 >& B_expr + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&leigvecs)), "eig_pair(): parameter 'eigval' is an alias of parameter 'leigvec'" ); + arma_debug_check( (void_ptr(&eigvals) == void_ptr(&reigvecs)), "eig_pair(): parameter 'eigval' is an alias of parameter 'reigvec'" ); + arma_debug_check( (void_ptr(&leigvecs) == void_ptr(&reigvecs)), "eig_pair(): parameter 'leigvec' is an alias of parameter 'reigvec'" ); + + const bool status = auxlib::eig_pair_twosided(eigvals, leigvecs, reigvecs, A_expr.get_ref(), B_expr.get_ref()); + + if(status == false) + { + eigvals.soft_reset(); + leigvecs.soft_reset(); + reigvecs.soft_reset(); + arma_debug_warn_level(3, "eig_pair(): decomposition failed"); } return status; } + //! @} diff --git a/src/armadillo_bits/fn_eig_sym.hpp b/src/armadillo_bits/fn_eig_sym.hpp index 47b85881..12043b6f 100644 --- a/src/armadillo_bits/fn_eig_sym.hpp +++ b/src/armadillo_bits/fn_eig_sym.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,15 +32,16 @@ eig_sym { arma_extra_debug_sigprint(); - // unwrap_check not used as T1::elem_type and T1::pod_type may not be the same. - // furthermore, it doesn't matter if X is an alias of eigval, as auxlib::eig_sym() makes a copy of X + typedef typename T1::elem_type eT; + + Mat A(X.get_ref()); - const bool status = auxlib::eig_sym(eigval, X); + const bool status = auxlib::eig_sym(eigval, A); if(status == false) { eigval.soft_reset(); - arma_debug_warn("eig_sym(): decomposition failed"); + arma_debug_warn_level(3, "eig_sym(): decomposition failed"); } return status; @@ -58,16 +61,21 @@ eig_sym { arma_extra_debug_sigprint(); - Col out; - const bool status = auxlib::eig_sym(out, X); + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Col< T> eigval; + Mat A(X.get_ref()); + + const bool status = auxlib::eig_sym(eigval, A); if(status == false) { - out.soft_reset(); + eigval.reset(); arma_stop_runtime_error("eig_sym(): decomposition failed"); } - return out; + return eigval; } @@ -87,17 +95,10 @@ eig_sym_helper { arma_extra_debug_sigprint(); - // if(auxlib::rudimentary_sym_check(X) == false) - // { - // if(is_cx::no ) { arma_debug_warn(caller_sig, ": given matrix is not symmetric"); } - // if(is_cx::yes) { arma_debug_warn(caller_sig, ": given matrix is not hermitian"); } - // return false; - // } - if((arma_config::debug) && (auxlib::rudimentary_sym_check(X) == false)) { - if(is_cx::no ) { arma_debug_warn(caller_sig, ": given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn(caller_sig, ": given matrix is not hermitian"); } + if(is_cx::no ) { arma_debug_warn_level(1, caller_sig, ": given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, caller_sig, ": given matrix is not hermitian"); } } bool status = false; @@ -127,7 +128,7 @@ eig_sym typedef typename T1::elem_type eT; - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'd')), "eig_sym(): unknown method specified" ); arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eig_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); @@ -145,7 +146,7 @@ eig_sym { eigval.soft_reset(); eigvec.soft_reset(); - arma_debug_warn("eig_sym(): decomposition failed"); + arma_debug_warn_level(3, "eig_sym(): decomposition failed"); } else { diff --git a/src/armadillo_bits/fn_eigs_gen.hpp b/src/armadillo_bits/fn_eigs_gen.hpp index 945da314..6f1a617c 100644 --- a/src/armadillo_bits/fn_eigs_gen.hpp +++ b/src/armadillo_bits/fn_eigs_gen.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,25 +24,123 @@ template arma_warn_unused inline -Col< std::complex > +typename enable_if2< is_real::value, Col< std::complex > >::result eigs_gen ( const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::pod_type tol = 0.0, - const typename arma_blas_type_only::result* junk = 0 + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + Col< std::complex > eigval; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_gen(): decomposition failed"); + } + + return eigval; + } + + + +//! this form is deprecated; use eigs_gen(X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(X, n_eigvals, form, opts); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + Col< std::complex > eigval; + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_gen(): decomposition failed"); + } + + return eigval; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col< std::complex > >::result +eigs_gen + ( + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); typedef typename T1::pod_type T; Mat< std::complex > eigvec; Col< std::complex > eigval; - const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form, tol); + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); if(status == false) { @@ -56,30 +156,127 @@ eigs_gen //! eigenvalues of general sparse matrix X template inline -bool +typename enable_if2< is_real::value, bool >::result eigs_gen ( Col< std::complex >& eigval, const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::pod_type tol = 0.0, - const typename arma_blas_type_only::result* junk = 0 + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_gen(eigval, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(eigval, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat< std::complex > eigvec; + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); typedef typename T1::pod_type T; Mat< std::complex > eigvec; - const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form, tol); + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); if(status == false) { eigval.soft_reset(); - arma_debug_warn("eigs_gen(): decomposition failed"); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); } return status; @@ -87,10 +284,10 @@ eigs_gen -//! eigenvalues and eigenvectors of general real sparse matrix X +//! eigenvalues and eigenvectors of general sparse matrix X template inline -bool +typename enable_if2< is_real::value, bool >::result eigs_gen ( Col< std::complex >& eigval, @@ -98,22 +295,126 @@ eigs_gen const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::pod_type tol = 0.0, - const typename arma_blas_type_only::result* junk = 0 + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + // typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_gen(eigval, eigvec, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::pod_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_gen(eigval, eigvec, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const std::complex sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + bool status = false; + + // If X is real and sigma is truly complex, treat X as complex. + // The reason is that we are still not able to apply truly complex shifts to real matrices + if( (is_real::yes) && (std::imag(sigma) != T(0)) ) + { + status = sp_auxlib::eigs_gen(eigval, eigvec, conv_to< SpMat< std::complex > >::from(X), n_eigvals, sigma, opts); + } + else + { + status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, sigma, opts); + } + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_gen + ( + Col< std::complex >& eigval, + Mat< std::complex >& eigvec, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); + + typedef typename T1::pod_type T; arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_gen(): parameter 'eigval' is an alias of parameter 'eigvec'" ); - const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, form, tol); + const bool status = sp_auxlib::eigs_gen(eigval, eigvec, X, n_eigvals, std::complex(T(sigma)), opts); if(status == false) { eigval.soft_reset(); eigvec.soft_reset(); - arma_debug_warn("eigs_gen(): decomposition failed"); + arma_debug_warn_level(3, "eigs_gen(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_eigs_sym.hpp b/src/armadillo_bits/fn_eigs_sym.hpp index 5c73c3a1..935aab92 100644 --- a/src/armadillo_bits/fn_eigs_sym.hpp +++ b/src/armadillo_bits/fn_eigs_sym.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,23 +24,78 @@ template arma_warn_unused inline -Col +typename enable_if2< is_real::value, Col >::result eigs_sym ( const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::elem_type tol = 0.0, - const typename arma_real_only::result* junk = 0 + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + Mat eigvec; + Col eigval; + + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + arma_stop_runtime_error("eigs_sym(): decomposition failed"); + } + + return eigval; + } + + + +//! this form is deprecated; use eigs_sym(X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, Col >::result +eigs_sym + ( + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(X, n_eigvals, form, opts); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_real::value, Col >::result +eigs_sym + ( + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); + + typedef typename T1::pod_type T; Mat eigvec; Col eigval; - const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form, tol); + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); if(status == false) { @@ -54,28 +111,83 @@ eigs_sym //! eigenvalues of symmetric real sparse matrix X template inline -bool +typename enable_if2< is_real::value, bool >::result eigs_sym ( Col& eigval, const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::elem_type tol = 0.0, - const typename arma_real_only::result* junk = 0 + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); Mat eigvec; - const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form, tol); + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); if(status == false) { eigval.soft_reset(); - arma_debug_warn("eigs_sym(): decomposition failed"); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_sym(eigval, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(eigval, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + Mat eigvec; + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); + + if(status == false) + { + eigval.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); } return status; @@ -86,7 +198,7 @@ eigs_sym //! eigenvalues and eigenvectors of symmetric real sparse matrix X template inline -bool +typename enable_if2< is_real::value, bool >::result eigs_sym ( Col& eigval, @@ -94,22 +206,80 @@ eigs_sym const SpBase& X, const uword n_eigvals, const char* form = "lm", - const typename T1::elem_type tol = 0.0, - const typename arma_real_only::result* junk = 0 + const eigs_opts opts = eigs_opts() ) { arma_extra_debug_sigprint(); - arma_ignore(junk); arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); - const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form, tol); + sp_auxlib::form_type form_val = sp_auxlib::interpret_form_str(form); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, form_val, opts); + + if(status == false) + { + eigval.soft_reset(); + eigvec.soft_reset(); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); + } + + return status; + } + + + +//! this form is deprecated; use eigs_sym(eigval, eigvec, X, n_eigvals, form, opts) instead +template +arma_deprecated +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + Mat& eigvec, + const SpBase& X, + const uword n_eigvals, + const char* form, + const typename T1::elem_type tol + ) + { + arma_extra_debug_sigprint(); + + eigs_opts opts; + opts.tol = tol; + + return eigs_sym(eigval, eigvec, X, n_eigvals, form, opts); + } + + + +template +inline +typename enable_if2< is_real::value, bool >::result +eigs_sym + ( + Col& eigval, + Mat& eigvec, + const SpBase& X, + const uword n_eigvals, + const double sigma, + const eigs_opts opts = eigs_opts() + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + arma_debug_check( void_ptr(&eigval) == void_ptr(&eigvec), "eigs_sym(): parameter 'eigval' is an alias of parameter 'eigvec'" ); + + const bool status = sp_auxlib::eigs_sym(eigval, eigvec, X, n_eigvals, T(sigma), opts); if(status == false) { eigval.soft_reset(); eigvec.soft_reset(); - arma_debug_warn("eigs_sym(): decomposition failed"); + arma_debug_warn_level(3, "eigs_sym(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_elem.hpp b/src/armadillo_bits/fn_elem.hpp index abbd07c3..917537f4 100644 --- a/src/armadillo_bits/fn_elem.hpp +++ b/src/armadillo_bits/fn_elem.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -436,7 +438,7 @@ template arma_warn_unused arma_inline const eOpCube -abs(const BaseCube& X, const typename arma_not_cx::result* junk = 0) +abs(const BaseCube& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); @@ -464,7 +466,7 @@ template arma_warn_unused inline const mtOpCube -abs(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = 0) +abs(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = nullptr) { arma_extra_debug_sigprint(); @@ -479,7 +481,7 @@ template arma_warn_unused arma_inline const SpOp -abs(const SpBase& X, const typename arma_not_cx::result* junk = 0) +abs(const SpBase& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -493,7 +495,7 @@ template arma_warn_unused arma_inline const mtSpOp -abs(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = 0) +abs(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -524,7 +526,7 @@ template arma_warn_unused arma_inline const eOpCube -arg(const BaseCube& X, const typename arma_not_cx::result* junk = 0) +arg(const BaseCube& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); @@ -552,7 +554,7 @@ template arma_warn_unused inline const mtOpCube -arg(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = 0) +arg(const BaseCube< std::complex,T1>& X, const typename arma_cx_only::result* junk = nullptr) { arma_extra_debug_sigprint(); @@ -567,7 +569,7 @@ template arma_warn_unused arma_inline const SpOp -arg(const SpBase& X, const typename arma_not_cx::result* junk = 0) +arg(const SpBase& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -581,7 +583,7 @@ template arma_warn_unused arma_inline const mtSpOp -arg(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = 0) +arg(const SpBase< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -1126,6 +1128,35 @@ lgamma(const BaseCube& A) +// +// tgamma + +template +arma_warn_unused +arma_inline +typename enable_if2< (is_arma_type::value && is_cx::no), const eOp >::result +tgamma(const T1& A) + { + arma_extra_debug_sigprint(); + + return eOp(A); + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_cx::no, const eOpCube >::result +tgamma(const BaseCube& A) + { + arma_extra_debug_sigprint(); + + return eOpCube(A.get_ref()); + } + + + // the functions below are currently unused; reserved for potential future use template void exp_approx(const T1&) { arma_stop_logic_error("unimplemented"); } diff --git a/src/armadillo_bits/fn_eps.hpp b/src/armadillo_bits/fn_eps.hpp index f8c95b9b..f68ba5d4 100644 --- a/src/armadillo_bits/fn_eps.hpp +++ b/src/armadillo_bits/fn_eps.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -24,7 +26,7 @@ template arma_warn_unused inline const eOp -eps(const Base& X, const typename arma_not_cx::result* junk = 0) +eps(const Base& X, const typename arma_not_cx::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -38,7 +40,7 @@ template arma_warn_unused inline Mat< typename T1::pod_type > -eps(const Base< std::complex, T1>& X, const typename arma_cx_only::result* junk = 0) +eps(const Base< std::complex, T1>& X, const typename arma_cx_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -49,7 +51,7 @@ eps(const Base< std::complex, T1>& X, const typename arma const unwrap tmp(X.get_ref()); const Mat& A = tmp.M; - Mat out(A.n_rows, A.n_cols); + Mat out(A.n_rows, A.n_cols, arma_nozeros_indicator()); T* out_mem = out.memptr(); const eT* A_mem = A.memptr(); diff --git a/src/armadillo_bits/fn_expmat.hpp b/src/armadillo_bits/fn_expmat.hpp index 055d310f..5e5909b8 100644 --- a/src/armadillo_bits/fn_expmat.hpp +++ b/src/armadillo_bits/fn_expmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -52,12 +54,11 @@ expmat(Mat& B, const Base& A) if(status == false) { - arma_debug_warn("expmat(): given matrix appears ill-conditioned"); B.soft_reset(); - return false; + arma_debug_warn_level(3, "expmat(): given matrix appears ill-conditioned"); } - return true; + return status; } @@ -91,7 +92,7 @@ expmat_sym(Mat& Y, const Base if(status == false) { Y.soft_reset(); - arma_debug_warn("expmat_sym(): transformation failed"); + arma_debug_warn_level(3, "expmat_sym(): transformation failed"); } return status; diff --git a/src/armadillo_bits/fn_eye.hpp b/src/armadillo_bits/fn_eye.hpp index 0e007011..4252ffaa 100644 --- a/src/armadillo_bits/fn_eye.hpp +++ b/src/armadillo_bits/fn_eye.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -47,20 +49,13 @@ template arma_warn_unused arma_inline const Gen -eye(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = 0) +eye(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_Col::value) - { - arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); - } - else - if(is_Row::value) - { - arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); - } + if(is_Col::value) { arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); } return Gen(n_rows, n_cols); } @@ -71,7 +66,7 @@ template arma_warn_unused arma_inline const Gen -eye(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = 0) +eye(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -85,20 +80,13 @@ template arma_warn_unused inline obj_type -eye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = NULL) +eye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_SpCol::value) - { - arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); - } - else - if(is_SpRow::value) - { - arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); - } + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "eye(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "eye(): incompatible size" ); } obj_type out; @@ -113,7 +101,7 @@ template arma_warn_unused inline obj_type -eye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = NULL) +eye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_fft.hpp b/src/armadillo_bits/fn_fft.hpp index cabf1d58..d2d11fbf 100644 --- a/src/armadillo_bits/fn_fft.hpp +++ b/src/armadillo_bits/fn_fft.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_fft2.hpp b/src/armadillo_bits/fn_fft2.hpp index 501cd2a7..51ea0daf 100644 --- a/src/armadillo_bits/fn_fft2.hpp +++ b/src/armadillo_bits/fn_fft2.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_find.hpp b/src/armadillo_bits/fn_find.hpp index bd565513..5efb254b 100644 --- a/src/armadillo_bits/fn_find.hpp +++ b/src/armadillo_bits/fn_find.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -45,7 +47,7 @@ find(const Base& X, const uword k, const char* direct { arma_extra_debug_sigprint(); - const char sig = (direction != NULL) ? direction[0] : char(0); + const char sig = (direction != nullptr) ? direction[0] : char(0); arma_debug_check ( @@ -162,7 +164,7 @@ find(const SpBase& X, const uword k = 0) const uword n_rows = P.get_n_rows(); const uword n_nz = P.get_n_nonzero(); - Mat tmp(n_nz,1); + Mat tmp(n_nz, 1, arma_nozeros_indicator()); uword* tmp_mem = tmp.memptr(); @@ -196,6 +198,10 @@ find(const SpBase& X, const uword k, const char* dire { arma_extra_debug_sigprint(); + arma_ignore(X); + arma_ignore(k); + arma_ignore(direction); + arma_check(true, "find(SpBase,k,direction): not implemented yet"); // TODO Col out; @@ -245,6 +251,24 @@ find_nonfinite(const T1& X) +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value, + const mtOp + >::result +find_nan(const T1& X) + { + arma_extra_debug_sigprint(); + + return mtOp(X); + } + + + // @@ -287,6 +311,25 @@ find_nonfinite(const BaseCube& X) +template +arma_warn_unused +inline +uvec +find_nan(const BaseCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube tmp(X.get_ref()); + + const Mat R( const_cast< eT* >(tmp.M.memptr()), tmp.M.n_elem, 1, false ); + + return find_nan(R); + } + + + // @@ -304,7 +347,7 @@ find_finite(const SpBase& X) const uword n_rows = P.get_n_rows(); const uword n_nz = P.get_n_nonzero(); - Mat tmp(n_nz,1); + Mat tmp(n_nz, 1, arma_nozeros_indicator()); uword* tmp_mem = tmp.memptr(); @@ -348,7 +391,7 @@ find_nonfinite(const SpBase& X) const uword n_rows = P.get_n_rows(); const uword n_nz = P.get_n_nonzero(); - Mat tmp(n_nz,1); + Mat tmp(n_nz, 1, arma_nozeros_indicator()); uword* tmp_mem = tmp.memptr(); @@ -379,4 +422,48 @@ find_nonfinite(const SpBase& X) +template +arma_warn_unused +inline +Col +find_nan(const SpBase& X) + { + arma_extra_debug_sigprint(); + + const SpProxy P(X.get_ref()); + + const uword n_rows = P.get_n_rows(); + const uword n_nz = P.get_n_nonzero(); + + Mat tmp(n_nz, 1, arma_nozeros_indicator()); + + uword* tmp_mem = tmp.memptr(); + + typename SpProxy::const_iterator_type it = P.begin(); + + uword count = 0; + + for(uword i=0; i out; + + if(count > 0) { out.steal_mem_col(tmp, count); } + + return out; + } + + + //! @} diff --git a/src/armadillo_bits/fn_find_unique.hpp b/src/armadillo_bits/fn_find_unique.hpp index 51302dc2..4d90ca15 100644 --- a/src/armadillo_bits/fn_find_unique.hpp +++ b/src/armadillo_bits/fn_find_unique.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_flip.hpp b/src/armadillo_bits/fn_flip.hpp index 689bdb35..811a5e02 100644 --- a/src/armadillo_bits/fn_flip.hpp +++ b/src/armadillo_bits/fn_flip.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_hess.hpp b/src/armadillo_bits/fn_hess.hpp index d57a0535..adf31d69 100644 --- a/src/armadillo_bits/fn_hess.hpp +++ b/src/armadillo_bits/fn_hess.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,7 +27,7 @@ hess ( Mat& H, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -48,7 +50,7 @@ hess if(status == false) { H.soft_reset(); - arma_debug_warn("hess(): decomposition failed"); + arma_debug_warn_level(3, "hess(): decomposition failed"); } return status; @@ -63,7 +65,7 @@ Mat hess ( const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -103,7 +105,7 @@ hess Mat& U, Mat& H, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -161,7 +163,7 @@ hess { U.soft_reset(); H.soft_reset(); - arma_debug_warn("hess(): decomposition failed"); + arma_debug_warn_level(3, "hess(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_hist.hpp b/src/armadillo_bits/fn_hist.hpp index e29856e4..018de947 100644 --- a/src/armadillo_bits/fn_hist.hpp +++ b/src/armadillo_bits/fn_hist.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_histc.hpp b/src/armadillo_bits/fn_histc.hpp index 5ee39b57..e99f8969 100644 --- a/src/armadillo_bits/fn_histc.hpp +++ b/src/armadillo_bits/fn_histc.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_index_max.hpp b/src/armadillo_bits/fn_index_max.hpp index 160d0630..aad33f84 100644 --- a/src/armadillo_bits/fn_index_max.hpp +++ b/src/armadillo_bits/fn_index_max.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_index_min.hpp b/src/armadillo_bits/fn_index_min.hpp index ebff82d3..e1b3ce21 100644 --- a/src/armadillo_bits/fn_index_min.hpp +++ b/src/armadillo_bits/fn_index_min.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_inplace_strans.hpp b/src/armadillo_bits/fn_inplace_strans.hpp index 0136cc9b..ad09cf26 100644 --- a/src/armadillo_bits/fn_inplace_strans.hpp +++ b/src/armadillo_bits/fn_inplace_strans.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,7 +32,7 @@ inplace_strans { arma_extra_debug_sigprint(); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_strans(): unknown method specified" ); diff --git a/src/armadillo_bits/fn_inplace_trans.hpp b/src/armadillo_bits/fn_inplace_trans.hpp index fa9ed7f6..0e238481 100644 --- a/src/armadillo_bits/fn_inplace_trans.hpp +++ b/src/armadillo_bits/fn_inplace_trans.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,7 +58,7 @@ inplace_htrans { arma_extra_debug_sigprint(); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_htrans(): unknown method specified" ); @@ -92,7 +94,7 @@ inplace_trans { arma_extra_debug_sigprint(); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_trans(): unknown method specified" ); @@ -117,7 +119,7 @@ inplace_trans { arma_extra_debug_sigprint(); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'l')), "inplace_trans(): unknown method specified" ); diff --git a/src/armadillo_bits/fn_interp1.hpp b/src/armadillo_bits/fn_interp1.hpp index ddb4884e..d1154230 100644 --- a/src/armadillo_bits/fn_interp1.hpp +++ b/src/armadillo_bits/fn_interp1.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -51,6 +53,11 @@ interp1_helper_nearest(const Mat& XG, const Mat& YG, const Mat& XI, { YI_mem[i] = extrap_val; } + else + if(arma_isnan(XI_val)) + { + YI_mem[i] = Datum::nan; + } else { // XG and XI are guaranteed to be sorted in ascending manner, @@ -111,6 +118,11 @@ interp1_helper_linear(const Mat& XG, const Mat& YG, const Mat& XI, M { YI_mem[i] = extrap_val; } + else + if(arma_isnan(XI_val)) + { + YI_mem[i] = Datum::nan; + } else { // XG and XI are guaranteed to be sorted in ascending manner, @@ -198,8 +210,8 @@ interp1_helper(const Mat& X, const Mat& Y, const Mat& XI, Mat& Y arma_debug_check( (N_subset < 2), "interp1(): X must have at least two unique elements" ); - Mat X_sanitised(N_subset,1); - Mat Y_sanitised(N_subset,1); + Mat X_sanitised(N_subset, 1, arma_nozeros_indicator()); + Mat Y_sanitised(N_subset, 1, arma_nozeros_indicator()); eT* X_sanitised_mem = X_sanitised.memptr(); eT* Y_sanitised_mem = Y_sanitised.memptr(); @@ -221,11 +233,11 @@ interp1_helper(const Mat& X, const Mat& Y, const Mat& XI, Mat& Y Mat XI_tmp; uvec XI_indices; - const bool XI_is_sorted = XI.is_sorted(); + const bool XI_is_sorted = XI.is_sorted(); // NOTE: .is_sorted() currently doesn't detect NaN if(XI_is_sorted == false) { - XI_indices = sort_index(XI); + XI_indices = sort_index(XI); // NOTE: sort_index() will throw if XI has NaN const uword N = XI.n_elem; @@ -244,6 +256,8 @@ interp1_helper(const Mat& X, const Mat& Y, const Mat& XI, Mat& Y const Mat& XI_sorted = (XI_is_sorted) ? XI : XI_tmp; + // NOTE: XI_sorted may have NaN + if(sig == 10) { interp1_helper_nearest(X_sanitised, Y_sanitised, XI_sorted, YI, extrap_val); } else if(sig == 20) { interp1_helper_linear (X_sanitised, Y_sanitised, XI_sorted, YI, extrap_val); } @@ -296,7 +310,7 @@ interp1 uword sig = 0; - if(method != NULL ) + if(method != nullptr) if(method[0] != char(0)) if(method[1] != char(0)) { diff --git a/src/armadillo_bits/fn_interp2.hpp b/src/armadillo_bits/fn_interp2.hpp index 408d32e9..b9b6127c 100644 --- a/src/armadillo_bits/fn_interp2.hpp +++ b/src/armadillo_bits/fn_interp2.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -193,7 +195,7 @@ interp2 typedef typename T1::elem_type eT; - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 'n') && (sig != 'l')), "interp2(): unsupported interpolation type" ); diff --git a/src/armadillo_bits/fn_intersect.hpp b/src/armadillo_bits/fn_intersect.hpp index fe8462c7..37afa52a 100644 --- a/src/armadillo_bits/fn_intersect.hpp +++ b/src/armadillo_bits/fn_intersect.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_inv.hpp b/src/armadillo_bits/fn_inv.hpp index 816acc41..65589f74 100644 --- a/src/armadillo_bits/fn_inv.hpp +++ b/src/armadillo_bits/fn_inv.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,7 +24,7 @@ template arma_warn_unused arma_inline -typename enable_if2< is_supported_blas_type::value, const Op >::result +typename enable_if2< is_supported_blas_type::value, const Op >::result inv ( const Base& X @@ -30,103 +32,7 @@ inv { arma_extra_debug_sigprint(); - return Op(X.get_ref()); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv - ( - const Base& X, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("inv(X,bool) is deprecated and will be removed; change to inv(X)"); - - return Op(X.get_ref()); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv - ( - const Base& X, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("inv(X,char*) is deprecated and will be removed; change to inv(X)"); - - return Op(X.get_ref()); - } - - - -template -arma_warn_unused -arma_inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv - ( - const Op& X - ) - { - arma_extra_debug_sigprint(); - - return Op(X.m, X.aux_uword_a, 0); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv - ( - const Op& X, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("inv(X,bool) is deprecated and will be removed; change to inv(X)"); - - return Op(X.m, X.aux_uword_a, 0); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv - ( - const Op& X, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("inv(X,char*) is deprecated and will be removed; change to inv(X)"); - - return Op(X.m, X.aux_uword_a, 0); + return Op(X.get_ref()); } @@ -142,114 +48,57 @@ inv { arma_extra_debug_sigprint(); - try - { - out = inv(X); - } - catch(std::runtime_error&) + const bool status = op_inv_gen_default::apply_direct(out, X.get_ref(), "inv()"); + + if(status == false) { - return false; + out.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); } - return true; + return status; } -//! NOTE: don't use this form: it will be removed template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, bool >::result +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result inv ( - Mat& out, const Base& X, - const bool // argument kept only for compatibility with old user code + const inv_opts::opts& opts ) { arma_extra_debug_sigprint(); - // arma_debug_warn("inv(Y,X,bool) is deprecated and will be removed; change to inv(Y,X)"); - - return inv(out,X); + return Op(X.get_ref(), opts.flags, uword(0)); } -//! NOTE: don't use this form: it will be removed template -arma_deprecated inline typename enable_if2< is_supported_blas_type::value, bool >::result inv ( Mat& out, const Base& X, - const char* // argument kept only for compatibility with old user code + const inv_opts::opts& opts ) { arma_extra_debug_sigprint(); - // arma_debug_warn("inv(Y,X,char*) is deprecated and will be removed; change to inv(Y,X)"); - - return inv(out,X); - } - - - -template -arma_warn_unused -arma_inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv_sympd - ( - const Base& X - ) - { - arma_extra_debug_sigprint(); - - return Op(X.get_ref()); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv_sympd - ( - const Base& X, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("inv_sympd(X,bool) is deprecated and will be removed; change to inv_sympd(X)"); - - return Op(X.get_ref()); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Op >::result -inv_sympd - ( - const Base& X, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); + const bool status = op_inv_gen_full::apply_direct(out, X.get_ref(), "inv()", opts.flags); - // arma_debug_warn("inv_sympd(X,char*) is deprecated and will be removed; change to inv_sympd(X)"); + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); + } - return Op(X.get_ref()); + return status; } @@ -257,66 +106,31 @@ inv_sympd template inline typename enable_if2< is_supported_blas_type::value, bool >::result -inv_sympd +inv ( - Mat& out, + Mat& out_inv, + typename T1::pod_type& out_rcond, const Base& X ) { arma_extra_debug_sigprint(); - try - { - out = inv_sympd(X); - } - catch(std::runtime_error&) - { - return false; - } + typedef typename T1::pod_type T; - return true; - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, bool >::result -inv_sympd - ( - Mat& out, - const Base& X, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); + op_inv_gen_state inv_state; - // arma_debug_warn("inv_sympd(Y,X,bool) is deprecated and will be removed; change to inv_sympd(Y,X)"); + const bool status = op_inv_gen_rcond::apply_direct(out_inv, inv_state, X.get_ref()); - return inv_sympd(out,X); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, bool >::result -inv_sympd - ( - Mat& out, - const Base& X, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); + out_rcond = inv_state.rcond; - // arma_debug_warn("inv_sympd(Y,X,char*) is deprecated and will be removed; change to inv_sympd(Y,X)"); + if(status == false) + { + out_rcond = T(0); + out_inv.soft_reset(); + arma_debug_warn_level(3, "inv(): matrix is singular"); + } - return inv_sympd(out,X); + return status; } diff --git a/src/armadillo_bits/fn_inv_sympd.hpp b/src/armadillo_bits/fn_inv_sympd.hpp new file mode 100644 index 00000000..ffd1d0d8 --- /dev/null +++ b/src/armadillo_bits/fn_inv_sympd.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_inv_sympd +//! @{ + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv_sympd + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_default::apply_direct(out, X.get_ref()); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +template +arma_warn_unused +arma_inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +inv_sympd + ( + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref(), opts.flags, uword(0)); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out, + const Base& X, + const inv_opts::opts& opts + ) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_full::apply_direct(out, X.get_ref(), opts.flags); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +inv_sympd + ( + Mat& out_inv, + typename T1::pod_type& out_rcond, + const Base& X + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + op_inv_spd_state inv_state; + + const bool status = op_inv_spd_rcond::apply_direct(out_inv, inv_state, X.get_ref()); + + out_rcond = inv_state.rcond; + + if(status == false) + { + out_rcond = T(0); + out_inv.soft_reset(); + arma_debug_warn_level(3, "inv_sympd(): matrix is singular or not positive definite"); + } + + return status; + } + + + +//! @} diff --git a/src/armadillo_bits/fn_join.hpp b/src/armadillo_bits/fn_join.hpp index 1819b41b..6d3ed07d 100644 --- a/src/armadillo_bits/fn_join.hpp +++ b/src/armadillo_bits/fn_join.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -258,7 +260,7 @@ join_slices(const Base& A, const Base out(UA.M.n_rows, UA.M.n_cols, 2); + Cube out(UA.M.n_rows, UA.M.n_cols, 2, arma_nozeros_indicator()); arrayops::copy(out.slice_memptr(0), UA.M.memptr(), UA.M.n_elem); arrayops::copy(out.slice_memptr(1), UB.M.memptr(), UB.M.n_elem); diff --git a/src/armadillo_bits/fn_kmeans.hpp b/src/armadillo_bits/fn_kmeans.hpp index 57365aeb..87074352 100644 --- a/src/armadillo_bits/fn_kmeans.hpp +++ b/src/armadillo_bits/fn_kmeans.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -40,7 +42,7 @@ kmeans const bool status = model.kmeans_wrapper(means, data.get_ref(), k, seed_mode, n_iter, print_mode); - if(status == true) + if(status) { means = model.means; } diff --git a/src/armadillo_bits/fn_kron.hpp b/src/armadillo_bits/fn_kron.hpp index 9d659045..61b5be2a 100644 --- a/src/armadillo_bits/fn_kron.hpp +++ b/src/armadillo_bits/fn_kron.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_log_det.hpp b/src/armadillo_bits/fn_log_det.hpp index 49abf7b3..3ea463ec 100644 --- a/src/armadillo_bits/fn_log_det.hpp +++ b/src/armadillo_bits/fn_log_det.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,13 +24,13 @@ //! log determinant of mat template inline -void +bool log_det ( typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -37,28 +39,29 @@ log_det typedef typename T1::elem_type eT; typedef typename T1::pod_type T; - const bool status = auxlib::log_det(out_val, out_sign, X); + const bool status = op_log_det::apply_direct(out_val, out_sign, X.get_ref()); if(status == false) { out_val = eT(Datum::nan); out_sign = T(0); - arma_warn("log_det(): failed to find determinant"); + arma_debug_warn_level(3, "log_det(): failed to find determinant"); } + + return status; } template +arma_warn_unused inline -void +std::complex log_det ( - typename T1::elem_type& out_val, - typename T1::pod_type& out_sign, - const Op& X, - const typename arma_blas_type_only::result* junk = 0 + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -67,61 +70,86 @@ log_det typedef typename T1::elem_type eT; typedef typename T1::pod_type T; - const diagmat_proxy A(X.m); - - arma_debug_check( (A.n_rows != A.n_cols), "log_det(): given matrix must be square sized" ); + eT out_val = eT(0); + T out_sign = T(0); - const uword N = (std::min)(A.n_rows, A.n_cols); + const bool status = op_log_det::apply_direct(out_val, out_sign, X.get_ref()); - if(N == 0) + if(status == false) { - out_val = eT(0); - out_sign = T(1); + out_val = eT(Datum::nan); + out_sign = T(0); - return; + arma_stop_runtime_error("log_det(): failed to find determinant"); } - eT x = A[0]; + return (out_sign >= T(1)) ? std::complex(out_val) : (out_val + std::complex(T(0),Datum::pi)); + } + + + +// + + + +template +inline +bool +log_det_sympd + ( + typename T1::pod_type& out_val, + const Base& X, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::pod_type T; - T sign = (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; - eT val = (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + out_val = T(0); - for(uword i=1; i::nan; - sign *= (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? -1 : +1 ) : +1; - val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + arma_debug_warn_level(3, "log_det_sympd(): given matrix is not symmetric positive definite"); } - out_val = val; - out_sign = sign; + return status; } template -inline arma_warn_unused -std::complex -log_det +inline +typename T1::pod_type +log_det_sympd ( const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - typedef typename T1::elem_type eT; - typedef typename T1::pod_type T; + typedef typename T1::pod_type T; - eT out_val = eT(0); - T out_sign = T(0); + T out_val = T(0); - log_det(out_val, out_sign, X.get_ref()); + const bool status = op_log_det_sympd::apply_direct(out_val, X.get_ref()); - return (out_sign >= T(1)) ? std::complex(out_val) : (out_val + std::complex(T(0),Datum::pi)); + if(status == false) + { + out_val = Datum::nan; + + arma_stop_runtime_error("log_det_sympd(): given matrix is not symmetric positive definite"); + } + + return out_val; } diff --git a/src/armadillo_bits/fn_log_normpdf.hpp b/src/armadillo_bits/fn_log_normpdf.hpp new file mode 100644 index 00000000..cb404db6 --- /dev/null +++ b/src/armadillo_bits/fn_log_normpdf.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_log_normpdf +//! @{ + + + +template +inline +typename enable_if2< (is_real::value), void >::result +log_normpdf_helper(Mat& out, const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(Proxy::use_at || Proxy::use_at || Proxy::use_at) + { + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UM(M_expr.get_ref()); + const quasi_unwrap US(S_expr.get_ref()); + + log_normpdf_helper(out, UX.M, UM.M, US.M); + + return; + } + + const Proxy PX(X_expr.get_ref()); + const Proxy PM(M_expr.get_ref()); + const Proxy PS(S_expr.get_ref()); + + arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "log_normpdf(): size mismatch" ); + + out.set_size(PX.get_n_rows(), PX.get_n_cols()); + + eT* out_mem = out.memptr(); + + const uword N = PX.get_n_elem(); + + typename Proxy::ea_type X_ea = PX.get_ea(); + typename Proxy::ea_type M_ea = PM.get_ea(); + typename Proxy::ea_type S_ea = PS.get_ea(); + + const bool use_mp = arma_config::openmp && mp_gate::eval(N); + + if(use_mp) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i::log_sqrt2pi); + } + } + #endif + } + else + { + for(uword i=0; i::log_sqrt2pi); + } + } + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +log_normpdf(const eT x) + { + const eT out = (eT(-0.5) * (x*x)) - Datum::log_sqrt2pi; + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), eT >::result +log_normpdf(const eT x, const eT mu, const eT sigma) + { + const eT tmp = (x - mu) / sigma; + + const eT out = (eT(-0.5) * (tmp*tmp)) - (std::log(sigma) + Datum::log_sqrt2pi); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const eT x, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + const quasi_unwrap UM(M_expr.get_ref()); + const Mat& M = UM.M; + + Mat out; + + log_normpdf_helper(out, x*ones< Mat >(arma::size(M)), M, S_expr.get_ref()); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + log_normpdf_helper(out, X, zeros< Mat >(arma::size(X)), ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UX(X_expr.get_ref()); + const Mat& X = UX.M; + + Mat out; + + log_normpdf_helper(out, X, mu*ones< Mat >(arma::size(X)), sigma*ones< Mat >(arma::size(X))); + + return out; + } + + + +template +arma_warn_unused +inline +typename enable_if2< (is_real::value), Mat >::result +log_normpdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat out; + + log_normpdf_helper(out, X_expr.get_ref(), M_expr.get_ref(), S_expr.get_ref()); + + return out; + } + + + +//! @} diff --git a/src/armadillo_bits/fn_logmat.hpp b/src/armadillo_bits/fn_logmat.hpp index f32d25d6..e169987a 100644 --- a/src/armadillo_bits/fn_logmat.hpp +++ b/src/armadillo_bits/fn_logmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -57,7 +59,7 @@ logmat(Mat< std::complex >& Y, const Base& Y, const Base& X, if(status == false) { Y.soft_reset(); - arma_debug_warn("logmat(): transformation failed"); + arma_debug_warn_level(3, "logmat(): transformation failed"); } return status; @@ -114,7 +116,7 @@ logmat_sympd(Mat& Y, const Base& L, Mat& U, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - arma_debug_check( (&L == &U), "lu(): L and U are the same object"); + arma_debug_check( (&L == &U), "lu(): L and U are the same object" ); const bool status = auxlib::lu(L, U, X); @@ -42,7 +44,7 @@ lu { L.soft_reset(); U.soft_reset(); - arma_debug_warn("lu(): decomposition failed"); + arma_debug_warn_level(3, "lu(): decomposition failed"); } return status; @@ -60,13 +62,13 @@ lu Mat& U, Mat& P, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - arma_debug_check( ( (&L == &U) || (&L == &P) || (&U == &P) ), "lu(): two or more output objects are the same object"); + arma_debug_check( ( (&L == &U) || (&L == &P) || (&U == &P) ), "lu(): two or more output objects are the same object" ); const bool status = auxlib::lu(L, U, P, X); @@ -75,7 +77,7 @@ lu L.soft_reset(); U.soft_reset(); P.soft_reset(); - arma_debug_warn("lu(): decomposition failed"); + arma_debug_warn_level(3, "lu(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_max.hpp b/src/armadillo_bits/fn_max.hpp index 91b99e36..dcbf1cee 100644 --- a/src/armadillo_bits/fn_max.hpp +++ b/src/armadillo_bits/fn_max.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_mean.hpp b/src/armadillo_bits/fn_mean.hpp index b0c09b58..b1400c19 100644 --- a/src/armadillo_bits/fn_mean.hpp +++ b/src/armadillo_bits/fn_mean.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_median.hpp b/src/armadillo_bits/fn_median.hpp index c529f09e..48ff7565 100644 --- a/src/armadillo_bits/fn_median.hpp +++ b/src/armadillo_bits/fn_median.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_min.hpp b/src/armadillo_bits/fn_min.hpp index 56cd2439..3baa1283 100644 --- a/src/armadillo_bits/fn_min.hpp +++ b/src/armadillo_bits/fn_min.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_misc.hpp b/src/armadillo_bits/fn_misc.hpp index b9c30903..51930e4d 100644 --- a/src/armadillo_bits/fn_misc.hpp +++ b/src/armadillo_bits/fn_misc.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -57,7 +59,7 @@ linspace const uword num_m1 = num - 1; - if(is_non_integral::value == true) + if(is_non_integral::value) { const T delta = (end-start)/T(num_m1); @@ -167,7 +169,7 @@ log_add_exp(eT log_a, eT log_b) } else { - return (log_a + arma_log1p(std::exp(negdelta))); + return (log_a + std::log1p(std::exp(negdelta))); } } @@ -190,7 +192,7 @@ template arma_warn_unused arma_inline bool -is_finite(const eT x, const typename arma_scalar_only::result* junk = 0) +is_finite(const eT x, const typename arma_scalar_only::result* junk = nullptr) { arma_ignore(junk); @@ -241,22 +243,6 @@ is_finite(const BaseCube& X) -//! NOTE: don't use this function: it will be removed -template -arma_deprecated -inline -const T1& -sympd(const Base& X) - { - arma_extra_debug_sigprint(); - - arma_debug_warn("sympd() is deprecated and will be removed; change inv(sympd(X)) to inv_sympd(X)"); - - return X.get_ref(); - } - - - template inline void @@ -295,7 +281,7 @@ ind2sub(const SizeMat& s, const uword i) const uword row = i % s_n_rows; const uword col = i / s_n_rows; - uvec out(2); + uvec out(2, arma_nozeros_indicator()); uword* out_mem = out.memptr(); @@ -329,7 +315,7 @@ ind2sub(const SizeMat& s, const T1& indices) arma_debug_check( ((P_is_empty == false) && (P_is_vec == false)), "ind2sub(): parameter 'indices' must be a vector" ); - umat out(2,P_n_elem); + umat out(2, P_n_elem, arma_nozeros_indicator()); if(Proxy::use_at == false) { @@ -411,7 +397,7 @@ ind2sub(const SizeCube& s, const uword i) const uword row = j % s_n_rows; const uword col = j / s_n_rows; - uvec out(3); + uvec out(3, arma_nozeros_indicator()); uword* out_mem = out.memptr(); @@ -443,7 +429,7 @@ ind2sub(const SizeCube& s, const T1& indices) const uword U_n_elem = U.M.n_elem; const uword* U_mem = U.M.memptr(); - umat out(3,U_n_elem); + umat out(3, U_n_elem, arma_nozeros_indicator()); for(uword count=0; count < U_n_elem; ++count) { @@ -501,7 +487,7 @@ sub2ind(const SizeMat& s, const Base& subscripts) const uword U_M_n_cols = U.M.n_cols; - uvec out(U_M_n_cols); + uvec out(U_M_n_cols, arma_nozeros_indicator()); uword* out_mem = out.memptr(); const uword* U_M_mem = U.M.memptr(); @@ -558,7 +544,7 @@ sub2ind(const SizeCube& s, const Base& subscripts) const uword U_M_n_cols = U.M.n_cols; - uvec out(U_M_n_cols); + uvec out(U_M_n_cols, arma_nozeros_indicator()); uword* out_mem = out.memptr(); const uword* U_M_mem = U.M.memptr(); diff --git a/src/armadillo_bits/fn_mvnrnd.hpp b/src/armadillo_bits/fn_mvnrnd.hpp index 56722316..dd873d0a 100644 --- a/src/armadillo_bits/fn_mvnrnd.hpp +++ b/src/armadillo_bits/fn_mvnrnd.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,7 +58,6 @@ mvnrnd(const Base& M, const Base -arma_warn_unused inline typename enable_if2 @@ -72,17 +73,16 @@ mvnrnd(Mat& out, const Base& if(status == false) { - arma_debug_warn("mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); - return false; + out.soft_reset(); + arma_debug_warn_level(3, "mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); } - return true; + return status; } template -arma_warn_unused inline typename enable_if2 @@ -98,11 +98,11 @@ mvnrnd(Mat& out, const Base& if(status == false) { - arma_debug_warn("mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); - return false; + out.soft_reset(); + arma_debug_warn_level(3, "mvnrnd(): given covariance matrix is not symmetric positive semi-definite"); } - return true; + return status; } diff --git a/src/armadillo_bits/fn_n_unique.hpp b/src/armadillo_bits/fn_n_unique.hpp index d54cbab4..2f00b72a 100644 --- a/src/armadillo_bits/fn_n_unique.hpp +++ b/src/armadillo_bits/fn_n_unique.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_nonzeros.hpp b/src/armadillo_bits/fn_nonzeros.hpp index 4e3f9cb7..202efe17 100644 --- a/src/armadillo_bits/fn_nonzeros.hpp +++ b/src/armadillo_bits/fn_nonzeros.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_norm.hpp b/src/armadillo_bits/fn_norm.hpp index 97ab28a4..a8f05f0f 100644 --- a/src/armadillo_bits/fn_norm.hpp +++ b/src/armadillo_bits/fn_norm.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,14 +22,14 @@ template -inline arma_warn_unused +inline typename enable_if2< is_arma_type::value, typename T1::pod_type >::result norm ( const T1& X, const uword k = uword(2), - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -37,64 +39,43 @@ norm const Proxy P(X); - if(P.get_n_elem() == 0) - { - return T(0); - } + if(P.get_n_elem() == 0) { return T(0); } - const bool is_vec = (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); if(is_vec) { - switch(k) - { - case 1: - return op_norm::vec_norm_1(P); - break; - - case 2: - return op_norm::vec_norm_2(P); - break; - - default: - { - arma_debug_check( (k == 0), "norm(): k must be greater than zero" ); - return op_norm::vec_norm_k(P, int(k)); - } - } + if(k == uword(1)) { return op_norm::vec_norm_1(P); } + if(k == uword(2)) { return op_norm::vec_norm_2(P); } + + arma_debug_check( (k == 0), "norm(): unsupported vector norm type" ); + + return op_norm::vec_norm_k(P, int(k)); } else { - switch(k) - { - case 1: - return op_norm::mat_norm_1(P); - break; - - case 2: - return op_norm::mat_norm_2(P); - break; + const quasi_unwrap::stored_type> U(P.Q); + + if(k == uword(1)) { return op_norm::mat_norm_1(U.M); } + if(k == uword(2)) { return op_norm::mat_norm_2(U.M); } - default: - arma_stop_logic_error("norm(): unsupported matrix norm type"); - return T(0); - } + arma_stop_logic_error("norm(): unsupported matrix norm type"); } - return T(0); // prevent erroneous compiler warnings + return T(0); } template -inline arma_warn_unused +inline typename enable_if2< is_arma_type::value, typename T1::pod_type >::result norm ( const T1& X, const char* method, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -104,53 +85,81 @@ norm const Proxy P(X); - if(P.get_n_elem() == 0) - { - return T(0); - } + if(P.get_n_elem() == 0) { return T(0); } - const char sig = (method != NULL) ? method[0] : char(0); - const bool is_vec = (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); + const char sig = (method != nullptr) ? method[0] : char(0); + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (P.get_n_rows() == 1) || (P.get_n_cols() == 1); if(is_vec) { - if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // max norm - { - return op_norm::vec_norm_max(P); - } - else - if(sig == '-') // min norm - { - return op_norm::vec_norm_min(P); - } - else - if( (sig == 'f') || (sig == 'F') ) - { - return op_norm::vec_norm_2(P); - } - else - { - arma_stop_logic_error("norm(): unsupported vector norm type"); - return T(0); - } + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { return op_norm::vec_norm_max(P); } + if( (sig == '-') ) { return op_norm::vec_norm_min(P); } + if( (sig == 'f') || (sig == 'F') ) { return op_norm::vec_norm_2(P); } + + arma_stop_logic_error("norm(): unsupported vector norm type"); } else { if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // inf norm { - return op_norm::mat_norm_inf(P); + const quasi_unwrap::stored_type> U(P.Q); + + return op_norm::mat_norm_inf(U.M); } else if( (sig == 'f') || (sig == 'F') ) { return op_norm::vec_norm_2(P); } - else - { - arma_stop_logic_error("norm(): unsupported matrix norm type"); - return T(0); - } + + arma_stop_logic_error("norm(): unsupported matrix norm type"); } + + return T(0); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, double >::result +norm + ( + const T1& X, + const uword k = uword(2), + const typename arma_integral_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(resolves_to_colvector::value) { return norm(conv_to< Col >::from(X), k); } + if(resolves_to_rowvector::value) { return norm(conv_to< Row >::from(X), k); } + + return norm(conv_to< Mat >::from(X), k); + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_arma_type::value, double >::result +norm + ( + const T1& X, + const char* method, + const typename arma_integral_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + if(resolves_to_colvector::value) { return norm(conv_to< Col >::from(X), method); } + if(resolves_to_rowvector::value) { return norm(conv_to< Row >::from(X), method); } + + return norm(conv_to< Mat >::from(X), method); } @@ -160,14 +169,14 @@ norm template -inline arma_warn_unused +inline typename enable_if2< is_arma_sparse_type::value, typename T1::pod_type >::result norm ( - const T1& X, + const T1& expr, const uword k = uword(2), - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -176,72 +185,53 @@ norm typedef typename T1::elem_type eT; typedef typename T1::pod_type T; - const SpProxy P(X); - - if(P.get_n_nonzero() == 0) + if(is_SpSubview_col::value) { - return T(0); + const SpSubview_col& sv = reinterpret_cast< const SpSubview_col& >(expr); + + if(sv.n_rows == sv.m.n_rows) + { + const SpMat& m = sv.m; + const uword col = sv.aux_col1; + const eT* mem = &(m.values[ m.col_ptrs[col] ]); + + return spop_norm::vec_norm_k(mem, sv.n_nonzero, k); + } } - const bool is_vec = (P.get_n_rows() == 1) || (P.get_n_cols() == 1); + const unwrap_spmat U(expr); + const SpMat& X = U.M; + + if(X.n_nonzero == 0) { return T(0); } + + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (X.n_rows == 1) || (X.n_cols == 1); if(is_vec) { - const unwrap_spmat::stored_type> tmp(P.Q); - const SpMat& A = tmp.M; - - // create a fake dense vector to allow reuse of code for dense vectors - Col fake_vector( access::rwp(A.values), A.n_nonzero, false ); - - const Proxy< Col > P_fake_vector(fake_vector); - - switch(k) - { - case 1: - return op_norm::vec_norm_1(P_fake_vector); - break; - - case 2: - return op_norm::vec_norm_2(P_fake_vector); - break; - - default: - { - arma_debug_check( (k == 0), "norm(): k must be greater than zero" ); - return op_norm::vec_norm_k(P_fake_vector, int(k)); - } - } + return spop_norm::vec_norm_k(X.values, X.n_nonzero, k); } else { - switch(k) - { - case 1: - return op_norm::mat_norm_1(P); - break; - - case 2: - return op_norm::mat_norm_2(P); - break; - - default: - arma_stop_logic_error("norm(): unsupported or unimplemented norm type for sparse matrices"); - return T(0); - } + if(k == uword(1)) { return spop_norm::mat_norm_1(X); } + if(k == uword(2)) { return spop_norm::mat_norm_2(X); } + + arma_stop_logic_error("norm(): unsupported or unimplemented norm type for sparse matrices"); } + + return T(0); } template -inline arma_warn_unused +inline typename enable_if2< is_arma_sparse_type::value, typename T1::pod_type >::result norm ( - const T1& X, + const T1& expr, const char* method, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -250,25 +240,19 @@ norm typedef typename T1::elem_type eT; typedef typename T1::pod_type T; - const SpProxy P(X); - - if(P.get_n_nonzero() == 0) - { - return T(0); - } + const unwrap_spmat U(expr); + const SpMat& X = U.M; - - const unwrap_spmat::stored_type> tmp(P.Q); - const SpMat& A = tmp.M; + if(X.n_nonzero == 0) { return T(0); } // create a fake dense vector to allow reuse of code for dense vectors - Col fake_vector( access::rwp(A.values), A.n_nonzero, false ); + Col fake_vector( access::rwp(X.values), X.n_nonzero, false ); const Proxy< Col > P_fake_vector(fake_vector); - const char sig = (method != NULL) ? method[0] : char(0); - const bool is_vec = (P.get_n_rows() == 1) || (P.get_n_cols() == 1); // TODO: (T1::is_row) || (T1::is_col) || ... + const char sig = (method != nullptr) ? method[0] : char(0); + const bool is_vec = (T1::is_xvec) || (T1::is_row) || (T1::is_col) || (X.n_rows == 1) || (X.n_cols == 1); if(is_vec) { @@ -281,43 +265,76 @@ norm { const T val = op_norm::vec_norm_min(P_fake_vector); - if( P.get_n_nonzero() < P.get_n_elem() ) - { - return (std::min)(T(0), val); - } - else - { - return val; - } + return (X.n_nonzero < X.n_elem) ? T((std::min)(T(0), val)) : T(val); } else if( (sig == 'f') || (sig == 'F') ) { return op_norm::vec_norm_2(P_fake_vector); } - else - { - arma_stop_logic_error("norm(): unsupported vector norm type"); - return T(0); - } + + arma_stop_logic_error("norm(): unsupported vector norm type"); } else { if( (sig == 'i') || (sig == 'I') || (sig == '+') ) // inf norm { - return op_norm::mat_norm_inf(P); + return spop_norm::mat_norm_inf(X); } else if( (sig == 'f') || (sig == 'F') ) { return op_norm::vec_norm_2(P_fake_vector); } - else - { - arma_stop_logic_error("norm(): unsupported matrix norm type"); - return T(0); - } + + arma_stop_logic_error("norm(): unsupported matrix norm type"); } + + return T(0); + } + + + +// +// approximate norms + + +template +arma_warn_unused +inline +typename T1::pod_type +norm2est + ( + const Base& X, + const typename T1::pod_type tolerance = 0, + const uword max_iter = 100, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return op_norm2est::norm2est(X.get_ref(), tolerance, max_iter); + } + + + +template +arma_warn_unused +inline +typename T1::pod_type +norm2est + ( + const SpBase& X, + const typename T1::pod_type tolerance = 0, + const uword max_iter = 100, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return op_norm2est::norm2est(X.get_ref(), tolerance, max_iter); } diff --git a/src/armadillo_bits/fn_normalise.hpp b/src/armadillo_bits/fn_normalise.hpp index 27f5df5d..ae074300 100644 --- a/src/armadillo_bits/fn_normalise.hpp +++ b/src/armadillo_bits/fn_normalise.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -33,7 +35,7 @@ normalise const T1& X, const uword p = uword(2), const arma_empty_class junk1 = arma_empty_class(), - const typename arma_real_or_cx_only::result* junk2 = 0 + const typename arma_real_or_cx_only::result* junk2 = nullptr ) { arma_extra_debug_sigprint(); @@ -59,7 +61,7 @@ normalise const T1& X, const uword p = uword(2), const uword dim = 0, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -79,7 +81,7 @@ normalise const SpBase& expr, const uword p = uword(2), const uword dim = 0, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -102,7 +104,7 @@ enable_if2 >::result normalise(const T& val) { - Col out(1); + Col out(1, arma_nozeros_indicator()); out[0] = (val != T(0)) ? T(val / (std::abs)(val)) : T(val); diff --git a/src/armadillo_bits/fn_normcdf.hpp b/src/armadillo_bits/fn_normcdf.hpp index 9d1608f0..06ed5cb8 100644 --- a/src/armadillo_bits/fn_normcdf.hpp +++ b/src/armadillo_bits/fn_normcdf.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,122 +28,95 @@ normcdf_helper(Mat& out, const Base::use_at || Proxy::use_at || Proxy::use_at) { - arma_stop_logic_error("normcdf(): C++11 compiler required"); + const quasi_unwrap UX(X_expr.get_ref()); + const quasi_unwrap UM(M_expr.get_ref()); + const quasi_unwrap US(S_expr.get_ref()); + + normcdf_helper(out, UX.M, UM.M, US.M); return; } - #else + + const Proxy PX(X_expr.get_ref()); + const Proxy PM(M_expr.get_ref()); + const Proxy PS(S_expr.get_ref()); + + arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normcdf(): size mismatch" ); + + out.set_size(PX.get_n_rows(), PX.get_n_cols()); + + eT* out_mem = out.memptr(); + + const uword N = PX.get_n_elem(); + + typename Proxy::ea_type X_ea = PX.get_ea(); + typename Proxy::ea_type M_ea = PM.get_ea(); + typename Proxy::ea_type S_ea = PS.get_ea(); + + const bool use_mp = arma_config::openmp && mp_gate::eval(N); + + if(use_mp) { - typedef typename T1::elem_type eT; - - if(Proxy::use_at || Proxy::use_at || Proxy::use_at) - { - const quasi_unwrap UX(X_expr.get_ref()); - const quasi_unwrap UM(M_expr.get_ref()); - const quasi_unwrap US(S_expr.get_ref()); - - normcdf_helper(out, UX.M, UM.M, US.M); - - return; - } - - const Proxy PX(X_expr.get_ref()); - const Proxy PM(M_expr.get_ref()); - const Proxy PS(S_expr.get_ref()); - - arma_debug_check( ( (PX.get_n_rows() != PM.get_n_rows()) || (PX.get_n_cols() != PM.get_n_cols()) || (PM.get_n_rows() != PS.get_n_rows()) || (PM.get_n_cols() != PS.get_n_cols()) ), "normcdf(): size mismatch" ); - - out.set_size(PX.get_n_rows(), PX.get_n_cols()); - - eT* out_mem = out.memptr(); - - const uword N = PX.get_n_elem(); - - typename Proxy::ea_type X_ea = PX.get_ea(); - typename Proxy::ea_type M_ea = PM.get_ea(); - typename Proxy::ea_type S_ea = PS.get_ea(); - - const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate::eval(N); - - if(use_mp) - { - #if defined(ARMA_USE_OPENMP) - { - const int n_threads = mp_thread_limit::get(); - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword i=0; i::sqrt2)); - - out_mem[i] = 0.5 * std::erfc(tmp); - } - } - #endif - } - else + #if defined(ARMA_USE_OPENMP) { + const int n_threads = mp_thread_limit::get(); + #pragma omp parallel for schedule(static) num_threads(n_threads) for(uword i=0; i::sqrt2)); - out_mem[i] = 0.5 * std::erfc(tmp); + out_mem[i] = eT(0.5) * std::erfc(tmp); } } + #endif + } + else + { + for(uword i=0; i::sqrt2)); + + out_mem[i] = eT(0.5) * std::erfc(tmp); + } } - #endif } template -arma_inline +arma_warn_unused +inline typename enable_if2< (is_real::value), eT >::result normcdf(const eT x) { - #if !defined(ARMA_USE_CXX11) - { - arma_stop_logic_error("normcdf(): C++11 compiler required"); - - return eT(0); - } - #else - { - const eT out = 0.5 * std::erfc( x / (-Datum::sqrt2) ); - - return out; - } - #endif + const eT out = eT(0.5) * std::erfc( x / (-Datum::sqrt2) ); + + return out; } template +arma_warn_unused inline typename enable_if2< (is_real::value), eT >::result normcdf(const eT x, const eT mu, const eT sigma) { - #if !defined(ARMA_USE_CXX11) - { - arma_stop_logic_error("normcdf(): C++11 compiler required"); - - return eT(0); - } - #else - { - const eT tmp = (x - mu) / (sigma * (-Datum::sqrt2)); - - const eT out = 0.5 * std::erfc(tmp); - - return out; - } - #endif + const eT tmp = (x - mu) / (sigma * (-Datum::sqrt2)); + + const eT out = eT(0.5) * std::erfc(tmp); + + return out; } template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normcdf(const eT x, const Base& M_expr, const Base& S_expr) @@ -161,6 +136,7 @@ normcdf(const eT x, const Base& M_expr, const Base& S_expr) template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normcdf(const Base& X_expr) @@ -182,6 +158,7 @@ normcdf(const Base& X_expr) template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normcdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) @@ -203,6 +180,7 @@ normcdf(const Base& X_expr, const typename T1::elem_ template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normcdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) diff --git a/src/armadillo_bits/fn_normpdf.hpp b/src/armadillo_bits/fn_normpdf.hpp index 9d00504e..e05af41b 100644 --- a/src/armadillo_bits/fn_normpdf.hpp +++ b/src/armadillo_bits/fn_normpdf.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -55,7 +57,7 @@ normpdf_helper(Mat& out, const Base::ea_type M_ea = PM.get_ea(); typename Proxy::ea_type S_ea = PS.get_ea(); - const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate::eval(N); + const bool use_mp = arma_config::openmp && mp_gate::eval(N); if(use_mp) { @@ -69,7 +71,7 @@ normpdf_helper(Mat& out, const Base::sqrt2pi); + out_mem[i] = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum::sqrt2pi); } } #endif @@ -82,7 +84,7 @@ normpdf_helper(Mat& out, const Base::sqrt2pi); + out_mem[i] = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum::sqrt2pi); } } } @@ -90,11 +92,12 @@ normpdf_helper(Mat& out, const Base -arma_inline +arma_warn_unused +inline typename enable_if2< (is_real::value), eT >::result normpdf(const eT x) { - const eT out = std::exp(-0.5 * (x*x)) / Datum::sqrt2pi; + const eT out = std::exp(eT(-0.5) * (x*x)) / Datum::sqrt2pi; return out; } @@ -102,13 +105,14 @@ normpdf(const eT x) template +arma_warn_unused inline typename enable_if2< (is_real::value), eT >::result normpdf(const eT x, const eT mu, const eT sigma) { const eT tmp = (x - mu) / sigma; - const eT out = std::exp(-0.5 * (tmp*tmp)) / (sigma * Datum::sqrt2pi); + const eT out = std::exp(eT(-0.5) * (tmp*tmp)) / (sigma * Datum::sqrt2pi); return out; } @@ -116,6 +120,7 @@ normpdf(const eT x, const eT mu, const eT sigma) template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normpdf(const eT x, const Base& M_expr, const Base& S_expr) @@ -135,6 +140,7 @@ normpdf(const eT x, const Base& M_expr, const Base& S_expr) template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normpdf(const Base& X_expr) @@ -156,6 +162,7 @@ normpdf(const Base& X_expr) template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normpdf(const Base& X_expr, const typename T1::elem_type mu, const typename T1::elem_type sigma) @@ -177,6 +184,7 @@ normpdf(const Base& X_expr, const typename T1::elem_ template +arma_warn_unused inline typename enable_if2< (is_real::value), Mat >::result normpdf(const Base& X_expr, const Base& M_expr, const Base& S_expr) diff --git a/src/armadillo_bits/fn_numel.hpp b/src/armadillo_bits/fn_numel.hpp index 4b532ea9..fe5c1915 100644 --- a/src/armadillo_bits/fn_numel.hpp +++ b/src/armadillo_bits/fn_numel.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_ones.hpp b/src/armadillo_bits/fn_ones.hpp index 248da732..ae8b6220 100644 --- a/src/armadillo_bits/fn_ones.hpp +++ b/src/armadillo_bits/fn_ones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,20 +37,16 @@ template arma_warn_unused arma_inline const Gen -ones(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +ones(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); - if(is_Row::value) - { - return Gen(1, n_elem); - } - else - { - return Gen(n_elem, 1); - } + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return Gen(n_rows, n_cols); } @@ -81,20 +79,13 @@ template arma_warn_unused inline const Gen -ones(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = 0) +ones(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_Col::value) - { - arma_debug_check( (n_cols != 1), "ones(): incompatible size" ); - } - else - if(is_Row::value) - { - arma_debug_check( (n_rows != 1), "ones(): incompatible size" ); - } + if(is_Col::value) { arma_debug_check( (n_cols != 1), "ones(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "ones(): incompatible size" ); } return Gen(n_rows, n_cols); } @@ -105,7 +96,7 @@ template arma_warn_unused inline const Gen -ones(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = 0) +ones(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -143,7 +134,7 @@ template arma_warn_unused arma_inline const GenCube -ones(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = 0) +ones(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -157,7 +148,7 @@ template arma_warn_unused arma_inline const GenCube -ones(const SizeCube& s, const typename arma_Cube_only::result* junk = 0) +ones(const SizeCube& s, const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_orth_null.hpp b/src/armadillo_bits/fn_orth_null.hpp index c48580ec..fe689065 100644 --- a/src/armadillo_bits/fn_orth_null.hpp +++ b/src/armadillo_bits/fn_orth_null.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -45,7 +47,8 @@ orth(Mat& out, const Base& X if(status == false) { - arma_debug_warn("orth(): svd failed"); + out.soft_reset(); + arma_debug_warn_level(3, "orth(): svd failed"); } return status; @@ -83,7 +86,8 @@ null(Mat& out, const Base& X if(status == false) { - arma_debug_warn("null(): svd failed"); + out.soft_reset(); + arma_debug_warn_level(3, "null(): svd failed"); } return status; diff --git a/src/armadillo_bits/fn_pinv.hpp b/src/armadillo_bits/fn_pinv.hpp index 7424c14f..6a873227 100644 --- a/src/armadillo_bits/fn_pinv.hpp +++ b/src/armadillo_bits/fn_pinv.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,6 +21,22 @@ +template +arma_warn_unused +inline +typename enable_if2< is_real::value, const Op >::result +pinv + ( + const Base& X + ) + { + arma_extra_debug_sigprint(); + + return Op(X.get_ref()); + } + + + template arma_warn_unused inline @@ -26,19 +44,27 @@ typename enable_if2< is_real::value, const Op& X, - const typename T1::pod_type tol = 0.0, - const char* method = "dc" + const typename T1::pod_type tol, + const char* method = nullptr ) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const char sig = (method != NULL) ? method[0] : char(0); + uword method_id = 0; // default setting - arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + if(method != nullptr) + { + const char sig = method[0]; + + arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + + if(sig == 's') { method_id = 1; } + if(sig == 'd') { method_id = 2; } + } - return (sig == 'd') ? Op(X.get_ref(), eT(tol), 1, 0) : Op(X.get_ref(), eT(tol), 0, 0); + return Op(X.get_ref(), eT(tol), method_id, uword(0)); } @@ -51,22 +77,29 @@ pinv Mat& out, const Base& X, const typename T1::pod_type tol = 0.0, - const char* method = "dc" + const char* method = nullptr ) { arma_extra_debug_sigprint(); - const char sig = (method != NULL) ? method[0] : char(0); - - arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + uword method_id = 0; // default setting - const bool use_divide_and_conquer = (sig == 'd'); + if(method != nullptr) + { + const char sig = method[0]; + + arma_debug_check( ((sig != 's') && (sig != 'd')), "pinv(): unknown method specified" ); + + if(sig == 's') { method_id = 1; } + if(sig == 'd') { method_id = 2; } + } - const bool status = op_pinv::apply_direct(out, X.get_ref(), tol, use_divide_and_conquer); + const bool status = op_pinv::apply_direct(out, X.get_ref(), tol, method_id); if(status == false) { - arma_debug_warn("pinv(): svd failed"); + out.soft_reset(); + arma_debug_warn_level(3, "pinv(): svd failed"); } return status; diff --git a/src/armadillo_bits/fn_polyfit.hpp b/src/armadillo_bits/fn_polyfit.hpp index 1ff1be07..e51e37bc 100644 --- a/src/armadillo_bits/fn_polyfit.hpp +++ b/src/armadillo_bits/fn_polyfit.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -36,7 +38,7 @@ polyfit(Mat& out, const Base if(status == false) { out.soft_reset(); - arma_debug_warn("polyfit(): failed"); + arma_debug_warn_level(3, "polyfit(): failed"); } return status; diff --git a/src/armadillo_bits/fn_polyval.hpp b/src/armadillo_bits/fn_polyval.hpp index 47c88093..f22c7280 100644 --- a/src/armadillo_bits/fn_polyval.hpp +++ b/src/armadillo_bits/fn_polyval.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_powext.hpp b/src/armadillo_bits/fn_powext.hpp new file mode 100644 index 00000000..a971219d --- /dev/null +++ b/src/armadillo_bits/fn_powext.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_powext +//! @{ + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value, + const Glue + >::result +pow + ( + const T1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return Glue(X, Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Mat +pow + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext::apply(X,Y); + } + + + +template +arma_warn_unused +arma_inline +const GlueCube +pow + ( + const BaseCube& X, + const BaseCube& Y + ) + { + arma_extra_debug_sigprint(); + + return GlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Cube +pow + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext::apply(X,Y); + } + + + +// + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + ( is_arma_type::value && is_cx::yes ), + const mtGlue + >::result +pow + ( + const T1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, Y.get_ref()); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_cx::yes, + Mat + >::result +pow + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext_cx::apply(X,Y); + } + + + +template +arma_warn_unused +arma_inline +const mtGlueCube +pow + ( + const BaseCube< std::complex, T1>& X, + const BaseCube< typename T1::pod_type , T2>& Y + ) + { + arma_extra_debug_sigprint(); + + return mtGlueCube(X.get_ref(), Y.get_ref()); + } + + + +template +arma_warn_unused +inline +Cube< std::complex > +pow + ( + const subview_cube_each1< std::complex >& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + return glue_powext_cx::apply(X,Y); + } + + + +//! @} diff --git a/src/armadillo_bits/fn_powmat.hpp b/src/armadillo_bits/fn_powmat.hpp new file mode 100644 index 00000000..17d0293a --- /dev/null +++ b/src/armadillo_bits/fn_powmat.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_powmat +//! @{ + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const Op >::result +powmat(const Base& X, const int y) + { + arma_extra_debug_sigprint(); + + const uword aux_a = (y < int(0)) ? uword(-y) : uword(y); + const uword aux_b = (y < int(0)) ? uword(1) : uword(0); + + return Op(X.get_ref(), aux_a, aux_b); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +powmat + ( + Mat& out, + const Base& X, + const int y + ) + { + arma_extra_debug_sigprint(); + + const uword y_val = (y < int(0)) ? uword(-y) : uword(y); + const bool y_neg = (y < int(0)); + + const bool status = op_powmat::apply_direct(out, X.get_ref(), y_val, y_neg); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "powmat(): transformation failed"); + } + + return status; + } + + + +template +arma_warn_unused +inline +typename enable_if2< is_supported_blas_type::value, const mtOp,T1,op_powmat_cx> >::result +powmat(const Base& X, const double y) + { + arma_extra_debug_sigprint(); + + typedef std::complex out_eT; + + return mtOp('j', X.get_ref(), out_eT(y)); + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +powmat + ( + Mat< std::complex >& out, + const Base& X, + const double y + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + + const bool status = op_powmat_cx::apply_direct(out, X.get_ref(), T(y)); + + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "powmat(): transformation failed"); + } + + return status; + } + + +//! @} diff --git a/src/armadillo_bits/fn_princomp.hpp b/src/armadillo_bits/fn_princomp.hpp index 2d382f24..0a251b51 100644 --- a/src/armadillo_bits/fn_princomp.hpp +++ b/src/armadillo_bits/fn_princomp.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,7 +37,7 @@ princomp Col& latent_out, Col& tsquared_out, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -50,7 +52,7 @@ princomp latent_out.soft_reset(); tsquared_out.soft_reset(); - arma_debug_warn("princomp(): decomposition failed"); + arma_debug_warn_level(3, "princomp(): decomposition failed"); } return status; @@ -72,7 +74,7 @@ princomp Mat& score_out, Col& latent_out, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -86,7 +88,7 @@ princomp score_out.soft_reset(); latent_out.soft_reset(); - arma_debug_warn("princomp(): decomposition failed"); + arma_debug_warn_level(3, "princomp(): decomposition failed"); } return status; @@ -106,7 +108,7 @@ princomp Mat& coeff_out, Mat& score_out, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -119,7 +121,7 @@ princomp coeff_out.soft_reset(); score_out.soft_reset(); - arma_debug_warn("princomp(): decomposition failed"); + arma_debug_warn_level(3, "princomp(): decomposition failed"); } return status; @@ -137,7 +139,7 @@ princomp ( Mat& coeff_out, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -149,7 +151,7 @@ princomp { coeff_out.soft_reset(); - arma_debug_warn("princomp(): decomposition failed"); + arma_debug_warn_level(3, "princomp(): decomposition failed"); } return status; @@ -164,7 +166,7 @@ const Op princomp ( const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/fn_prod.hpp b/src/armadillo_bits/fn_prod.hpp index c4005093..c15110ff 100644 --- a/src/armadillo_bits/fn_prod.hpp +++ b/src/armadillo_bits/fn_prod.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,8 +23,8 @@ //! \brief //! Delayed product of elements of a matrix along a specified dimension (either rows or columns). //! The result is stored in a dense matrix that has either one column or one row. -//! For dim = 0, find the sum of each column (i.e. traverse across rows) -//! For dim = 1, find the sum of each row (i.e. traverse across columns) +//! For dim = 0, find the sum of each column (ie. traverse across rows) +//! For dim = 1, find the sum of each row (ie. traverse across columns) //! The default is dim = 0. //! NOTE: this function works differently than in Matlab/Octave. diff --git a/src/armadillo_bits/fn_qr.hpp b/src/armadillo_bits/fn_qr.hpp index d2dd9311..3d49a1b3 100644 --- a/src/armadillo_bits/fn_qr.hpp +++ b/src/armadillo_bits/fn_qr.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,13 +30,13 @@ qr Mat& Q, Mat& R, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - arma_debug_check( (&Q == &R), "qr(): Q and R are the same object"); + arma_debug_check( (&Q == &R), "qr(): Q and R are the same object" ); const bool status = auxlib::qr(Q, R, X); @@ -42,7 +44,7 @@ qr { Q.soft_reset(); R.soft_reset(); - arma_debug_warn("qr(): decomposition failed"); + arma_debug_warn_level(3, "qr(): decomposition failed"); } return status; @@ -59,13 +61,13 @@ qr_econ Mat& Q, Mat& R, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - arma_debug_check( (&Q == &R), "qr_econ(): Q and R are the same object"); + arma_debug_check( (&Q == &R), "qr_econ(): Q and R are the same object" ); const bool status = auxlib::qr_econ(Q, R, X); @@ -73,7 +75,66 @@ qr_econ { Q.soft_reset(); R.soft_reset(); - arma_debug_warn("qr_econ(): decomposition failed"); + arma_debug_warn_level(3, "qr_econ(): decomposition failed"); + } + + return status; + } + + + +//! QR decomposition with pivoting +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +qr + ( + Mat& Q, + Mat& R, + Mat& P, + const Base& X, + const char* P_mode = "matrix" + ) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (&Q == &R), "qr(): Q and R are the same object" ); + + const char sig = (P_mode != nullptr) ? P_mode[0] : char(0); + + arma_debug_check( ((sig != 'm') && (sig != 'v')), "qr(): argument 'P_mode' must be \"vector\" or \"matrix\"" ); + + bool status = false; + + if(sig == 'v') + { + status = auxlib::qr_pivot(Q, R, P, X); + } + else + if(sig == 'm') + { + Mat P_vec; + + status = auxlib::qr_pivot(Q, R, P_vec, X); + + if(status) + { + // construct P + + const uword N = P_vec.n_rows; + + P.zeros(N,N); + + for(uword row=0; row < N; ++row) { P.at(P_vec[row], row) = uword(1); } + } + } + + if(status == false) + { + Q.soft_reset(); + R.soft_reset(); + P.soft_reset(); + arma_debug_warn_level(3, "qr(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_quantile.hpp b/src/armadillo_bits/fn_quantile.hpp new file mode 100644 index 00000000..6c1ea2b0 --- /dev/null +++ b/src/armadillo_bits/fn_quantile.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_quantile +//! @{ + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no && is_real::value, + const mtGlue + >::result +quantile(const T1& X, const Base& P) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, P.get_ref()); + } + + + +template +arma_warn_unused +arma_inline +typename +enable_if2 + < + is_arma_type::value && is_cx::no && is_real::value, + const mtGlue + >::result +quantile(const T1& X, const Base& P, const uword dim) + { + arma_extra_debug_sigprint(); + + return mtGlue(X, P.get_ref(), dim); + } + + +//! @} diff --git a/src/armadillo_bits/fn_qz.hpp b/src/armadillo_bits/fn_qz.hpp index d6d1da95..9979dfad 100644 --- a/src/armadillo_bits/fn_qz.hpp +++ b/src/armadillo_bits/fn_qz.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -41,7 +43,7 @@ qz { arma_extra_debug_sigprint(); - const char sig = (select != NULL) ? select[0] : char(0); + const char sig = (select != nullptr) ? select[0] : char(0); arma_debug_check( ( (sig != 'n') && (sig != 'l') && (sig != 'r') && (sig != 'i') && (sig != 'o') ), "qz(): unknown select form" ); @@ -53,7 +55,7 @@ qz BB.soft_reset(); Q.soft_reset(); Z.soft_reset(); - arma_debug_warn("qz(): decomposition failed"); + arma_debug_warn_level(3, "qz(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_randg.hpp b/src/armadillo_bits/fn_randg.hpp index 8ef32713..a0e998a8 100644 --- a/src/armadillo_bits/fn_randg.hpp +++ b/src/armadillo_bits/fn_randg.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,76 +25,35 @@ template arma_warn_unused inline obj_type -randg(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = 0) +randg(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - #if defined(ARMA_USE_CXX11) + typedef typename obj_type::elem_type eT; + + if(is_Col::value) { - if(is_Col::value) - { - arma_debug_check( (n_cols != 1), "randg(): incompatible size" ); - } - else - if(is_Row::value) - { - arma_debug_check( (n_rows != 1), "randg(): incompatible size" ); - } - - obj_type out(n_rows, n_cols); - - double a; - double b; - - if(param.state == 0) - { - a = double(1); - b = double(1); - } - else - if(param.state == 1) - { - a = double(param.a_int); - b = double(param.b_int); - } - else - { - a = param.a_double; - b = param.b_double; - } - - arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): a and b must be greater than zero" ); - - #if defined(ARMA_USE_EXTERN_CXX11_RNG) - { - arma_rng_cxx11_instance.randg_fill(out.memptr(), out.n_elem, a, b); - } - #else - { - arma_rng_cxx11 local_arma_rng_cxx11_instance; - - typedef typename arma_rng_cxx11::seed_type seed_type; - - local_arma_rng_cxx11_instance.set_seed( seed_type(arma_rng::randi()) ); - - local_arma_rng_cxx11_instance.randg_fill(out.memptr(), out.n_elem, a, b); - } - #endif - - return out; + arma_debug_check( (n_cols != 1), "randg(): incompatible size" ); } - #else + else + if(is_Row::value) { - arma_ignore(n_rows); - arma_ignore(n_cols); - arma_ignore(param); - - arma_stop_logic_error("randg(): C++11 compiler required"); - - return obj_type(); + arma_debug_check( (n_rows != 1), "randg(): incompatible size" ); } - #endif + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + arma_rng::randg::fill(out.memptr(), out.n_elem, a, b); + + return out; } @@ -101,7 +62,7 @@ template arma_warn_unused inline obj_type -randg(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = 0) +randg(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -115,20 +76,16 @@ template arma_warn_unused inline obj_type -randg(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +randg(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); - if(is_Row::value) - { - return randg(1, n_elem, param); - } - else - { - return randg(n_elem, 1, param); - } + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return randg(n_rows, n_cols, param); } @@ -176,7 +133,18 @@ randg(const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return as_scalar( randg(uword(1), uword(1), param) ); + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + double out_val = double(0); + + arma_rng::randg::fill(&out_val, uword(1), a, b); + + return out_val; } @@ -187,7 +155,20 @@ inline typename arma_real_or_cx_only::result randg(const distr_param& param = distr_param()) { - return eT( as_scalar( randg< Col >(uword(1), uword(1), param) ) ); + arma_extra_debug_sigprint(); + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + eT out_val = eT(0); + + arma_rng::randg::fill(&out_val, uword(1), a, b); + + return out_val; } @@ -196,67 +177,25 @@ template arma_warn_unused inline cube_type -randg(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = 0) +randg(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - #if defined(ARMA_USE_CXX11) - { - cube_type out(n_rows, n_cols, n_slices); - - double a; - double b; - - if(param.state == 0) - { - a = double(1); - b = double(1); - } - else - if(param.state == 1) - { - a = double(param.a_int); - b = double(param.b_int); - } - else - { - a = param.a_double; - b = param.b_double; - } - - arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): a and b must be greater than zero" ); - - #if defined(ARMA_USE_EXTERN_CXX11_RNG) - { - arma_rng_cxx11_instance.randg_fill(out.memptr(), out.n_elem, a, b); - } - #else - { - arma_rng_cxx11 local_arma_rng_cxx11_instance; - - typedef typename arma_rng_cxx11::seed_type seed_type; - - local_arma_rng_cxx11_instance.set_seed( seed_type(arma_rng::randi()) ); - - local_arma_rng_cxx11_instance.randg_fill(out.memptr(), out.n_elem, a, b); - } - #endif - - return out; - } - #else - { - arma_ignore(n_rows); - arma_ignore(n_cols); - arma_ignore(n_slices); - arma_ignore(param); - - arma_stop_logic_error("randg(): C++11 compiler required"); - - return cube_type(); - } - #endif + typedef typename cube_type::elem_type eT; + + double a = double(1); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( ((a <= double(0)) || (b <= double(0))), "randg(): incorrect distribution parameters; a and b must be greater than zero" ); + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + arma_rng::randg::fill(out.memptr(), out.n_elem, a, b); + + return out; } @@ -265,7 +204,7 @@ template arma_warn_unused inline cube_type -randg(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = 0) +randg(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_randi.hpp b/src/armadillo_bits/fn_randi.hpp index d2c914b1..2aae9b5e 100644 --- a/src/armadillo_bits/fn_randi.hpp +++ b/src/armadillo_bits/fn_randi.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,7 +25,7 @@ template arma_warn_unused inline obj_type -randi(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = 0) +randi(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -40,29 +42,14 @@ randi(const uword n_rows, const uword n_cols, const distr_param& param = distr_p arma_debug_check( (n_rows != 1), "randi(): incompatible size" ); } - obj_type out(n_rows, n_cols); + int a = 0; + int b = arma_rng::randi::max_val(); - int a; - int b; + param.get_int_vals(a,b); - if(param.state == 0) - { - a = 0; - b = arma_rng::randi::max_val(); - } - else - if(param.state == 1) - { - a = param.a_int; - b = param.b_int; - } - else - { - a = int(param.a_double); - b = int(param.b_double); - } + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); - arma_debug_check( (a > b), "randi(): incorrect distribution parameters: a must be less than b" ); + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); arma_rng::randi::fill(out.memptr(), out.n_elem, a, b); @@ -75,7 +62,7 @@ template arma_warn_unused inline obj_type -randi(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = 0) +randi(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -89,7 +76,7 @@ template arma_warn_unused inline obj_type -randi(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +randi(const uword n_elem, const distr_param& param = distr_param(), const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk1); @@ -148,7 +135,20 @@ inline sword randi(const distr_param& param) { - return as_scalar( randi(uword(1), uword(1), param) ); + arma_extra_debug_sigprint(); + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + sword out_val = sword(0); + + arma_rng::randi::fill(&out_val, uword(1), a, b); + + return out_val; } @@ -159,7 +159,20 @@ inline typename arma_scalar_only::result randi(const distr_param& param) { - return eT( as_scalar( randi< Col >(uword(1), uword(1), param) ) ); + arma_extra_debug_sigprint(); + + int a = 0; + int b = arma_rng::randi::max_val(); + + param.get_int_vals(a,b); + + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); + + eT out_val = eT(0); + + arma_rng::randi::fill(&out_val, uword(1), a, b); + + return out_val; } @@ -169,6 +182,8 @@ inline sword randi() { + arma_extra_debug_sigprint(); + return sword( arma_rng::randi() ); } @@ -180,6 +195,8 @@ inline typename arma_scalar_only::result randi() { + arma_extra_debug_sigprint(); + return eT( arma_rng::randi() ); } @@ -189,36 +206,21 @@ template arma_warn_unused inline cube_type -randi(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = 0) +randi(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); typedef typename cube_type::elem_type eT; - cube_type out(n_rows, n_cols, n_slices); + int a = 0; + int b = arma_rng::randi::max_val(); - int a; - int b; + param.get_int_vals(a,b); - if(param.state == 0) - { - a = 0; - b = arma_rng::randi::max_val(); - } - else - if(param.state == 1) - { - a = param.a_int; - b = param.b_int; - } - else - { - a = int(param.a_double); - b = int(param.b_double); - } + arma_debug_check( (a > b), "randi(): incorrect distribution parameters; a must be less than b" ); - arma_debug_check( (a > b), "randi(): incorrect distribution parameters: a must be less than b" ); + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); arma_rng::randi::fill(out.memptr(), out.n_elem, a, b); @@ -231,7 +233,7 @@ template arma_warn_unused inline cube_type -randi(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = 0) +randi(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_randn.hpp b/src/armadillo_bits/fn_randn.hpp index e0b39358..37cb3db9 100644 --- a/src/armadillo_bits/fn_randn.hpp +++ b/src/armadillo_bits/fn_randn.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,11 +21,15 @@ +// scalars + arma_warn_unused inline double randn() { + arma_extra_debug_sigprint(); + return double(arma_rng::randn()); } @@ -35,157 +41,315 @@ inline typename arma_real_or_cx_only::result randn() { + arma_extra_debug_sigprint(); + return eT(arma_rng::randn()); } -//! Generate a vector with all elements set to random values with a gaussian distribution (zero mean, unit variance) arma_warn_unused -arma_inline -const Gen -randn(const uword n_elem) +inline +double +randn(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return double(arma_rng::randn()); } + + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + const double val = double(arma_rng::randn()); + + return ((val * sd) + mu); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randn(const distr_param& param) { arma_extra_debug_sigprint(); - return Gen(n_elem, 1); + if(param.state == 0) { return eT(arma_rng::randn()); } + + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + eT val = eT(0); + + arma_rng::randn::fill(&val, 1, mu, sd); // using fill() as eT can be complex + + return val; + } + + + +// vectors + +arma_warn_unused +inline +vec +randn(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + vec out(n_elem, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), n_elem, mu, sd); + } + + return out; } template arma_warn_unused -arma_inline -const Gen -randn(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +inline +obj_type +randn(const uword n_elem, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); - arma_ignore(junk1); - arma_ignore(junk2); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); - if(is_Row::value) + if(param.state == 0) { - return Gen(1, n_elem); + arma_rng::randn::fill(out.memptr(), out.n_elem); } else { - return Gen(n_elem, 1); + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); } + + return out; } -//! Generate a dense matrix with all elements set to random values with a gaussian distribution (zero mean, unit variance) +// matrices + arma_warn_unused -arma_inline -const Gen -randn(const uword n_rows, const uword n_cols) +inline +mat +randn(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return Gen(n_rows, n_cols); + mat out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; } arma_warn_unused -arma_inline -const Gen -randn(const SizeMat& s) +inline +mat +randn(const SizeMat& s, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return Gen(s.n_rows, s.n_cols); + return randn(s.n_rows, s.n_cols, param); } template arma_warn_unused -arma_inline -const Gen -randn(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = 0) +inline +obj_type +randn(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_Col::value) + typedef typename obj_type::elem_type eT; + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "randn(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "randn(): incompatible size" ); } + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) { - arma_debug_check( (n_cols != 1), "randn(): incompatible size" ); + arma_rng::randn::fill(out.memptr(), out.n_elem); } else - if(is_Row::value) { - arma_debug_check( (n_rows != 1), "randn(): incompatible size" ); + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); } - return Gen(n_rows, n_cols); + return out; } template arma_warn_unused -arma_inline -const Gen -randn(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = 0) +inline +obj_type +randn(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - return randn(s.n_rows, s.n_cols); + return randn(s.n_rows, s.n_cols, param); } +// cubes + + arma_warn_unused -arma_inline -const GenCube -randn(const uword n_rows, const uword n_cols, const uword n_slices) +inline +cube +randn(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return GenCube(n_rows, n_cols, n_slices); + cube out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; } arma_warn_unused -arma_inline -const GenCube -randn(const SizeCube& s) +inline +cube +randn(const SizeCube& s, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return GenCube(s.n_rows, s.n_cols, s.n_slices); + return randn(s.n_rows, s.n_cols, s.n_slices, param); } template arma_warn_unused -arma_inline -const GenCube -randn(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = 0) +inline +cube_type +randn(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint(); arma_ignore(junk); - return GenCube(n_rows, n_cols, n_slices); + typedef typename cube_type::elem_type eT; + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randn::fill(out.memptr(), out.n_elem); + } + else + { + double mu = double(0); + double sd = double(1); + + param.get_double_vals(mu,sd); + + arma_debug_check( (sd <= double(0)), "randn(): incorrect distribution parameters; standard deviation must be > 0" ); + + arma_rng::randn::fill(out.memptr(), out.n_elem, mu, sd); + } + + return out; } template arma_warn_unused -arma_inline -const GenCube -randn(const SizeCube& s, const typename arma_Cube_only::result* junk = 0) +inline +cube_type +randn(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint(); arma_ignore(junk); - return GenCube(s.n_rows, s.n_cols, s.n_slices); + return randn(s.n_rows, s.n_cols, s.n_slices, param); } diff --git a/src/armadillo_bits/fn_randperm.hpp b/src/armadillo_bits/fn_randperm.hpp index a0bacdee..19623a75 100644 --- a/src/armadillo_bits/fn_randperm.hpp +++ b/src/armadillo_bits/fn_randperm.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,6 +31,7 @@ internal_randperm_helper(obj_type& x, const uword N, const uword N_keep) typedef typename obj_type::elem_type eT; // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet + // and the associated comparison functor typedef arma_sort_index_packet packet; @@ -118,7 +121,7 @@ randperm(const uword N, const uword M) { arma_extra_debug_sigprint(); - arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'"); + arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'" ); obj_type x; @@ -136,7 +139,7 @@ randperm(const uword N, const uword M) { arma_extra_debug_sigprint(); - arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'"); + arma_debug_check( (M > N), "randperm(): 'M' must be less than or equal to 'N'" ); uvec x; diff --git a/src/armadillo_bits/fn_randu.hpp b/src/armadillo_bits/fn_randu.hpp index 52eb4bd8..432c1716 100644 --- a/src/armadillo_bits/fn_randu.hpp +++ b/src/armadillo_bits/fn_randu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,11 +21,15 @@ +// scalars + arma_warn_unused inline double randu() { + arma_extra_debug_sigprint(); + return double(arma_rng::randu()); } @@ -35,157 +41,315 @@ inline typename arma_real_or_cx_only::result randu() { + arma_extra_debug_sigprint(); + return eT(arma_rng::randu()); } -//! Generate a vector with all elements set to random values in the [0,1] interval (uniform distribution) arma_warn_unused -arma_inline -const Gen -randu(const uword n_elem) +inline +double +randu(const distr_param& param) { arma_extra_debug_sigprint(); - return Gen(n_elem, 1); + if(param.state == 0) { return double(arma_rng::randu()); } + + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + const double val = double(arma_rng::randu()); + + return ((val * (b - a)) + a); + } + + + +template +arma_warn_unused +inline +typename arma_real_or_cx_only::result +randu(const distr_param& param) + { + arma_extra_debug_sigprint(); + + if(param.state == 0) { return eT(arma_rng::randu()); } + + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + eT val = eT(0); + + arma_rng::randu::fill(&val, 1, a, b); // using fill() as eT can be complex + + return val; + } + + + +// vectors + +arma_warn_unused +inline +vec +randu(const uword n_elem, const distr_param& param = distr_param()) + { + arma_extra_debug_sigprint(); + + vec out(n_elem, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), n_elem, a, b); + } + + return out; } template arma_warn_unused -arma_inline -const Gen -randu(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +inline +obj_type +randu(const uword n_elem, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); - arma_ignore(junk1); - arma_ignore(junk2); + arma_ignore(junk); + + typedef typename obj_type::elem_type eT; + + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); - if(is_Row::value) + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) { - return Gen(1, n_elem); + arma_rng::randu::fill(out.memptr(), out.n_elem); } else { - return Gen(n_elem, 1); + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); } + + return out; } -//! Generate a dense matrix with all elements set to random values in the [0,1] interval (uniform distribution) +// matrices + arma_warn_unused -arma_inline -const Gen -randu(const uword n_rows, const uword n_cols) +inline +mat +randu(const uword n_rows, const uword n_cols, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return Gen(n_rows, n_cols); + mat out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; } arma_warn_unused -arma_inline -const Gen -randu(const SizeMat& s) +inline +mat +randu(const SizeMat& s, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return Gen(s.n_rows, s.n_cols); + return randu(s.n_rows, s.n_cols, param); } template arma_warn_unused -arma_inline -const Gen -randu(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = 0) +inline +obj_type +randu(const uword n_rows, const uword n_cols, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_Col::value) + typedef typename obj_type::elem_type eT; + + if(is_Col::value) { arma_debug_check( (n_cols != 1), "randu(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "randu(): incompatible size" ); } + + obj_type out(n_rows, n_cols, arma_nozeros_indicator()); + + if(param.state == 0) { - arma_debug_check( (n_cols != 1), "randu(): incompatible size" ); + arma_rng::randu::fill(out.memptr(), out.n_elem); } else - if(is_Row::value) { - arma_debug_check( (n_rows != 1), "randu(): incompatible size" ); + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); } - return Gen(n_rows, n_cols); + return out; } template arma_warn_unused -arma_inline -const Gen -randu(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = 0) +inline +obj_type +randu(const SizeMat& s, const distr_param& param = distr_param(), const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - return randu(s.n_rows, s.n_cols); + return randu(s.n_rows, s.n_cols, param); } +// cubes + + arma_warn_unused -arma_inline -const GenCube -randu(const uword n_rows, const uword n_cols, const uword n_slices) +inline +cube +randu(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return GenCube(n_rows, n_cols, n_slices); + cube out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; } arma_warn_unused -arma_inline -const GenCube -randu(const SizeCube& s) +inline +cube +randu(const SizeCube& s, const distr_param& param = distr_param()) { arma_extra_debug_sigprint(); - return GenCube(s.n_rows, s.n_cols, s.n_slices); + return randu(s.n_rows, s.n_cols, s.n_slices, param); } template arma_warn_unused -arma_inline -const GenCube -randu(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = 0) +inline +cube_type +randu(const uword n_rows, const uword n_cols, const uword n_slices, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - return GenCube(n_rows, n_cols, n_slices); + typedef typename cube_type::elem_type eT; + + cube_type out(n_rows, n_cols, n_slices, arma_nozeros_indicator()); + + if(param.state == 0) + { + arma_rng::randu::fill(out.memptr(), out.n_elem); + } + else + { + double a = double(0); + double b = double(1); + + param.get_double_vals(a,b); + + arma_debug_check( (a >= b), "randu(): incorrect distribution parameters; a must be less than b" ); + + arma_rng::randu::fill(out.memptr(), out.n_elem, a, b); + } + + return out; } template arma_warn_unused -arma_inline -const GenCube -randu(const SizeCube& s, const typename arma_Cube_only::result* junk = 0) +inline +cube_type +randu(const SizeCube& s, const distr_param& param = distr_param(), const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - return GenCube(s.n_rows, s.n_cols, s.n_slices); + return randu(s.n_rows, s.n_cols, s.n_slices, param); } diff --git a/src/armadillo_bits/fn_range.hpp b/src/armadillo_bits/fn_range.hpp index b60c04a7..3a280945 100644 --- a/src/armadillo_bits/fn_range.hpp +++ b/src/armadillo_bits/fn_range.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_rank.hpp b/src/armadillo_bits/fn_rank.hpp index 65c833bd..7701a049 100644 --- a/src/armadillo_bits/fn_rank.hpp +++ b/src/armadillo_bits/fn_rank.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,46 +24,32 @@ template arma_warn_unused inline -uword -rank - ( - const Base& X, - typename T1::pod_type tol = 0.0, - const typename arma_blas_type_only::result* junk = 0 - ) +typename enable_if2< is_supported_blas_type::value, uword >::result +rank(const Base& expr, const typename T1::pod_type tol = 0) { arma_extra_debug_sigprint(); - arma_ignore(junk); - - typedef typename T1::pod_type T; - - uword X_n_rows; - uword X_n_cols; - Col s; - const bool status = auxlib::svd_dc(s, X, X_n_rows, X_n_cols); + uword out = uword(0); - if(status == false) - { - arma_stop_runtime_error("rank(): svd failed"); - - return uword(0); - } + const bool status = op_rank::apply(out, expr.get_ref(), tol); - const uword s_n_elem = s.n_elem; - const T* s_mem = s.memptr(); + if(status == false) { arma_stop_runtime_error("rank(): failed"); return uword(0); } - // set tolerance to default if it hasn't been specified - if( (tol == T(0)) && (s_n_elem > 0) ) - { - tol = (std::max)(X_n_rows, X_n_cols) * s_mem[0] * std::numeric_limits::epsilon(); - } - - uword count = 0; + return out; + } + + + +template +inline +typename enable_if2< is_supported_blas_type::value, bool >::result +rank(uword& out, const Base& expr, const typename T1::pod_type tol = 0) + { + arma_extra_debug_sigprint(); - for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } + out = uword(0); - return count; + return op_rank::apply(out, expr.get_ref(), tol); } diff --git a/src/armadillo_bits/fn_regspace.hpp b/src/armadillo_bits/fn_regspace.hpp index be6db357..83e7de2c 100644 --- a/src/armadillo_bits/fn_regspace.hpp +++ b/src/armadillo_bits/fn_regspace.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_repelem.hpp b/src/armadillo_bits/fn_repelem.hpp index 6d286324..5d1e8177 100644 --- a/src/armadillo_bits/fn_repelem.hpp +++ b/src/armadillo_bits/fn_repelem.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_repmat.hpp b/src/armadillo_bits/fn_repmat.hpp index cd708c75..113bfb39 100644 --- a/src/armadillo_bits/fn_repmat.hpp +++ b/src/armadillo_bits/fn_repmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_reshape.hpp b/src/armadillo_bits/fn_reshape.hpp index d0f282d9..35bef091 100644 --- a/src/armadillo_bits/fn_reshape.hpp +++ b/src/armadillo_bits/fn_reshape.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,11 +25,11 @@ template arma_warn_unused inline typename enable_if2< is_arma_type::value, const Op >::result -reshape(const T1& X, const uword in_n_rows, const uword in_n_cols) +reshape(const T1& X, const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - return Op(X, in_n_rows, in_n_cols); + return Op(X, new_n_rows, new_n_cols); } @@ -45,51 +47,38 @@ reshape(const T1& X, const SizeMat& s) -//! NOTE: don't use this form: it will be removed template -arma_deprecated +arma_frown("don't use this form: it will be removed") inline -const Op -reshape(const Base& X, const uword in_n_rows, const uword in_n_cols, const uword dim) //!< NOTE: don't use this form: it will be removed +Mat +reshape(const Base& X, const uword new_n_rows, const uword new_n_cols, const uword dim) { arma_extra_debug_sigprint(); - arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); - - // arma_debug_warn("this form of reshape() is deprecated and will be removed"); + typedef typename T1::elem_type eT; - return Op(X.get_ref(), in_n_rows, in_n_cols, dim, 'j'); - } - - - -template -arma_warn_unused -inline -const OpCube -reshape(const BaseCube& X, const uword in_n_rows, const uword in_n_cols, const uword in_n_slices) - { - arma_extra_debug_sigprint(); + arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); - return OpCube(X.get_ref(), in_n_rows, in_n_cols, in_n_slices, uword(0), 'j'); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -const OpCube -reshape(const BaseCube& X, const uword in_n_rows, const uword in_n_cols, const uword in_n_slices, const uword dim) //!< NOTE: don't use this form: it will be removed - { - arma_extra_debug_sigprint(); + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; - arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); + Mat out; - // arma_debug_warn("this form of reshape() is deprecated and will be removed"); + if(dim == 0) + { + op_reshape::apply_mat_noalias(out, A, new_n_rows, new_n_cols); + } + else + if(dim == 1) + { + Mat tmp; + + op_strans::apply_mat_noalias(tmp, A); + + op_reshape::apply_mat_noalias(out, tmp, new_n_rows, new_n_cols); + } - return OpCube(X.get_ref(), in_n_rows, in_n_cols, in_n_slices, dim, 'j'); + return out; } @@ -97,30 +86,25 @@ reshape(const BaseCube& X, const uword in_n_rows, con template arma_warn_unused inline -const OpCube -reshape(const BaseCube& X, const SizeCube& s) +const OpCube +reshape(const BaseCube& X, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - return OpCube(X.get_ref(), s.n_rows, s.n_cols, s.n_slices, uword(0), 'j'); + return OpCube(X.get_ref(), new_n_rows, new_n_cols, new_n_slices); } -//! NOTE: don't use this form: it will be removed template -arma_deprecated +arma_warn_unused inline -const OpCube -reshape(const BaseCube& X, const SizeCube& s, const uword dim) //!< NOTE: don't use this form: it will be removed +const OpCube +reshape(const BaseCube& X, const SizeCube& s) { arma_extra_debug_sigprint(); - arma_debug_check( (dim > 1), "reshape(): parameter 'dim' must be 0 or 1" ); - - // arma_debug_warn("this form of reshape() is deprecated and will be removed"); - - return OpCube(X.get_ref(), s.n_rows, s.n_cols, s.n_slices, dim, 'j'); + return OpCube(X.get_ref(), s.n_rows, s.n_cols, s.n_slices); } @@ -129,11 +113,11 @@ template arma_warn_unused inline const SpOp -reshape(const SpBase& X, const uword in_n_rows, const uword in_n_cols) +reshape(const SpBase& X, const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); - return SpOp(X.get_ref(), in_n_rows, in_n_cols); + return SpOp(X.get_ref(), new_n_rows, new_n_cols); } diff --git a/src/armadillo_bits/fn_resize.hpp b/src/armadillo_bits/fn_resize.hpp index 9fab807b..7088290d 100644 --- a/src/armadillo_bits/fn_resize.hpp +++ b/src/armadillo_bits/fn_resize.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_reverse.hpp b/src/armadillo_bits/fn_reverse.hpp index 17500ee4..284c80dd 100644 --- a/src/armadillo_bits/fn_reverse.hpp +++ b/src/armadillo_bits/fn_reverse.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_roots.hpp b/src/armadillo_bits/fn_roots.hpp index b6fa537a..80fe240c 100644 --- a/src/armadillo_bits/fn_roots.hpp +++ b/src/armadillo_bits/fn_roots.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -51,7 +53,11 @@ roots(Mat< std::complex >& out, const Base& S, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -40,7 +42,7 @@ schur if(status == false) { S.soft_reset(); - arma_debug_warn("schur(): decomposition failed"); + arma_debug_warn_level(3, "schur(): decomposition failed"); } return status; @@ -55,7 +57,7 @@ Mat schur ( const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -87,7 +89,7 @@ schur Mat& U, Mat& S, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -101,7 +103,7 @@ schur { U.soft_reset(); S.soft_reset(); - arma_debug_warn("schur(): decomposition failed"); + arma_debug_warn_level(3, "schur(): decomposition failed"); } return status; diff --git a/src/armadillo_bits/fn_shift.hpp b/src/armadillo_bits/fn_shift.hpp index 7b1da38e..d3de6a7c 100644 --- a/src/armadillo_bits/fn_shift.hpp +++ b/src/armadillo_bits/fn_shift.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -39,19 +41,19 @@ shift const uword len = (N < 0) ? uword(-N) : uword(N); const uword neg = (N < 0) ? uword( 1) : uword(0); - return Op(X, len, neg, uword(0), 'j'); + return Op(X, len, neg); } template arma_warn_unused -arma_inline +inline typename enable_if2 < is_arma_type::value && resolves_to_vector::no, - const Op + Mat >::result shift ( @@ -61,22 +63,30 @@ shift { arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; + const uword len = (N < 0) ? uword(-N) : uword(N); const uword neg = (N < 0) ? uword( 1) : uword(0); - return Op(X, len, neg, uword(0), 'j'); + quasi_unwrap U(X); + + Mat out; + + op_shift::apply_noalias(out, U.M, len, neg, 0); + + return out; } template arma_warn_unused -arma_inline +inline typename enable_if2 < (is_arma_type::value), - const Op + Mat >::result shift ( @@ -87,10 +97,20 @@ shift { arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; + + arma_debug_check( (dim > 1), "shift(): parameter 'dim' must be 0 or 1" ); + const uword len = (N < 0) ? uword(-N) : uword(N); const uword neg = (N < 0) ? uword( 1) : uword(0); - return Op(X, len, neg, dim, 'j'); + quasi_unwrap U(X); + + Mat out; + + op_shift::apply_noalias(out, U.M, len, neg, dim); + + return out; } diff --git a/src/armadillo_bits/fn_shuffle.hpp b/src/armadillo_bits/fn_shuffle.hpp index 7b1eec64..a0e0f652 100644 --- a/src/armadillo_bits/fn_shuffle.hpp +++ b/src/armadillo_bits/fn_shuffle.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_size.hpp b/src/armadillo_bits/fn_size.hpp index 507bc6bb..b6ac80ec 100644 --- a/src/armadillo_bits/fn_size.hpp +++ b/src/armadillo_bits/fn_size.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -88,6 +90,38 @@ size(const Col& X) +arma_warn_unused +inline +const SizeMat +size(const arma::span& row_span, const arma::span& col_span) + { + arma_extra_debug_sigprint(); + + uword n_rows = 0; + uword n_cols = 0; + + if(row_span.whole || col_span.whole) + { + arma_debug_check(true, "size(): span::all not supported"); + } + else + { + if((row_span.a > row_span.b) || (col_span.a > col_span.b)) + { + arma_debug_check_bounds(true, "size(): span indices incorrectly used"); + } + else + { + n_rows = row_span.b - row_span.a + 1; + n_cols = col_span.b - col_span.a + 1; + } + } + + return SizeMat(n_rows, n_cols); + } + + + template arma_warn_unused inline @@ -159,6 +193,40 @@ size(const BaseCube& X, const uword dim) +arma_warn_unused +inline +const SizeCube +size(const arma::span& row_span, const arma::span& col_span, const arma::span& slice_span) + { + arma_extra_debug_sigprint(); + + uword n_rows = 0; + uword n_cols = 0; + uword n_slices = 0; + + if(row_span.whole || col_span.whole || slice_span.whole) + { + arma_debug_check(true, "size(): span::all not supported"); + } + else + { + if((row_span.a > row_span.b) || (col_span.a > col_span.b) || (slice_span.a > slice_span.b)) + { + arma_debug_check_bounds(true, "size(): span indices incorrectly used"); + } + else + { + n_rows = row_span.b - row_span.a + 1; + n_cols = col_span.b - col_span.a + 1; + n_slices = slice_span.b - slice_span.a + 1; + } + } + + return SizeCube(n_rows, n_cols, n_slices); + } + + + template arma_warn_unused inline diff --git a/src/armadillo_bits/fn_solve.hpp b/src/armadillo_bits/fn_solve.hpp index d5aed957..12ca693b 100644 --- a/src/armadillo_bits/fn_solve.hpp +++ b/src/armadillo_bits/fn_solve.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,84 +28,64 @@ template arma_warn_unused inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result +typename enable_if2< is_supported_blas_type::value, const Glue >::result solve ( const Base& A, - const Base& B, - const solve_opts::opts& opts = solve_opts::none + const Base& B ) { arma_extra_debug_sigprint(); - return Glue(A.get_ref(), B.get_ref(), opts.flags); + return Glue(A.get_ref(), B.get_ref()); } -//! NOTE: don't use this form: it will be removed template -arma_deprecated inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result +typename enable_if2< is_supported_blas_type::value, bool >::result solve ( + Mat& out, const Base& A, - const Base& B, - const bool // argument kept only for compatibility with old user code + const Base& B ) { arma_extra_debug_sigprint(); - // arma_debug_warn("solve(A,B,bool) is deprecated and will be removed; change to solve(A,B)"); + const bool status = glue_solve_gen_default::apply(out, A.get_ref(), B.get_ref()); - return Glue(A.get_ref(), B.get_ref(), solve_opts::flag_none); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result -solve - ( - const Base& A, - const Base& B, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("solve(A,B,char*) is deprecated and will be removed; change to solve(A,B)"); + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } - return Glue(A.get_ref(), B.get_ref(), solve_opts::flag_none); + return status; } template +arma_warn_unused inline -typename enable_if2< is_supported_blas_type::value, bool >::result +typename enable_if2< is_supported_blas_type::value, const Glue >::result solve ( - Mat& out, const Base& A, const Base& B, - const solve_opts::opts& opts = solve_opts::none + const solve_opts::opts& opts ) { arma_extra_debug_sigprint(); - return glue_solve_gen::apply(out, A.get_ref(), B.get_ref(), opts.flags); + return Glue(A.get_ref(), B.get_ref(), opts.flags); } -//! NOTE: don't use this form: it will be removed template -arma_deprecated inline typename enable_if2< is_supported_blas_type::value, bool >::result solve @@ -111,36 +93,20 @@ solve Mat& out, const Base& A, const Base& B, - const bool // argument kept only for compatibility with old user code + const solve_opts::opts& opts ) { arma_extra_debug_sigprint(); - // arma_debug_warn("solve(X,A,B,bool) is deprecated and will be removed; change to solve(X,A,B)"); + const bool status = glue_solve_gen_full::apply(out, A.get_ref(), B.get_ref(), opts.flags); - return glue_solve_gen::apply(out, A.get_ref(), B.get_ref(), solve_opts::flag_none); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, bool >::result -solve - ( - Mat& out, - const Base& A, - const Base& B, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } - // arma_debug_warn("solve(X,A,B,char*) is deprecated and will be removed; change to solve(X,A,B)"); - - return glue_solve_gen::apply(out, A.get_ref(), B.get_ref(), solve_opts::flag_none); + return status; } @@ -174,7 +140,7 @@ solve template arma_warn_unused inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result +typename enable_if2< is_supported_blas_type::value, const Glue >::result solve ( const Op& A, @@ -189,59 +155,7 @@ solve if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } - return Glue(A.m, B.get_ref(), flags); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result -solve - ( - const Op& A, - const Base& B, - const bool // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("solve(A,B,bool) is deprecated and will be removed; change to solve(A,B)"); - - uword flags = uword(0); - - if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } - if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } - - return Glue(A.m, B.get_ref(), flags); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, const Glue >::result -solve - ( - const Op& A, - const Base& B, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("solve(A,B,char*) is deprecated and will be removed; change to solve(A,B)"); - - uword flags = uword(0); - - if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } - if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } - - return Glue(A.m, B.get_ref(), flags); + return Glue(A.m, B.get_ref(), flags); } @@ -263,37 +177,20 @@ solve if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } - return glue_solve_tri_default::apply(out, A.m, B.get_ref(), flags); - } - - - -template -inline -typename enable_if2< is_supported_blas_type::value, bool >::result -solve - ( - Mat& out, - const Op& A, - const Base& B, - const solve_opts::opts& opts - ) - { - arma_extra_debug_sigprint(); + const bool status = glue_solve_tri_default::apply(out, A.m, B.get_ref(), flags); - uword flags = opts.flags; - - if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } - if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } - return glue_solve_tri::apply(out, A.m, B.get_ref(), flags); + return status; } -//! NOTE: don't use this form: it will be removed template -arma_deprecated inline typename enable_if2< is_supported_blas_type::value, bool >::result solve @@ -301,46 +198,25 @@ solve Mat& out, const Op& A, const Base& B, - const bool // argument kept only for compatibility with old user code + const solve_opts::opts& opts ) { arma_extra_debug_sigprint(); - // arma_debug_warn("solve(X,A,B,bool) is deprecated and will be removed; change to solve(X,A,B)"); - - uword flags = uword(0); + uword flags = opts.flags; if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } - return glue_solve_tri_default::apply(out, A.m, B.get_ref(), flags); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename enable_if2< is_supported_blas_type::value, bool >::result -solve - ( - Mat& out, - const Op& A, - const Base& B, - const char* // argument kept only for compatibility with old user code - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("solve(X,A,B,char*) is deprecated and will be removed; change to solve(X,A,B)"); + const bool status = glue_solve_tri_full::apply(out, A.m, B.get_ref(), flags); - uword flags = uword(0); - - if(A.aux_uword_a == 0) { flags |= solve_opts::flag_triu; } - if(A.aux_uword_a == 1) { flags |= solve_opts::flag_tril; } + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "solve(): solution not found"); + } - return glue_solve_tri_default::apply(out, A.m, B.get_ref(), flags); + return status; } diff --git a/src/armadillo_bits/fn_sort.hpp b/src/armadillo_bits/fn_sort.hpp index 215b80e5..01b45f2b 100644 --- a/src/armadillo_bits/fn_sort.hpp +++ b/src/armadillo_bits/fn_sort.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -61,82 +63,6 @@ sort -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename -enable_if2 - < - is_arma_type::value && resolves_to_vector::yes, - const Op - >::result -sort - ( - const T1& X, - const uword sort_type - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("sort(X,uword) is deprecated and will be removed; change to sort(X,sort_direction)"); - - return Op(X, sort_type, 0); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename -enable_if2 - < - is_arma_type::value && resolves_to_vector::no, - const Op - >::result -sort - ( - const T1& X, - const uword sort_type - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("sort(X,uword) is deprecated and will be removed; change to sort(X,sort_direction)"); - - return Op(X, sort_type, 0); - } - - - -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -typename -enable_if2 - < - (is_arma_type::value), - const Op - >::result -sort - ( - const T1& X, - const uword sort_type, - const uword dim - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("sort(X,uword,uword) is deprecated and will be removed; change to sort(X,sort_direction,dim)"); - - return Op(X, sort_type, dim); - } - - - template arma_warn_unused inline @@ -154,9 +80,9 @@ sort { arma_extra_debug_sigprint(); - const char sig = (sort_direction != NULL) ? sort_direction[0] : char(0); + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); - arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction"); + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); const uword sort_type = (sig == 'a') ? 0 : 1; @@ -182,9 +108,9 @@ sort { arma_extra_debug_sigprint(); - const char sig = (sort_direction != NULL) ? sort_direction[0] : char(0); + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); - arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction"); + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); const uword sort_type = (sig == 'a') ? 0 : 1; @@ -211,9 +137,9 @@ sort { arma_extra_debug_sigprint(); - const char sig = (sort_direction != NULL) ? sort_direction[0] : char(0); + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); - arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction"); + arma_debug_check( (sig != 'a') && (sig != 'd'), "sort(): unknown sort direction" ); const uword sort_type = (sig == 'a') ? 0 : 1; diff --git a/src/armadillo_bits/fn_sort_index.hpp b/src/armadillo_bits/fn_sort_index.hpp index 4ffbba10..1df3693c 100644 --- a/src/armadillo_bits/fn_sort_index.hpp +++ b/src/armadillo_bits/fn_sort_index.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,35 +37,13 @@ sort_index -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -const mtOp -sort_index - ( - const Base& X, - const uword sort_type - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("sort_index(X,uword) is deprecated and will be removed; change to sort_index(X,sort_direction)"); - - arma_debug_check( (sort_type > 1), "sort_index(): parameter 'sort_type' must be 0 or 1" ); - - return mtOp(X.get_ref(), sort_type, uword(0)); - } - - - template arma_warn_unused inline typename enable_if2 < - ( (is_arma_type::value == true) && (is_same_type::value == true) ), + ( (is_arma_type::value) && (is_same_type::value) ), const mtOp >::result sort_index @@ -74,7 +54,7 @@ sort_index { arma_extra_debug_sigprint(); - const char sig = (sort_direction != NULL) ? sort_direction[0] : char(0); + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); arma_debug_check( ((sig != 'a') && (sig != 'd')), "sort_index(): unknown sort direction" ); @@ -103,35 +83,13 @@ stable_sort_index -//! NOTE: don't use this form: it will be removed -template -arma_deprecated -inline -const mtOp -stable_sort_index - ( - const Base& X, - const uword sort_type - ) - { - arma_extra_debug_sigprint(); - - // arma_debug_warn("stable_sort_index(X,uword) is deprecated and will be removed; change to stable_sort_index(X,sort_direction)"); - - arma_debug_check( (sort_type > 1), "stable_sort_index(): parameter 'sort_type' must be 0 or 1" ); - - return mtOp(X.get_ref(), sort_type, uword(0)); - } - - - template arma_warn_unused inline typename enable_if2 < - ( (is_arma_type::value == true) && (is_same_type::value == true) ), + ( (is_arma_type::value) && (is_same_type::value) ), const mtOp >::result stable_sort_index @@ -142,7 +100,7 @@ stable_sort_index { arma_extra_debug_sigprint(); - const char sig = (sort_direction != NULL) ? sort_direction[0] : char(0); + const char sig = (sort_direction != nullptr) ? sort_direction[0] : char(0); arma_debug_check( ((sig != 'a') && (sig != 'd')), "stable_sort_index(): unknown sort direction" ); diff --git a/src/armadillo_bits/fn_speye.hpp b/src/armadillo_bits/fn_speye.hpp index 71a851da..48570be2 100644 --- a/src/armadillo_bits/fn_speye.hpp +++ b/src/armadillo_bits/fn_speye.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -24,20 +26,13 @@ template arma_warn_unused inline obj_type -speye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = NULL) +speye(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_SpCol::value) - { - arma_debug_check( (n_cols != 1), "speye(): incompatible size" ); - } - else - if(is_SpRow::value) - { - arma_debug_check( (n_rows != 1), "speye(): incompatible size" ); - } + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "speye(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "speye(): incompatible size" ); } obj_type out; @@ -52,7 +47,7 @@ template arma_warn_unused inline obj_type -speye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = NULL) +speye(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_spones.hpp b/src/armadillo_bits/fn_spones.hpp index 0697622b..ff45b218 100644 --- a/src/armadillo_bits/fn_spones.hpp +++ b/src/armadillo_bits/fn_spones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_sprandn.hpp b/src/armadillo_bits/fn_sprandn.hpp index 0f380159..1798224a 100644 --- a/src/armadillo_bits/fn_sprandn.hpp +++ b/src/armadillo_bits/fn_sprandn.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,7 +32,7 @@ sprandn const uword n_rows, const uword n_cols, const double density, - const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0 + const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -59,7 +61,7 @@ template arma_warn_unused inline obj_type -sprandn(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0) +sprandn(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_sprandu.hpp b/src/armadillo_bits/fn_sprandu.hpp index 96905e51..846e75b0 100644 --- a/src/armadillo_bits/fn_sprandu.hpp +++ b/src/armadillo_bits/fn_sprandu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,7 +32,7 @@ sprandu const uword n_rows, const uword n_cols, const double density, - const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0 + const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -59,7 +61,7 @@ template arma_warn_unused inline obj_type -sprandu(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0) +sprandu(const SizeMat& s, const double density, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/fn_spsolve.hpp b/src/armadillo_bits/fn_spsolve.hpp index e116617b..3eaf333b 100644 --- a/src/armadillo_bits/fn_spsolve.hpp +++ b/src/armadillo_bits/fn_spsolve.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -17,8 +19,7 @@ //! \addtogroup fn_spsolve //! @{ -//! Solve a system of linear equations, i.e., A*X = B, where X is unknown, -//! A is sparse, and B is dense. X will be dense too. + template inline @@ -28,9 +29,9 @@ spsolve_helper Mat& out, const SpBase& A, const Base& B, - const char* solver, - const spsolve_opts_base& settings, - const typename arma_blas_type_only::result* junk = 0 + const char* solver, + const spsolve_opts_base& settings, + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -39,7 +40,7 @@ spsolve_helper typedef typename T1::pod_type T; typedef typename T1::elem_type eT; - const char sig = (solver != NULL) ? solver[0] : char(0); + const char sig = (solver != nullptr) ? solver[0] : char(0); arma_debug_check( ((sig != 'l') && (sig != 's')), "spsolve(): unknown solver" ); @@ -54,7 +55,7 @@ spsolve_helper const superlu_opts& opts = (settings.id == 1) ? static_cast(settings) : superlu_opts_default; - arma_debug_check( ( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ), "spsolve(): pivot_thresh out of bounds" ); + arma_debug_check( ( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ), "spsolve(): pivot_thresh must be in the [0,1] interval" ); if(sig == 's') // SuperLU solver { @@ -72,7 +73,7 @@ spsolve_helper { if( (settings.id != 0) && ((opts.symmetric) || (opts.pivot_thresh != double(1))) ) { - arma_debug_warn("spsolve(): ignoring settings not applicable to LAPACK based solver"); + arma_debug_warn_level(1, "spsolve(): ignoring settings not applicable to LAPACK based solver"); } Mat AA; @@ -87,9 +88,9 @@ spsolve_helper conversion_ok = true; } - catch(std::bad_alloc&) + catch(...) { - arma_debug_warn("spsolve(): not enough memory to use LAPACK based solver"); + arma_debug_warn_level(1, "spsolve(): not enough memory to use LAPACK based solver"); } if(conversion_ok) @@ -102,22 +103,19 @@ spsolve_helper if(opts.equilibrate == true ) { flags |= solve_opts::flag_equilibrate; } if(opts.allow_ugly == true ) { flags |= solve_opts::flag_allow_ugly; } - status = glue_solve_gen::apply(out, AA, B.get_ref(), flags); + status = glue_solve_gen_full::apply(out, AA, B.get_ref(), flags); } } - if(status == false) + if( (status == false) && (rcond > T(0)) ) { - if(rcond > T(0)) { arma_debug_warn("spsolve(): system seems singular (rcond: ", rcond, ")"); } - else { arma_debug_warn("spsolve(): system seems singular"); } - - out.soft_reset(); + arma_debug_warn_level(2, "spsolve(): system is singular (rcond: ", rcond, ")"); } - if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(out)) ) + if( (status == true) && (rcond > T(0)) && (rcond < std::numeric_limits::epsilon()) ) { - arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")"); + arma_debug_warn_level(2, "solve(): solution computed, but system is singular to working precision (rcond: ", rcond, ")"); } return status; @@ -125,6 +123,10 @@ spsolve_helper +// + + + template inline bool @@ -133,9 +135,9 @@ spsolve Mat& out, const SpBase& A, const Base& B, - const char* solver = "superlu", - const spsolve_opts_base& settings = spsolve_opts_none(), - const typename arma_blas_type_only::result* junk = 0 + const char* solver = "superlu", + const spsolve_opts_base& settings = spsolve_opts_none(), + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -143,6 +145,12 @@ spsolve const bool status = spsolve_helper(out, A.get_ref(), B.get_ref(), solver, settings); + if(status == false) + { + out.soft_reset(); + arma_debug_warn_level(3, "spsolve(): solution not found"); + } + return status; } @@ -156,9 +164,9 @@ spsolve ( const SpBase& A, const Base& B, - const char* solver = "superlu", - const spsolve_opts_base& settings = spsolve_opts_none(), - const typename arma_blas_type_only::result* junk = 0 + const char* solver = "superlu", + const spsolve_opts_base& settings = spsolve_opts_none(), + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -172,6 +180,7 @@ spsolve if(status == false) { + out.soft_reset(); arma_stop_runtime_error("spsolve(): solution not found"); } diff --git a/src/armadillo_bits/fn_sqrtmat.hpp b/src/armadillo_bits/fn_sqrtmat.hpp index 9da3d5a5..882aa15c 100644 --- a/src/armadillo_bits/fn_sqrtmat.hpp +++ b/src/armadillo_bits/fn_sqrtmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,7 +58,7 @@ sqrtmat(Mat< std::complex >& Y, const Base& Y, const Base& X if(status == false) { - arma_debug_warn("sqrtmat(): given matrix seems singular; may not have a square root"); + arma_debug_warn_level(3, "sqrtmat(): given matrix is singular; may not have a square root"); } return status; @@ -112,7 +114,7 @@ sqrtmat_sympd(Mat& Y, const Base(X, norm_type, 0); } @@ -67,7 +69,7 @@ enable_if2 stddev(const T1& X, const uword norm_type, const uword dim) { arma_extra_debug_sigprint(); - + return mtOp(X, norm_type, dim); } @@ -75,7 +77,7 @@ stddev(const T1& X, const uword norm_type, const uword dim) template arma_warn_unused -arma_inline +inline typename arma_scalar_only::result stddev(const T&) { diff --git a/src/armadillo_bits/fn_strans.hpp b/src/armadillo_bits/fn_strans.hpp index a2ad82ae..de81e198 100644 --- a/src/armadillo_bits/fn_strans.hpp +++ b/src/armadillo_bits/fn_strans.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,8 +28,8 @@ const Op strans ( const T1& X, - const typename enable_if< is_arma_type::value == true >::result* junk1 = 0, - const typename arma_cx_only::result* junk2 = 0 + const typename enable_if< is_arma_type::value >::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr ) { arma_extra_debug_sigprint(); @@ -48,8 +50,8 @@ const Op strans ( const T1& X, - const typename enable_if< is_arma_type::value == true >::result* junk1 = 0, - const typename arma_not_cx::result* junk2 = 0 + const typename enable_if< is_arma_type::value >::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr ) { arma_extra_debug_sigprint(); @@ -72,8 +74,8 @@ const SpOp strans ( const T1& X, - const typename enable_if< is_arma_sparse_type::value == true >::result* junk1 = 0, - const typename arma_cx_only::result* junk2 = 0 + const typename enable_if< is_arma_sparse_type::value >::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr ) { arma_extra_debug_sigprint(); @@ -92,8 +94,8 @@ const SpOp strans ( const T1& X, - const typename enable_if< is_arma_sparse_type::value == true >::result* junk1 = 0, - const typename arma_not_cx::result* junk2 = 0 + const typename enable_if< is_arma_sparse_type::value >::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/fn_sum.hpp b/src/armadillo_bits/fn_sum.hpp index 01382820..0fa89369 100644 --- a/src/armadillo_bits/fn_sum.hpp +++ b/src/armadillo_bits/fn_sum.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_svd.hpp b/src/armadillo_bits/fn_svd.hpp index ee8dde63..ff987bb1 100644 --- a/src/armadillo_bits/fn_svd.hpp +++ b/src/armadillo_bits/fn_svd.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,20 +28,22 @@ svd ( Col& S, const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - // it doesn't matter if X is an alias of S, as auxlib::svd() makes a copy of X + typedef typename T1::elem_type eT; + + Mat A(X.get_ref()); - const bool status = auxlib::svd_dc(S, X); + const bool status = auxlib::svd_dc(S, A); if(status == false) { S.soft_reset(); - arma_debug_warn("svd(): decomposition failed"); + arma_debug_warn_level(3, "svd(): decomposition failed"); } return status; @@ -54,15 +58,20 @@ Col svd ( const Base& X, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - Col out; + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Col out; + + Mat A(X.get_ref()); - const bool status = auxlib::svd_dc(out, X); + const bool status = auxlib::svd_dc(out, A); if(status == false) { @@ -85,31 +94,34 @@ svd Mat& V, const Base& X, const char* method = "dc", - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); + typedef typename T1::elem_type eT; + arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), "svd(): two or more output objects are the same object" ); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'd')), "svd(): unknown method specified" ); - // auxlib::svd() makes an internal copy of X - const bool status = (sig == 'd') ? auxlib::svd_dc(U, S, V, X) : auxlib::svd(U, S, V, X); + Mat A(X.get_ref()); + + const bool status = (sig == 'd') ? auxlib::svd_dc(U, S, V, A) : auxlib::svd(U, S, V, A); if(status == false) { U.soft_reset(); S.soft_reset(); V.soft_reset(); - arma_debug_warn("svd(): decomposition failed"); + arma_debug_warn_level(3, "svd(): decomposition failed"); } return status; @@ -128,12 +140,14 @@ svd_econ const Base& X, const char mode, const char* method = "dc", - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); + typedef typename T1::elem_type eT; + arma_debug_check ( ( ((void*)(&U) == (void*)(&S)) || (&U == &V) || ((void*)(&S) == (void*)(&V)) ), @@ -146,18 +160,20 @@ svd_econ "svd_econ(): parameter 'mode' is incorrect" ); - const char sig = (method != NULL) ? method[0] : char(0); + const char sig = (method != nullptr) ? method[0] : char(0); arma_debug_check( ((sig != 's') && (sig != 'd')), "svd_econ(): unknown method specified" ); - const bool status = ((mode == 'b') && (sig == 'd')) ? auxlib::svd_dc_econ(U, S, V, X) : auxlib::svd_econ(U, S, V, X, mode); + Mat A(X.get_ref()); + + const bool status = ((mode == 'b') && (sig == 'd')) ? auxlib::svd_dc_econ(U, S, V, A) : auxlib::svd_econ(U, S, V, A, mode); if(status == false) { U.soft_reset(); S.soft_reset(); V.soft_reset(); - arma_debug_warn("svd(): decomposition failed"); + arma_debug_warn_level(3, "svd_econ(): decomposition failed"); } return status; @@ -176,13 +192,13 @@ svd_econ const Base& X, const char* mode = "both", const char* method = "dc", - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); arma_ignore(junk); - return svd_econ(U, S, V, X, ((mode != NULL) ? mode[0] : char(0)), method); + return svd_econ(U, S, V, X, ((mode != nullptr) ? mode[0] : char(0)), method); } diff --git a/src/armadillo_bits/fn_svds.hpp b/src/armadillo_bits/fn_svds.hpp index 6851486a..26c8c50e 100644 --- a/src/armadillo_bits/fn_svds.hpp +++ b/src/armadillo_bits/fn_svds.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,7 +32,7 @@ svds_helper const uword k, const typename T1::pod_type tol, const bool calc_UV, - const typename arma_real_only::result* junk = 0 + const typename arma_real_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -81,7 +83,10 @@ svds_helper Col eigval; Mat eigvec; - const bool status = sp_auxlib::eigs_sym(eigval, eigvec, C, kk, "la", (tol / Datum::sqrt2)); + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_sym(eigval, eigvec, C, kk, "la", opts); if(status == false) { @@ -118,15 +123,15 @@ svds_helper if(calc_UV) { - uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } - uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } - if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); } + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } return true; } @@ -145,7 +150,7 @@ svds_helper const uword k, const typename T1::pod_type tol, const bool calc_UV, - const typename arma_cx_only::result* junk = 0 + const typename arma_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -202,7 +207,10 @@ svds_helper Col eigval_tmp; Mat eigvec; - const bool status = sp_auxlib::eigs_gen(eigval_tmp, eigvec, C, kk, "lr", (tol / Datum::sqrt2)); + eigs_opts opts; + opts.tol = (tol / Datum::sqrt2); + + const bool status = eigs_gen(eigval_tmp, eigvec, C, kk, "lr", opts); if(status == false) { @@ -241,15 +249,15 @@ svds_helper if(calc_UV) { - uvec U_row_indices(A.n_rows); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } - uvec V_row_indices(A.n_cols); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } + uvec U_row_indices(A.n_rows, arma_nozeros_indicator()); for(uword i=0; i < A.n_rows; ++i) { U_row_indices[i] = i; } + uvec V_row_indices(A.n_cols, arma_nozeros_indicator()); for(uword i=0; i < A.n_cols; ++i) { V_row_indices[i] = i + A.n_rows; } U = Datum::sqrt2 * eigvec(U_row_indices, sorted_indices); V = Datum::sqrt2 * eigvec(V_row_indices, sorted_indices); } } - if(S.n_elem < k) { arma_debug_warn("svds(): found fewer singular values than specified"); } + if(S.n_elem < k) { arma_debug_warn_level(1, "svds(): found fewer singular values than specified"); } return true; } @@ -268,7 +276,7 @@ svds const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -276,7 +284,7 @@ svds const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, true); - if(status == false) { arma_debug_warn("svds(): decomposition failed"); } + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } return status; } @@ -293,7 +301,7 @@ svds const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -304,7 +312,7 @@ svds const bool status = svds_helper(U, S, V, X.get_ref(), k, tol, false); - if(status == false) { arma_debug_warn("svds(): decomposition failed"); } + if(status == false) { arma_debug_warn_level(3, "svds(): decomposition failed"); } return status; } @@ -321,7 +329,7 @@ svds const SpBase& X, const uword k, const typename T1::pod_type tol = 0.0, - const typename arma_real_or_cx_only::result* junk = 0 + const typename arma_real_or_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/fn_syl_lyap.hpp b/src/armadillo_bits/fn_sylvester.hpp similarity index 72% rename from src/armadillo_bits/fn_syl_lyap.hpp rename to src/armadillo_bits/fn_sylvester.hpp index b0405054..a5b81654 100644 --- a/src/armadillo_bits/fn_syl_lyap.hpp +++ b/src/armadillo_bits/fn_sylvester.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,7 +30,7 @@ syl const Base& in_A, const Base& in_B, const Base& in_C, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -49,7 +51,7 @@ syl if(status == false) { out.soft_reset(); - arma_debug_warn("syl(): solution not found"); + arma_debug_warn_level(3, "syl(): solution not found"); } return status; @@ -57,6 +59,24 @@ syl +template +inline +bool +sylvester + ( + Mat & out, + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_ignore(junk); + return syl(out, in_A, in_B, in_C); + } + + + template arma_warn_unused inline @@ -66,7 +86,7 @@ syl const Base& in_A, const Base& in_B, const Base& in_C, - const typename arma_blas_type_only::result* junk = 0 + const typename arma_blas_type_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -97,4 +117,21 @@ syl +template +arma_warn_unused +inline +Mat +sylvester + ( + const Base& in_A, + const Base& in_B, + const Base& in_C, + const typename arma_blas_type_only::result* junk = nullptr + ) + { + arma_ignore(junk); + return syl(in_A, in_B, in_C); + } + + //! @} diff --git a/src/armadillo_bits/fn_symmat.hpp b/src/armadillo_bits/fn_symmat.hpp index 255f0068..4bee64df 100644 --- a/src/armadillo_bits/fn_symmat.hpp +++ b/src/armadillo_bits/fn_symmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,13 +23,13 @@ template arma_warn_unused arma_inline -typename enable_if2< is_cx::no, const Op >::result +typename enable_if2< is_cx::no, const Op >::result symmatu(const Base& X, const bool do_conj = false) { arma_extra_debug_sigprint(); arma_ignore(do_conj); - return Op(X.get_ref(), 0, 0); + return Op(X.get_ref()); } @@ -35,13 +37,13 @@ symmatu(const Base& X, const bool do_conj = false) template arma_warn_unused arma_inline -typename enable_if2< is_cx::no, const Op >::result +typename enable_if2< is_cx::no, const Op >::result symmatl(const Base& X, const bool do_conj = false) { arma_extra_debug_sigprint(); arma_ignore(do_conj); - return Op(X.get_ref(), 1, 0); + return Op(X.get_ref()); } @@ -49,12 +51,12 @@ symmatl(const Base& X, const bool do_conj = false) template arma_warn_unused arma_inline -typename enable_if2< is_cx::yes, const Op >::result +typename enable_if2< is_cx::yes, const Op >::result symmatu(const Base& X, const bool do_conj = true) { arma_extra_debug_sigprint(); - return Op(X.get_ref(), 0, (do_conj ? 1 : 0)); + return Op(X.get_ref(), 0, (do_conj ? 1 : 0)); } @@ -62,12 +64,12 @@ symmatu(const Base& X, const bool do_conj = true) template arma_warn_unused arma_inline -typename enable_if2< is_cx::yes, const Op >::result +typename enable_if2< is_cx::yes, const Op >::result symmatl(const Base& X, const bool do_conj = true) { arma_extra_debug_sigprint(); - return Op(X.get_ref(), 1, (do_conj ? 1 : 0)); + return Op(X.get_ref(), 0, (do_conj ? 1 : 0)); } diff --git a/src/armadillo_bits/fn_toeplitz.hpp b/src/armadillo_bits/fn_toeplitz.hpp index 9ce2d8ba..660541b5 100644 --- a/src/armadillo_bits/fn_toeplitz.hpp +++ b/src/armadillo_bits/fn_toeplitz.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_trace.hpp b/src/armadillo_bits/fn_trace.hpp index 9ec70a3c..8a15bac2 100644 --- a/src/armadillo_bits/fn_trace.hpp +++ b/src/armadillo_bits/fn_trace.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -100,10 +102,7 @@ trace(const Glue& X) arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); - if( (A.n_elem == 0) || (B.n_elem == 0) ) - { - return eT(0); - } + if( (A.n_elem == 0) || (B.n_elem == 0) ) { return eT(0); } const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; @@ -222,10 +221,7 @@ trace(const Glue& X) arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); - if( (A.n_elem == 0) || (B.n_elem == 0) ) - { - return eT(0); - } + if( (A.n_elem == 0) || (B.n_elem == 0) ) { return eT(0); } const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; @@ -514,10 +510,7 @@ trace(const SpGlue& expr) arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); - if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) - { - return eT(0); - } + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } const uword N = (std::min)(A.n_rows, B.n_cols); @@ -574,10 +567,7 @@ trace(const SpGlue, T2, spglue_times>& expr) // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); - if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) - { - return eT(0); - } + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } const uword N = (std::min)(A.n_cols, B.n_cols); @@ -633,10 +623,7 @@ trace(const SpGlue, T2, spglue_times>& expr) // NOTE: deliberately swapped A.n_rows and A.n_cols to take into account the requested transpose operation arma_debug_assert_mul_size(A.n_cols, A.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); - if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) - { - return eT(0); - } + if( (A.n_nonzero == 0) || (B.n_nonzero == 0) ) { return eT(0); } const uword N = (std::min)(A.n_cols, B.n_cols); diff --git a/src/armadillo_bits/fn_trans.hpp b/src/armadillo_bits/fn_trans.hpp index 798c8d46..f558e106 100644 --- a/src/armadillo_bits/fn_trans.hpp +++ b/src/armadillo_bits/fn_trans.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,7 +27,7 @@ const Op trans ( const T1& X, - const typename enable_if< is_arma_type::value == true >::result* junk = 0 + const typename enable_if< is_arma_type::value >::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -43,7 +45,7 @@ const Op htrans ( const T1& X, - const typename enable_if< is_arma_type::value == true >::result* junk = 0 + const typename enable_if< is_arma_type::value >::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -65,7 +67,7 @@ const SpOp trans ( const T1& X, - const typename enable_if< is_arma_sparse_type::value == true >::result* junk = 0 + const typename enable_if< is_arma_sparse_type::value >::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -83,7 +85,7 @@ const SpOp htrans ( const T1& X, - const typename enable_if< is_arma_sparse_type::value == true >::result* junk = 0 + const typename enable_if< is_arma_sparse_type::value >::result* junk = nullptr ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/fn_trapz.hpp b/src/armadillo_bits/fn_trapz.hpp index 243ec55a..72646b78 100644 --- a/src/armadillo_bits/fn_trapz.hpp +++ b/src/armadillo_bits/fn_trapz.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_trig.hpp b/src/armadillo_bits/fn_trig.hpp index 66de22d8..d73947b6 100644 --- a/src/armadillo_bits/fn_trig.hpp +++ b/src/armadillo_bits/fn_trig.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_trimat.hpp b/src/armadillo_bits/fn_trimat.hpp index e5827d2e..24c95f20 100644 --- a/src/armadillo_bits/fn_trimat.hpp +++ b/src/armadillo_bits/fn_trimat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -106,37 +108,35 @@ trimatu(const Base& X, const sword k) -// // TODO: implement for sparse matrices -// template -// arma_warn_unused -// arma_inline -// const SpOp -// trimatu(const SpBase& X, const sword k) -// { -// arma_extra_debug_sigprint(); -// -// const uword row_offset = (k < 0) ? uword(-k) : uword(0); -// const uword col_offset = (k > 0) ? uword( k) : uword(0); -// -// return SpOp(X.get_ref(), row_offset, col_offset); -// } -// -// -// -// // TODO: implement for sparse matrices -// template -// arma_warn_unused -// arma_inline -// const SpOp -// trimatl(const SpBase& X, const sword k) -// { -// arma_extra_debug_sigprint(); -// -// const uword row_offset = (k < 0) ? uword(-k) : uword(0); -// const uword col_offset = (k > 0) ? uword( k) : uword(0); -// -// return SpOp(X.get_ref(), row_offset, col_offset); -// } +template +arma_warn_unused +arma_inline +const SpOp +trimatu(const SpBase& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return SpOp(X.get_ref(), row_offset, col_offset); + } + + + +template +arma_warn_unused +arma_inline +const SpOp +trimatl(const SpBase& X, const sword k) + { + arma_extra_debug_sigprint(); + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + return SpOp(X.get_ref(), row_offset, col_offset); + } diff --git a/src/armadillo_bits/fn_trimat_ind.hpp b/src/armadillo_bits/fn_trimat_ind.hpp new file mode 100644 index 00000000..4e657052 --- /dev/null +++ b/src/armadillo_bits/fn_trimat_ind.hpp @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_trimat_ind +//! @{ + + +arma_warn_unused +inline +uvec +trimatu_ind(const SizeMat& s, const sword k = 0) + { + arma_extra_debug_sigprint(); + + const uword n_rows = s.n_rows; + const uword n_cols = s.n_cols; + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu_ind(): requested diagonal is out of bounds" ); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + uvec tmp(n_rows * n_cols, arma_nozeros_indicator()); // worst case scenario + uword* tmp_mem = tmp.memptr(); + uword count = 0; + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + const uword end_row = i + row_offset; + + const uword index_offset = (n_rows * col); + + for(uword row=0; row <= end_row; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + else + { + if(col < n_cols) + { + const uword index_offset = (n_rows * col); + + for(uword row=0; row < n_rows; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + } + } + + uvec out; + + out.steal_mem_col(tmp, count); + + return out; + } + + + +arma_warn_unused +inline +uvec +trimatl_ind(const SizeMat& s, const sword k = 0) + { + arma_extra_debug_sigprint(); + + const uword n_rows = s.n_rows; + const uword n_cols = s.n_cols; + + const uword row_offset = (k < 0) ? uword(-k) : uword(0); + const uword col_offset = (k > 0) ? uword( k) : uword(0); + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl_ind(): requested diagonal is out of bounds" ); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + uvec tmp(n_rows * n_cols, arma_nozeros_indicator()); // worst case scenario + uword* tmp_mem = tmp.memptr(); + uword count = 0; + + for(uword col=0; col < col_offset; ++col) + { + const uword index_offset = (n_rows * col); + + for(uword row=0; row < n_rows; ++row) + { + tmp_mem[count] = index_offset + row; + ++count; + } + } + + for(uword i=0; i(X, norm_type, 0); } @@ -67,7 +69,7 @@ enable_if2 var(const T1& X, const uword norm_type, const uword dim) { arma_extra_debug_sigprint(); - + return mtOp(X, norm_type, dim); } @@ -75,7 +77,7 @@ var(const T1& X, const uword norm_type, const uword dim) template arma_warn_unused -arma_inline +inline typename arma_scalar_only::result var(const T&) { @@ -96,7 +98,7 @@ enable_if2 var(const T1& X, const uword norm_type = 0) { arma_extra_debug_sigprint(); - + return spop_var::var_vec(X, norm_type); } @@ -114,7 +116,7 @@ enable_if2 var(const T1& X, const uword norm_type = 0) { arma_extra_debug_sigprint(); - + return mtSpOp(X, norm_type, 0); } @@ -132,7 +134,7 @@ enable_if2 var(const T1& X, const uword norm_type, const uword dim) { arma_extra_debug_sigprint(); - + return mtSpOp(X, norm_type, dim); } diff --git a/src/armadillo_bits/fn_vecnorm.hpp b/src/armadillo_bits/fn_vecnorm.hpp new file mode 100644 index 00000000..0fa88aa9 --- /dev/null +++ b/src/armadillo_bits/fn_vecnorm.hpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup fn_vecnorm +//! @{ + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + if(k == uword(1)) { return op_norm::vec_norm_1(P); } + if(k == uword(2)) { return op_norm::vec_norm_2(P); } + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + + return op_norm::vec_norm_k(P, int(k)); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword dim = 0; + + return mtOp(X, k, dim); + } + + + +template +arma_warn_unused +inline +const mtOp +vecnorm + ( + const Base& X, + const uword k, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtOp(X.get_ref(), k, dim); + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + typedef typename T1::pod_type T; + + const Proxy P(X); + + if(P.get_n_elem() == 0) { return T(0); } + + const char sig = (method != nullptr) ? method[0] : char(0); + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { return op_norm::vec_norm_max(P); } + if( (sig == '-') ) { return op_norm::vec_norm_min(P); } + + arma_stop_logic_error("vecnorm(): unsupported vector norm type"); + + return T(0); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_type::value && resolves_to_vector::no, + const mtOp + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + const uword dim = 0; + + return mtOp(X, method_id, dim); + } + + + +template +arma_warn_unused +inline +const mtOp +vecnorm + ( + const Base& X, + const char* method, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + return mtOp(X.get_ref(), method_id, dim); + } + + + +// +// norms for sparse matrices + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return arma::norm(X, k); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const mtSpOp + >::result +vecnorm + ( + const T1& X, + const uword k = uword(2), + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const uword dim = 0; + + return mtSpOp(X, k, dim); + } + + + +template +arma_warn_unused +inline +const mtSpOp +vecnorm + ( + const SpBase& X, + const uword k, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + return mtSpOp(X.get_ref(), k, dim); + } + + + +// + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::yes, + typename T1::pod_type + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + return arma::norm(X, method); + } + + + +template +arma_warn_unused +inline +typename +enable_if2 + < + is_arma_sparse_type::value && resolves_to_sparse_vector::no, + const mtSpOp + >::result +vecnorm + ( + const T1& X, + const char* method, + const arma_empty_class junk1 = arma_empty_class(), + const typename arma_real_or_cx_only::result* junk2 = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk1); + arma_ignore(junk2); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + const uword dim = 0; + + return mtSpOp(X, method_id, dim); + } + + + +template +arma_warn_unused +inline +const mtSpOp +vecnorm + ( + const SpBase& X, + const char* method, + const uword dim, + const typename arma_real_or_cx_only::result* junk = nullptr + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const char sig = (method != nullptr) ? method[0] : char(0); + + uword method_id = 0; + + if( (sig == 'i') || (sig == 'I') || (sig == '+') ) { method_id = 1; } + if( (sig == '-') ) { method_id = 2; } + + return mtSpOp(X.get_ref(), method_id, dim); + } + + + +//! @} diff --git a/src/armadillo_bits/fn_vectorise.hpp b/src/armadillo_bits/fn_vectorise.hpp index e4535eed..ff210069 100644 --- a/src/armadillo_bits/fn_vectorise.hpp +++ b/src/armadillo_bits/fn_vectorise.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/fn_wishrnd.hpp b/src/armadillo_bits/fn_wishrnd.hpp index 4405edb8..3f05b77b 100644 --- a/src/armadillo_bits/fn_wishrnd.hpp +++ b/src/armadillo_bits/fn_wishrnd.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -72,17 +74,16 @@ wishrnd(Mat& W, const Base& if(status == false) { - arma_debug_warn("wishrnd(): given matrix is not symmetric positive definite"); - return false; + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): given matrix is not symmetric positive definite"); } - return true; + return status; } template -arma_warn_unused inline typename enable_if2 @@ -95,7 +96,15 @@ wishrnd(Mat& W, const Base& arma_extra_debug_sigprint(); arma_ignore(S); - return op_wishrnd::apply_direct(W, D.get_ref(), df, uword(2)); + const bool status = op_wishrnd::apply_direct(W, D.get_ref(), df, uword(2)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): problem with given 'D' matrix"); + } + + return status; } @@ -157,17 +166,16 @@ iwishrnd(Mat& W, const Base& if(status == false) { - arma_debug_warn("iwishrnd(): given matrix is not symmetric positive definite and/or df is too low"); - return false; + W.soft_reset(); + arma_debug_warn_level(3, "iwishrnd(): given matrix is not symmetric positive definite and/or df is too low"); } - return true; + return status; } template -arma_warn_unused inline typename enable_if2 @@ -180,7 +188,15 @@ iwishrnd(Mat& W, const Base& arma_extra_debug_sigprint(); arma_ignore(T); - return op_iwishrnd::apply_direct(W, Dinv.get_ref(), df, uword(2)); + const bool status = op_iwishrnd::apply_direct(W, Dinv.get_ref(), df, uword(2)); + + if(status == false) + { + W.soft_reset(); + arma_debug_warn_level(3, "wishrnd(): problem with given 'Dinv' matrix and/or df is too low"); + } + + return status; } diff --git a/src/armadillo_bits/fn_zeros.hpp b/src/armadillo_bits/fn_zeros.hpp index 204561e9..5f069223 100644 --- a/src/armadillo_bits/fn_zeros.hpp +++ b/src/armadillo_bits/fn_zeros.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,20 +37,16 @@ template arma_warn_unused arma_inline const Gen -zeros(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = 0) +zeros(const uword n_elem, const arma_empty_class junk1 = arma_empty_class(), const typename arma_Mat_Col_Row_only::result* junk2 = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk1); arma_ignore(junk2); - if(is_Row::value) - { - return Gen(1, n_elem); - } - else - { - return Gen(n_elem, 1); - } + const uword n_rows = (is_Row::value) ? uword(1) : n_elem; + const uword n_cols = (is_Row::value) ? n_elem : uword(1); + + return Gen(n_rows, n_cols); } @@ -81,20 +79,13 @@ template arma_warn_unused arma_inline const Gen -zeros(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = 0) +zeros(const uword n_rows, const uword n_cols, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_Col::value) - { - arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); - } - else - if(is_Row::value) - { - arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); - } + if(is_Col::value) { arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); } + if(is_Row::value) { arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); } return Gen(n_rows, n_cols); } @@ -105,7 +96,7 @@ template arma_warn_unused arma_inline const Gen -zeros(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = 0) +zeros(const SizeMat& s, const typename arma_Mat_Col_Row_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -143,7 +134,7 @@ template arma_warn_unused arma_inline const GenCube -zeros(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = 0) +zeros(const uword n_rows, const uword n_cols, const uword n_slices, const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -157,7 +148,7 @@ template arma_warn_unused arma_inline const GenCube -zeros(const SizeCube& s, const typename arma_Cube_only::result* junk = 0) +zeros(const SizeCube& s, const typename arma_Cube_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); @@ -171,20 +162,13 @@ template arma_warn_unused inline sp_obj_type -zeros(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0) +zeros(const uword n_rows, const uword n_cols, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); - if(is_SpCol::value == true) - { - arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); - } - else - if(is_SpRow::value == true) - { - arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); - } + if(is_SpCol::value) { arma_debug_check( (n_cols != 1), "zeros(): incompatible size" ); } + if(is_SpRow::value) { arma_debug_check( (n_rows != 1), "zeros(): incompatible size" ); } return sp_obj_type(n_rows, n_cols); } @@ -195,7 +179,7 @@ template arma_warn_unused inline sp_obj_type -zeros(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = 0) +zeros(const SizeMat& s, const typename arma_SpMat_SpCol_SpRow_only::result* junk = nullptr) { arma_extra_debug_sigprint(); arma_ignore(junk); diff --git a/src/armadillo_bits/glue_affmul_bones.hpp b/src/armadillo_bits/glue_affmul_bones.hpp index f34ebc03..5284b6cc 100644 --- a/src/armadillo_bits/glue_affmul_bones.hpp +++ b/src/armadillo_bits/glue_affmul_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class glue_affmul template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_affmul_meat.hpp b/src/armadillo_bits/glue_affmul_meat.hpp index e38a01f5..19c3799a 100644 --- a/src/armadillo_bits/glue_affmul_meat.hpp +++ b/src/armadillo_bits/glue_affmul_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -240,7 +242,7 @@ glue_affmul::apply_noalias_square(Mat& out, const T1& A, { if(B_n_cols == 1) { - Col tmp(N); + Col tmp(N, arma_nozeros_indicator()); eT* tmp_mem = tmp.memptr(); arrayops::copy(tmp_mem, B.memptr(), N-1); @@ -251,7 +253,7 @@ glue_affmul::apply_noalias_square(Mat& out, const T1& A, } else { - Mat tmp(N, B_n_cols); + Mat tmp(N, B_n_cols, arma_nozeros_indicator()); for(uword col=0; col < B_n_cols; ++col) { @@ -423,7 +425,7 @@ glue_affmul::apply_noalias_rectangle(Mat& out, const T1& if(B_n_cols == 1) { - Col tmp(A_n_cols); + Col tmp(A_n_cols, arma_nozeros_indicator()); eT* tmp_mem = tmp.memptr(); arrayops::copy(tmp_mem, B.memptr(), A_n_cols-1); @@ -434,7 +436,7 @@ glue_affmul::apply_noalias_rectangle(Mat& out, const T1& } else { - Mat tmp(A_n_cols, B_n_cols); + Mat tmp(A_n_cols, B_n_cols, arma_nozeros_indicator()); for(uword col=0; col < B_n_cols; ++col) { @@ -468,7 +470,7 @@ glue_affmul::apply_noalias_generic(Mat& out, const T1& A const uword B_n_rows = B.n_rows; const uword B_n_cols = B.n_cols; - Mat tmp(B_n_rows+1, B_n_cols); + Mat tmp(B_n_rows+1, B_n_cols, arma_nozeros_indicator()); for(uword col=0; col < B_n_cols; ++col) { diff --git a/src/armadillo_bits/glue_atan2_bones.hpp b/src/armadillo_bits/glue_atan2_bones.hpp index 730596d1..f60e7830 100644 --- a/src/armadillo_bits/glue_atan2_bones.hpp +++ b/src/armadillo_bits/glue_atan2_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_atan2_meat.hpp b/src/armadillo_bits/glue_atan2_meat.hpp index 633ef851..38469ed8 100644 --- a/src/armadillo_bits/glue_atan2_meat.hpp +++ b/src/armadillo_bits/glue_atan2_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -68,7 +70,7 @@ glue_atan2::apply_noalias(Mat& out, const Proxy& P1, eT* out_mem = out.memptr(); - const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate::use_mp || Proxy::use_mp)>::eval(n_elem); + const bool use_mp = arma_config::openmp && mp_gate::use_mp || Proxy::use_mp)>::eval(n_elem); const bool use_at = Proxy::use_at || Proxy::use_at; if(use_at == false) @@ -170,7 +172,7 @@ glue_atan2::apply_noalias(Cube& out, const ProxyCube eT* out_mem = out.memptr(); - const bool use_mp = arma_config::cxx11 && arma_config::openmp && mp_gate::use_mp || ProxyCube::use_mp)>::eval(n_elem); + const bool use_mp = arma_config::openmp && mp_gate::use_mp || ProxyCube::use_mp)>::eval(n_elem); const bool use_at = ProxyCube::use_at || ProxyCube::use_at; if(use_at == false) diff --git a/src/armadillo_bits/glue_conv_bones.hpp b/src/armadillo_bits/glue_conv_bones.hpp index 66b574cb..5382f84c 100644 --- a/src/armadillo_bits/glue_conv_bones.hpp +++ b/src/armadillo_bits/glue_conv_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_conv template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; }; template inline static void apply(Mat& out, const Mat& A, const Mat& B, const bool A_is_col); diff --git a/src/armadillo_bits/glue_conv_meat.hpp b/src/armadillo_bits/glue_conv_meat.hpp index ba5022b9..e722ff48 100644 --- a/src/armadillo_bits/glue_conv_meat.hpp +++ b/src/armadillo_bits/glue_conv_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -38,7 +40,7 @@ glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_ if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; } - Col hh(h_n_elem); // flipped version of h + Col hh(h_n_elem, arma_nozeros_indicator()); // flipped version of h const eT* h_mem = h.memptr(); eT* hh_mem = hh.memptr(); @@ -49,7 +51,7 @@ glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_ } - Col xx( (x_n_elem + 2*h_n_elem_m1), fill::zeros ); // zero padded version of x + Col xx( (x_n_elem + 2*h_n_elem_m1), arma_zeros_indicator() ); // zero padded version of x const eT* x_mem = x.memptr(); eT* xx_mem = xx.memptr(); @@ -61,11 +63,28 @@ glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_ eT* out_mem = out.memptr(); - for(uword i=0; i < out_n_elem; ++i) + if( (arma_config::openmp) && (x_n_elem >= 128) && (h_n_elem >= 64) && (mp_thread_limit::in_parallel() == false) ) { - // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) ); - - out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < out_n_elem; ++i) + { + out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); + } + } + #endif + } + else + { + for(uword i=0; i < out_n_elem; ++i) + { + // out_mem[i] = dot( hh, xx.subvec(i, (i + h_n_elem_m1)) ); + + out_mem[i] = op_dot::direct_dot( h_n_elem, hh_mem, &(xx_mem[i]) ); + } } } @@ -90,7 +109,7 @@ glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_ // if( (h_n_elem == 0) || (x_n_elem == 0) ) { out.zeros(); return; } // // -// Col hh(h_n_elem); // flipped version of h +// Col hh(h_n_elem, arma_nozeros_indicator()); // flipped version of h // // const eT* h_mem = h.memptr(); // eT* hh_mem = hh.memptr(); @@ -106,7 +125,7 @@ glue_conv::apply(Mat& out, const Mat& A, const Mat& B, const bool A_ // // const uword HH_n_rows = h_n_elem + (N_copies-1); // -// Mat HH(HH_n_rows, N_copies, fill::zeros); +// Mat HH(HH_n_rows, N_copies, arma_zeros_indicator()); // // for(uword i=0; i& out, const Mat& A, const Mat& B, const bool A_ // // // -// Col xx( (x_n_elem + 2*h_n_elem_m1), fill::zeros ); // zero padded version of x +// Col xx( (x_n_elem + 2*h_n_elem_m1), arma_zeros_indicator() ); // zero padded version of x // // const eT* x_mem = x.memptr(); // eT* xx_mem = xx.memptr(); @@ -181,7 +200,7 @@ glue_conv::apply(Mat& out, const Glue& arma_debug_check ( ( ((A.is_vec() == false) && (A.is_empty() == false)) || ((B.is_vec() == false) && (B.is_empty() == false)) ), - "conv(): given object is not a vector" + "conv(): given object must be a vector" ); const bool A_is_col = ((T1::is_col) || (A.n_cols == 1)); @@ -235,7 +254,7 @@ glue_conv2::apply(Mat& out, const Mat& A, const Mat& B) if(G.is_empty() || W.is_empty()) { out.zeros(); return; } - Mat H(G.n_rows, G.n_cols); // flipped filter coefficients + Mat H(G.n_rows, G.n_cols, arma_nozeros_indicator()); // flipped filter coefficients const uword H_n_rows = H.n_rows; const uword H_n_cols = H.n_cols; @@ -255,31 +274,64 @@ glue_conv2::apply(Mat& out, const Mat& A, const Mat& B) } - Mat X( (W.n_rows + 2*H_n_rows_m1), (W.n_cols + 2*H_n_cols_m1), fill::zeros ); + Mat X( (W.n_rows + 2*H_n_rows_m1), (W.n_cols + 2*H_n_cols_m1), arma_zeros_indicator() ); X( H_n_rows_m1, H_n_cols_m1, arma::size(W) ) = W; // zero padded version of 2D image out.set_size( out_n_rows, out_n_cols ); - for(uword col=0; col < out_n_cols; ++col) + if( (arma_config::openmp) && (out_n_cols >= 2) && (mp_thread_limit::in_parallel() == false) ) { - eT* out_colptr = out.colptr(col); - - for(uword row=0; row < out_n_rows; ++row) + #if defined(ARMA_USE_OPENMP) { - // out.at(row, col) = accu( H % X(row, col, size(H)) ); - - eT acc = eT(0); + const int n_threads = mp_thread_limit::get(); - for(uword H_col = 0; H_col < H_n_cols; ++H_col) + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < out_n_cols; ++col) { - const eT* X_colptr = X.colptr(col + H_col); + eT* out_colptr = out.colptr(col); - acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) ); + for(uword row=0; row < out_n_rows; ++row) + { + // out.at(row, col) = accu( H % X(row, col, size(H)) ); + + eT acc = eT(0); + + for(uword H_col = 0; H_col < H_n_cols; ++H_col) + { + const eT* X_colptr = X.colptr(col + H_col); + + acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) ); + } + + out_colptr[row] = acc; + } } + } + #endif + } + else + { + for(uword col=0; col < out_n_cols; ++col) + { + eT* out_colptr = out.colptr(col); - out_colptr[row] = acc; + for(uword row=0; row < out_n_rows; ++row) + { + // out.at(row, col) = accu( H % X(row, col, size(H)) ); + + eT acc = eT(0); + + for(uword H_col = 0; H_col < H_n_cols; ++H_col) + { + const eT* X_colptr = X.colptr(col + H_col); + + acc += op_dot::direct_dot( H_n_rows, H.colptr(H_col), &(X_colptr[row]) ); + } + + out_colptr[row] = acc; + } } } } diff --git a/src/armadillo_bits/glue_cor_bones.hpp b/src/armadillo_bits/glue_cor_bones.hpp index 4bda90ab..eabb8977 100644 --- a/src/armadillo_bits/glue_cor_bones.hpp +++ b/src/armadillo_bits/glue_cor_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_cor template struct traits { - static const bool is_row = false; // T1::is_col; // TODO: check - static const bool is_col = false; // T2::is_col; // TODO: check - static const bool is_xvec = false; + static constexpr bool is_row = false; // T1::is_col; // TODO: check + static constexpr bool is_col = false; // T2::is_col; // TODO: check + static constexpr bool is_xvec = false; }; template inline static void apply(Mat& out, const Glue& X); diff --git a/src/armadillo_bits/glue_cor_meat.hpp b/src/armadillo_bits/glue_cor_meat.hpp index 083eff2a..8f937975 100644 --- a/src/armadillo_bits/glue_cor_meat.hpp +++ b/src/armadillo_bits/glue_cor_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_cov_bones.hpp b/src/armadillo_bits/glue_cov_bones.hpp index d0991c12..385dd7a0 100644 --- a/src/armadillo_bits/glue_cov_bones.hpp +++ b/src/armadillo_bits/glue_cov_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_cov template struct traits { - static const bool is_row = false; // T1::is_col; // TODO: check - static const bool is_col = false; // T2::is_col; // TODO: check - static const bool is_xvec = false; + static constexpr bool is_row = false; // T1::is_col; // TODO: check + static constexpr bool is_col = false; // T2::is_col; // TODO: check + static constexpr bool is_xvec = false; }; template inline static void apply(Mat& out, const Glue& X); diff --git a/src/armadillo_bits/glue_cov_meat.hpp b/src/armadillo_bits/glue_cov_meat.hpp index a3fa3082..d5768e2d 100644 --- a/src/armadillo_bits/glue_cov_meat.hpp +++ b/src/armadillo_bits/glue_cov_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_cross_bones.hpp b/src/armadillo_bits/glue_cross_bones.hpp index 80ceabe9..469e2e7c 100644 --- a/src/armadillo_bits/glue_cross_bones.hpp +++ b/src/armadillo_bits/glue_cross_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_cross template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = true; }; template inline static void apply(Mat& out, const Glue& X); diff --git a/src/armadillo_bits/glue_cross_meat.hpp b/src/armadillo_bits/glue_cross_meat.hpp index 3f8ebc90..bdf38d1b 100644 --- a/src/armadillo_bits/glue_cross_meat.hpp +++ b/src/armadillo_bits/glue_cross_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -32,12 +34,9 @@ glue_cross::apply(Mat& out, const Glue PA(X.A); const Proxy PB(X.B); - arma_debug_check( ((PA.get_n_elem() != 3) || (PB.get_n_elem() != 3)), "cross(): input vectors must have 3 elements" ); - - const uword PA_n_rows = Proxy::is_row ? 1 : PA.get_n_rows(); - const uword PA_n_cols = Proxy::is_col ? 1 : PA.get_n_cols(); + arma_debug_check( ((PA.get_n_elem() != 3) || (PB.get_n_elem() != 3)), "cross(): each vector must have 3 elements" ); - out.set_size(PA_n_rows, PA_n_cols); + out.set_size(PA.get_n_rows(), PA.get_n_cols()); eT* out_mem = out.memptr(); @@ -60,7 +59,7 @@ glue_cross::apply(Mat& out, const Glue::is_col ? true : (PA_n_cols == 1); + const bool PA_is_col = Proxy::is_col ? true : (PA.get_n_cols() == 1); const bool PB_is_col = Proxy::is_col ? true : (PB.get_n_cols() == 1); const eT ax = PA.at(0,0); diff --git a/src/armadillo_bits/glue_hist_bones.hpp b/src/armadillo_bits/glue_hist_bones.hpp index 4f3a9890..2d053588 100644 --- a/src/armadillo_bits/glue_hist_bones.hpp +++ b/src/armadillo_bits/glue_hist_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -39,9 +41,9 @@ class glue_hist_default template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; }; template diff --git a/src/armadillo_bits/glue_hist_meat.hpp b/src/armadillo_bits/glue_hist_meat.hpp index decc6e52..ec5b4a32 100644 --- a/src/armadillo_bits/glue_hist_meat.hpp +++ b/src/armadillo_bits/glue_hist_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_histc_bones.hpp b/src/armadillo_bits/glue_histc_bones.hpp index f933d4f3..c1cc6875 100644 --- a/src/armadillo_bits/glue_histc_bones.hpp +++ b/src/armadillo_bits/glue_histc_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -39,9 +41,9 @@ class glue_histc_default template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T1::is_col; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; }; template diff --git a/src/armadillo_bits/glue_histc_meat.hpp b/src/armadillo_bits/glue_histc_meat.hpp index 6f348b1a..6e79175b 100644 --- a/src/armadillo_bits/glue_histc_meat.hpp +++ b/src/armadillo_bits/glue_histc_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,7 +27,7 @@ glue_histc::apply_noalias(Mat& C, const Mat& A, const Mat& B, con { arma_extra_debug_sigprint(); - arma_debug_check( ((B.is_vec() == false) && (B.is_empty() == false)), "histc(): parameter 'edges' is not a vector" ); + arma_debug_check( ((B.is_vec() == false) && (B.is_empty() == false)), "histc(): parameter 'edges' must be a vector" ); const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; diff --git a/src/armadillo_bits/glue_hypot_bones.hpp b/src/armadillo_bits/glue_hypot_bones.hpp index 4aba03d9..53985cc3 100644 --- a/src/armadillo_bits/glue_hypot_bones.hpp +++ b/src/armadillo_bits/glue_hypot_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_hypot_meat.hpp b/src/armadillo_bits/glue_hypot_meat.hpp index 4afe67c4..f773b333 100644 --- a/src/armadillo_bits/glue_hypot_meat.hpp +++ b/src/armadillo_bits/glue_hypot_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_intersect_bones.hpp b/src/armadillo_bits/glue_intersect_bones.hpp index 079cbe16..6a3bcad9 100644 --- a/src/armadillo_bits/glue_intersect_bones.hpp +++ b/src/armadillo_bits/glue_intersect_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class glue_intersect template struct traits { - static const bool is_row = (T1::is_row && T2::is_row); - static const bool is_col = (T1::is_col || T2::is_col); - static const bool is_xvec = false; + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col || T2::is_col); + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_intersect_meat.hpp b/src/armadillo_bits/glue_intersect_meat.hpp index 0af8164b..21a2b6d6 100644 --- a/src/armadillo_bits/glue_intersect_meat.hpp +++ b/src/armadillo_bits/glue_intersect_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -76,7 +78,7 @@ glue_intersect::apply(Mat& out, uvec& iA, uvec& iB, cons const uword C_n_elem = A_uniq.n_elem + B_uniq.n_elem; - Col C(C_n_elem); + Col C(C_n_elem, arma_nozeros_indicator()); arrayops::copy(C.memptr(), A_uniq.memptr(), A_uniq.n_elem); arrayops::copy(C.memptr() + A_uniq.n_elem, B_uniq.memptr(), B_uniq.n_elem); @@ -96,7 +98,7 @@ glue_intersect::apply(Mat& out, uvec& iA, uvec& iB, cons const eT* C_sorted_mem = C_sorted.memptr(); - uvec jj(C_n_elem); // worst case length + uvec jj(C_n_elem, arma_nozeros_indicator()); // worst case length uword* jj_mem = jj.memptr(); uword jj_count = 0; diff --git a/src/armadillo_bits/glue_join_bones.hpp b/src/armadillo_bits/glue_join_bones.hpp index 84de5bd5..b84116a0 100644 --- a/src/armadillo_bits/glue_join_bones.hpp +++ b/src/armadillo_bits/glue_join_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_join_cols template struct traits { - static const bool is_row = false; - static const bool is_col = (T1::is_col && T2::is_col); - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; }; template @@ -54,9 +56,9 @@ class glue_join_rows template struct traits { - static const bool is_row = (T1::is_row && T2::is_row); - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_join_meat.hpp b/src/armadillo_bits/glue_join_meat.hpp index 637df5d8..1ffd3b1e 100644 --- a/src/armadillo_bits/glue_join_meat.hpp +++ b/src/armadillo_bits/glue_join_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -362,7 +364,7 @@ glue_join_slices::apply(Cube& out, const GlueCube C(A.n_rows, A.n_cols, A.n_slices + B.n_slices); + Cube C(A.n_rows, A.n_cols, A.n_slices + B.n_slices, arma_nozeros_indicator()); C.slices(0, A.n_slices-1) = A; C.slices(A.n_slices, C.n_slices-1) = B; diff --git a/src/armadillo_bits/glue_kron_bones.hpp b/src/armadillo_bits/glue_kron_bones.hpp index 73595003..84c93476 100644 --- a/src/armadillo_bits/glue_kron_bones.hpp +++ b/src/armadillo_bits/glue_kron_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class glue_kron template struct traits { - static const bool is_row = (T1::is_row && T2::is_row); - static const bool is_col = (T1::is_col && T2::is_col); - static const bool is_xvec = false; + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; }; template inline static void direct_kron(Mat& out, const Mat& A, const Mat& B); diff --git a/src/armadillo_bits/glue_kron_meat.hpp b/src/armadillo_bits/glue_kron_meat.hpp index 3616cfba..c7c4ff63 100644 --- a/src/armadillo_bits/glue_kron_meat.hpp +++ b/src/armadillo_bits/glue_kron_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -123,24 +125,21 @@ glue_kron::apply(Mat& out, const Glue& typedef typename T1::elem_type eT; - const unwrap A_tmp(X.A); - const unwrap B_tmp(X.B); - - const Mat& A = A_tmp.M; - const Mat& B = B_tmp.M; + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); - if( (&out != &A) && (&out != &B) ) - { - glue_kron::direct_kron(out, A, B); - } - else + if(UA.is_alias(out) || UB.is_alias(out)) { Mat tmp; - glue_kron::direct_kron(tmp, A, B); + glue_kron::direct_kron(tmp, UA.M, UB.M); out.steal_mem(tmp); } + else + { + glue_kron::direct_kron(out, UA.M, UB.M); + } } diff --git a/src/armadillo_bits/glue_max_bones.hpp b/src/armadillo_bits/glue_max_bones.hpp index ea4f679d..149988e9 100644 --- a/src/armadillo_bits/glue_max_bones.hpp +++ b/src/armadillo_bits/glue_max_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,18 +31,14 @@ class glue_max template inline static void apply(Mat& out, const Glue& X); - template inline static void apply(Mat< eT >& out, const Proxy& PA, const Proxy& PB); - - template inline static void apply(Mat< std::complex >& out, const Proxy& PA, const Proxy& PB); + template inline static void apply(Mat& out, const Proxy& PA, const Proxy& PB); // cubes template inline static void apply(Cube& out, const GlueCube& X); - template inline static void apply(Cube< eT >& out, const ProxyCube& PA, const ProxyCube& PB); - - template inline static void apply(Cube< std::complex >& out, const ProxyCube& PA, const ProxyCube& PB); + template inline static void apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB); }; diff --git a/src/armadillo_bits/glue_max_meat.hpp b/src/armadillo_bits/glue_max_meat.hpp index 44e12d0d..b1e52c28 100644 --- a/src/armadillo_bits/glue_max_meat.hpp +++ b/src/armadillo_bits/glue_max_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,51 +60,9 @@ glue_max::apply(Mat& out, const Proxy& PA, const Proxy& PB) const uword n_rows = PA.get_n_rows(); const uword n_cols = PA.get_n_cols(); - arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise maximum"); + arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise max()"); - out.set_size(n_rows, n_cols); - - eT* out_mem = out.memptr(); - - if( (Proxy::use_at == false) && (Proxy::use_at == false) ) - { - typename Proxy::ea_type A = PA.get_ea(); - typename Proxy::ea_type B = PB.get_ea(); - - const uword N = PA.get_n_elem(); - - for(uword i=0; i -inline -void -glue_max::apply(Mat< std::complex >& out, const Proxy& PA, const Proxy& PB) - { - arma_extra_debug_sigprint(); - - typedef typename std::complex eT; - - const uword n_rows = PA.get_n_rows(); - const uword n_cols = PA.get_n_cols(); - - arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise maximum"); + const arma_gt_comparator comparator; out.set_size(n_rows, n_cols); @@ -117,10 +77,10 @@ glue_max::apply(Mat< std::complex >& out, const Proxy& PA, const Proxy std::abs(B_val) ) ? A_val : B_val; + out_mem[i] = comparator(Ai,Bi) ? Ai : Bi; } } else @@ -128,10 +88,10 @@ glue_max::apply(Mat< std::complex >& out, const Proxy& PA, const Proxy std::abs(B_val) ) ? A_val : B_val; + *out_mem = comparator(Ai,Bi) ? Ai : Bi; ++out_mem; } @@ -179,53 +139,9 @@ glue_max::apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB) const uword n_cols = PA.get_n_cols(); const uword n_slices = PA.get_n_slices(); - arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise maximum"); - - out.set_size(n_rows, n_cols, n_slices); - - eT* out_mem = out.memptr(); - - if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) - { - typename ProxyCube::ea_type A = PA.get_ea(); - typename ProxyCube::ea_type B = PB.get_ea(); - - const uword N = PA.get_n_elem(); - - for(uword i=0; i -inline -void -glue_max::apply(Cube< std::complex >& out, const ProxyCube& PA, const ProxyCube& PB) - { - arma_extra_debug_sigprint(); - - typedef typename std::complex eT; - - const uword n_rows = PA.get_n_rows(); - const uword n_cols = PA.get_n_cols(); - const uword n_slices = PA.get_n_slices(); + arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise max()"); - arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise maximum"); + const arma_gt_comparator comparator; out.set_size(n_rows, n_cols, n_slices); @@ -240,10 +156,10 @@ glue_max::apply(Cube< std::complex >& out, const ProxyCube& PA, const Pro for(uword i=0; i std::abs(B_val) ) ? A_val : B_val; + out_mem[i] = comparator(Ai,Bi) ? Ai : Bi; } } else @@ -252,10 +168,10 @@ glue_max::apply(Cube< std::complex >& out, const ProxyCube& PA, const Pro for(uword col=0; col < n_cols; ++col ) for(uword row=0; row < n_rows; ++row ) { - const eT A_val = PA.at(row,col,slice); - const eT B_val = PB.at(row,col,slice); + const eT Ai = PA.at(row,col,slice); + const eT Bi = PB.at(row,col,slice); - *out_mem = ( std::abs(A_val) > std::abs(B_val) ) ? A_val : B_val; + *out_mem = comparator(Ai,Bi) ? Ai : Bi; ++out_mem; } diff --git a/src/armadillo_bits/glue_min_bones.hpp b/src/armadillo_bits/glue_min_bones.hpp index 3b8e2628..bf1cbc35 100644 --- a/src/armadillo_bits/glue_min_bones.hpp +++ b/src/armadillo_bits/glue_min_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,18 +31,14 @@ class glue_min template inline static void apply(Mat& out, const Glue& X); - template inline static void apply(Mat< eT >& out, const Proxy& PA, const Proxy& PB); - - template inline static void apply(Mat< std::complex >& out, const Proxy& PA, const Proxy& PB); + template inline static void apply(Mat& out, const Proxy& PA, const Proxy& PB); // cubes template inline static void apply(Cube& out, const GlueCube& X); - template inline static void apply(Cube< eT >& out, const ProxyCube& PA, const ProxyCube& PB); - - template inline static void apply(Cube< std::complex >& out, const ProxyCube& PA, const ProxyCube& PB); + template inline static void apply(Cube& out, const ProxyCube& PA, const ProxyCube& PB); }; diff --git a/src/armadillo_bits/glue_min_meat.hpp b/src/armadillo_bits/glue_min_meat.hpp index d0619d79..0fc6e3f0 100644 --- a/src/armadillo_bits/glue_min_meat.hpp +++ b/src/armadillo_bits/glue_min_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,51 +60,9 @@ glue_min::apply(Mat& out, const Proxy& PA, const Proxy& PB) const uword n_rows = PA.get_n_rows(); const uword n_cols = PA.get_n_cols(); - arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise minimum"); + arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise min()"); - out.set_size(n_rows, n_cols); - - eT* out_mem = out.memptr(); - - if( (Proxy::use_at == false) && (Proxy::use_at == false) ) - { - typename Proxy::ea_type A = PA.get_ea(); - typename Proxy::ea_type B = PB.get_ea(); - - const uword N = PA.get_n_elem(); - - for(uword i=0; i -inline -void -glue_min::apply(Mat< std::complex >& out, const Proxy& PA, const Proxy& PB) - { - arma_extra_debug_sigprint(); - - typedef typename std::complex eT; - - const uword n_rows = PA.get_n_rows(); - const uword n_cols = PA.get_n_cols(); - - arma_debug_assert_same_size(n_rows, n_cols, PB.get_n_rows(), PB.get_n_cols(), "element-wise minimum"); + const arma_lt_comparator comparator; out.set_size(n_rows, n_cols); @@ -117,10 +77,10 @@ glue_min::apply(Mat< std::complex >& out, const Proxy& PA, const Proxy >& out, const Proxy& PA, const Proxy& out, const ProxyCube& PA, const ProxyCube& PB) const uword n_cols = PA.get_n_cols(); const uword n_slices = PA.get_n_slices(); - arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise minimum"); - - out.set_size(n_rows, n_cols, n_slices); - - eT* out_mem = out.memptr(); - - if( (ProxyCube::use_at == false) && (ProxyCube::use_at == false) ) - { - typename ProxyCube::ea_type A = PA.get_ea(); - typename ProxyCube::ea_type B = PB.get_ea(); - - const uword N = PA.get_n_elem(); - - for(uword i=0; i -inline -void -glue_min::apply(Cube< std::complex >& out, const ProxyCube& PA, const ProxyCube& PB) - { - arma_extra_debug_sigprint(); - - typedef typename std::complex eT; - - const uword n_rows = PA.get_n_rows(); - const uword n_cols = PA.get_n_cols(); - const uword n_slices = PA.get_n_slices(); + arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise min()"); - arma_debug_assert_same_size(n_rows, n_cols, n_slices, PB.get_n_rows(), PB.get_n_cols(), PB.get_n_slices(), "element-wise minimum"); + const arma_lt_comparator comparator; out.set_size(n_rows, n_cols, n_slices); @@ -240,10 +156,10 @@ glue_min::apply(Cube< std::complex >& out, const ProxyCube& PA, const Pro for(uword i=0; i >& out, const ProxyCube& PA, const Pro for(uword col=0; col < n_cols; ++col ) for(uword row=0; row < n_rows; ++row ) { - const eT A_val = PA.at(row,col,slice); - const eT B_val = PB.at(row,col,slice); + const eT Ai = PA.at(row,col,slice); + const eT Bi = PB.at(row,col,slice); - *out_mem = ( std::abs(A_val) < std::abs(B_val) ) ? A_val : B_val; + *out_mem = comparator(Ai,Bi) ? Ai : Bi; ++out_mem; } diff --git a/src/armadillo_bits/glue_mixed_bones.hpp b/src/armadillo_bits/glue_mixed_bones.hpp index 0806dc67..bdc3806d 100644 --- a/src/armadillo_bits/glue_mixed_bones.hpp +++ b/src/armadillo_bits/glue_mixed_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class glue_mixed_times template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_mixed_meat.hpp b/src/armadillo_bits/glue_mixed_meat.hpp index 5c70cab8..21b6dc4b 100644 --- a/src/armadillo_bits/glue_mixed_meat.hpp +++ b/src/armadillo_bits/glue_mixed_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -59,7 +61,7 @@ glue_mixed_times::apply(Mat::eT>& out, const mtGlue< } else { - Mat tmp(out_n_rows, out_n_cols); + Mat tmp(out_n_rows, out_n_cols, arma_nozeros_indicator()); gemm_mixed::apply(tmp, A, B, alpha); diff --git a/src/armadillo_bits/glue_mvnrnd_bones.hpp b/src/armadillo_bits/glue_mvnrnd_bones.hpp index 7ee438f5..ab1c437a 100644 --- a/src/armadillo_bits/glue_mvnrnd_bones.hpp +++ b/src/armadillo_bits/glue_mvnrnd_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,9 +27,9 @@ class glue_mvnrnd_vec template struct traits { - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_mvnrnd_meat.hpp b/src/armadillo_bits/glue_mvnrnd_meat.hpp index 5e802bef..3c3019fb 100644 --- a/src/armadillo_bits/glue_mvnrnd_meat.hpp +++ b/src/armadillo_bits/glue_mvnrnd_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -36,6 +38,7 @@ glue_mvnrnd_vec::apply(Mat& out, const Glue& out, const Glue& out, const Base& out, const Base struct traits { - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; }; template inline static bool apply_noalias(Mat& out, const Col& X, const Col& Y, const uword N); diff --git a/src/armadillo_bits/glue_polyfit_meat.hpp b/src/armadillo_bits/glue_polyfit_meat.hpp index ddaaaf0d..3969c75f 100644 --- a/src/armadillo_bits/glue_polyfit_meat.hpp +++ b/src/armadillo_bits/glue_polyfit_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,7 +30,7 @@ glue_polyfit::apply_noalias(Mat& out, const Col& X, const Col& Y, co // create Vandermonde matrix - Mat V(X.n_elem, N+1); + Mat V(X.n_elem, N+1, arma_nozeros_indicator()); V.tail_cols(1).ones(); @@ -76,7 +78,7 @@ glue_polyfit::apply_direct(Mat& out, const Base struct traits { - static const bool is_row = T2::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = T2::is_xvec; + static constexpr bool is_row = T2::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = T2::is_xvec; }; template inline static void apply_noalias(Mat& out, const Mat& P, const Mat& X); diff --git a/src/armadillo_bits/glue_polyval_meat.hpp b/src/armadillo_bits/glue_polyval_meat.hpp index 4086615d..2c2a59cb 100644 --- a/src/armadillo_bits/glue_polyval_meat.hpp +++ b/src/armadillo_bits/glue_polyval_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_powext_bones.hpp b/src/armadillo_bits/glue_powext_bones.hpp new file mode 100644 index 00000000..d5698c5c --- /dev/null +++ b/src/armadillo_bits/glue_powext_bones.hpp @@ -0,0 +1,70 @@ + +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_powext +//! @{ + + + +class glue_powext + : public traits_glue_or + { + public: + + template inline static void apply(Mat& out, const Glue& X); + + template inline static void apply(Mat& out, const Mat& A, const Mat& B); + + template inline static Mat apply(const subview_each1& X, const Base& Y); + + // + + template inline static void apply(Cube& out, const GlueCube& X); + + template inline static void apply(Cube& out, const Cube& A, const Cube& B); + + template inline static Cube apply(const subview_cube_each1& X, const Base& Y); + }; + + + +class glue_powext_cx + : public traits_glue_or + { + public: + + template inline static void apply(Mat& out, const mtGlue& X); + + template inline static void apply(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B); + + template inline static Mat apply(const subview_each1& X, const Base& Y); + + // + + template inline static void apply(Cube& out, const mtGlueCube& X); + + template inline static void apply(Cube< std::complex >& out, const Cube< std::complex >& A, const Cube& B); + + template inline static Cube< std::complex > apply(const subview_cube_each1< std::complex >& X, const Base& Y); + }; + + + +//! @} diff --git a/src/armadillo_bits/glue_powext_meat.hpp b/src/armadillo_bits/glue_powext_meat.hpp new file mode 100644 index 00000000..700a2cfd --- /dev/null +++ b/src/armadillo_bits/glue_powext_meat.hpp @@ -0,0 +1,674 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup glue_powext +//! @{ + + +template +inline +void +glue_powext::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + const bool UA_bad_alias = UA.is_alias(out) && (UA.has_subview); // allow inplace operation + const bool UB_bad_alias = UB.is_alias(out); + + if(UA_bad_alias || UB_bad_alias) + { + Mat tmp; + + glue_powext::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext::apply(Mat& out, const Mat& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + out.set_size(A.n_rows, A.n_cols); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const eT* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Mat +glue_powext::apply + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + + const parent& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + Mat out(A_n_rows, A_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const eT* B_mem = B.memptr(); + + if(mode == 0) // each column + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_mem[row]); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_mem[row]); + } + } + } + } + + if(mode == 1) // each row + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_val); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = eop_aux::pow(A_mem[row], B_val); + } + } + } + } + + return out; + } + + + +template +inline +void +glue_powext::apply(Cube& out, const GlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube UA(X.A); + const unwrap_cube UB(X.B); + + const Cube& A = UA.M; + const Cube& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + if(UB.is_alias(out)) + { + Cube tmp; + + glue_powext::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext::apply(Cube& out, const Cube& A, const Cube& B) + { + arma_extra_debug_sigprint(); + + out.set_size(A.n_rows, A.n_cols, A.n_slices); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const eT* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Cube +glue_powext::apply + ( + const subview_cube_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + const Cube& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + Cube out(A_n_rows, A_n_cols, A_n_slices, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const eT* B_mem = B.memptr(); + const uword B_n_elem = B.n_elem; + + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_slices) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = eop_aux::pow(A_slice_mem[i], B_mem[i]); + } + } + } + #endif + } + else + { + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = eop_aux::pow(A_slice_mem[i], B_mem[i]); + } + } + } + + return out; + } + + + +// + + + +template +inline +void +glue_powext_cx::apply(Mat& out, const mtGlue& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const quasi_unwrap UA(X.A); + const quasi_unwrap UB(X.B); + + const Mat& A = UA.M; + const Mat< T>& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + if(UA.is_alias(out) && (UA.has_subview)) + { + Mat tmp; + + glue_powext_cx::apply(tmp, A, B); + + out.steal_mem(tmp); + } + else + { + glue_powext_cx::apply(out, A, B); + } + } + + + +template +inline +void +glue_powext_cx::apply(Mat< std::complex >& out, const Mat< std::complex >& A, const Mat& B) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + out.set_size(A.n_rows, A.n_cols); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const T* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Mat +glue_powext_cx::apply + ( + const subview_each1& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename parent::elem_type eT; + typedef typename parent::pod_type T; + + const parent& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + Mat out(A_n_rows, A_n_cols, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const T* B_mem = B.memptr(); + + if(mode == 0) // each column + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_mem[row]); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_mem[row]); + } + } + } + } + + if(mode == 1) // each row + { + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_cols) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_val); + } + } + } + #endif + } + else + { + for(uword i=0; i < A_n_cols; ++i) + { + const eT* A_mem = A.colptr(i); + eT* out_mem = out.colptr(i); + + const eT B_val = B_mem[i]; + + for(uword row=0; row < A_n_rows; ++row) + { + out_mem[row] = std::pow(A_mem[row], B_val); + } + } + } + } + + return out; + } + + + +template +inline +void +glue_powext_cx::apply(Cube& out, const mtGlueCube& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + typedef typename get_pod_type::result T; + + const unwrap_cube UA(X.A); + const unwrap_cube UB(X.B); + + const Cube& A = UA.M; + const Cube< T>& B = UB.M; + + arma_debug_assert_same_size(A, B, "element-wise pow()"); + + glue_powext_cx::apply(out, A, B); + } + + + +template +inline +void +glue_powext_cx::apply(Cube< std::complex >& out, const Cube< std::complex >& A, const Cube& B) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + out.set_size(A.n_rows, A.n_cols, A.n_slices); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + const T* B_mem = B.memptr(); + + if( arma_config::openmp && mp_gate::eval(N) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i +inline +Cube< std::complex > +glue_powext_cx::apply + ( + const subview_cube_each1< std::complex >& X, + const Base& Y + ) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + const Cube& A = X.P; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + const uword A_n_slices = A.n_slices; + + Cube out(A_n_rows, A_n_cols, A_n_slices, arma_nozeros_indicator()); + + const quasi_unwrap tmp(Y.get_ref()); + const Mat& B = tmp.M; + + X.check_size(B); + + const T* B_mem = B.memptr(); + const uword B_n_elem = B.n_elem; + + if( arma_config::openmp && mp_gate::eval(A.n_elem) ) + { + #if defined(ARMA_USE_OPENMP) + { + const int n_threads = int( (std::min)(uword(mp_thread_limit::get()), A_n_slices) ); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = std::pow(A_slice_mem[i], B_mem[i]); + } + } + } + #endif + } + else + { + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_slice_mem = A.slice_memptr(s); + eT* out_slice_mem = out.slice_memptr(s); + + for(uword i=0; i < B_n_elem; ++i) + { + out_slice_mem[i] = std::pow(A_slice_mem[i], B_mem[i]); + } + } + } + + return out; + } + + + +//! @} diff --git a/src/armadillo_bits/glue_quantile_bones.hpp b/src/armadillo_bits/glue_quantile_bones.hpp new file mode 100644 index 00000000..cd7fcf12 --- /dev/null +++ b/src/armadillo_bits/glue_quantile_bones.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_quantile +//! @{ + + +class glue_quantile + : public traits_glue_default + { + public: + + template + inline static void worker(eTb* out_mem, Col& Y, const Mat& P); + + + template + inline static void apply_noalias(Mat& out, const Mat& X, const Mat& P, const uword dim); + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + + +class glue_quantile_default + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T1::is_col; + static constexpr bool is_xvec = T1::is_xvec; + }; + + template + inline static void apply(Mat& out, const mtGlue& expr); + }; + + +//! @} diff --git a/src/armadillo_bits/glue_quantile_meat.hpp b/src/armadillo_bits/glue_quantile_meat.hpp new file mode 100644 index 00000000..370e432a --- /dev/null +++ b/src/armadillo_bits/glue_quantile_meat.hpp @@ -0,0 +1,230 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_quantile +//! @{ + + +template +inline +void +glue_quantile::worker(eTb* out_mem, Col& Y, const Mat& P) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming out_mem is an array with P.n_elem elements + + // TODO: ignore non-finite values ? + + // algorithm based on "Definition 5" in: + // Rob J. Hyndman and Yanan Fan. + // Sample Quantiles in Statistical Packages. + // The American Statistician, Vol. 50, No. 4, pp. 361-365, 1996. + // http://doi.org/10.2307/2684934 + + const eTb* P_mem = P.memptr(); + const uword P_n_elem = P.n_elem; + + const eTb alpha = 0.5; + const eTb N = eTb(Y.n_elem); + const eTb P_min = (eTb(1) - alpha) / N; + const eTb P_max = (N - alpha) / N; + + for(uword i=0; i < P_n_elem; ++i) + { + const eTb P_i = P_mem[i]; + + eTb out_val = eTb(0); + + if(P_i < P_min) + { + out_val = (P_i < eTb(0)) ? eTb(-std::numeric_limits::infinity()) : eTb(Y.min()); + } + else + if(P_i > P_max) + { + out_val = (P_i > eTb(1)) ? eTb( std::numeric_limits::infinity()) : eTb(Y.max()); + } + else + { + const uword k = uword(std::floor(N * P_i + alpha)); + const eTb P_k = (eTb(k) - alpha) / N; + + const eTb w = (P_i - P_k) * N; + + eTa* Y_k_ptr = Y.begin() + uword(k); + std::nth_element( Y.begin(), Y_k_ptr, Y.end() ); + const eTa Y_k_val = (*Y_k_ptr); + + eTa* Y_km1_ptr = Y.begin() + uword(k-1); + // std::nth_element( Y.begin(), Y_km1_ptr, Y.end() ); + std::nth_element( Y.begin(), Y_km1_ptr, Y_k_ptr ); + const eTa Y_km1_val = (*Y_km1_ptr); + + out_val = ((eTb(1) - w) * Y_km1_val) + (w * Y_k_val); + } + + out_mem[i] = out_val; + } + } + + + +template +inline +void +glue_quantile::apply_noalias(Mat& out, const Mat& X, const Mat& P, const uword dim) + { + arma_extra_debug_sigprint(); + + arma_debug_check( ((P.is_vec() == false) && (P.is_empty() == false)), "quantile(): parameter 'P' must be a vector" ); + + if(X.is_empty()) { out.reset(); return; } + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + const uword P_n_elem = P.n_elem; + + if(dim == 0) + { + out.set_size(P_n_elem, X_n_cols); + + if(out.is_empty()) { return; } + + Col Y(X_n_rows, arma_nozeros_indicator()); + + if(X_n_cols == 1) + { + arrayops::copy(Y.memptr(), X.memptr(), X_n_rows); + + glue_quantile::worker(out.memptr(), Y, P); + } + else + { + for(uword col=0; col < X_n_cols; ++col) + { + arrayops::copy(Y.memptr(), X.colptr(col), X_n_rows); + + glue_quantile::worker(out.colptr(col), Y, P); + } + } + } + else + if(dim == 1) + { + out.set_size(X_n_rows, P_n_elem); + + if(out.is_empty()) { return; } + + Col Y(X_n_cols, arma_nozeros_indicator()); + + if(X_n_rows == 1) + { + arrayops::copy(Y.memptr(), X.memptr(), X_n_cols); + + glue_quantile::worker(out.memptr(), Y, P); + } + else + { + Col tmp(P_n_elem, arma_nozeros_indicator()); + + eTb* tmp_mem = tmp.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + eTa* Y_mem = Y.memptr(); + + for(uword col=0; col < X_n_cols; ++col) { Y_mem[col] = X.at(row,col); } + + glue_quantile::worker(tmp_mem, Y, P); + + for(uword i=0; i < P_n_elem; ++i) { out.at(row,i) = tmp_mem[i]; } + } + } + } + } + + + +template +inline +void +glue_quantile::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T2::elem_type eTb; + + const uword dim = expr.aux_uword; + + arma_debug_check( (dim > 1), "quantile(): parameter 'dim' must be 0 or 1" ); + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + arma_debug_check((UA.M.internal_has_nan() || UB.M.internal_has_nan()), "quantile(): detected NaN"); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_quantile::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_quantile::apply_noalias(out, UA.M, UB.M, dim); + } + } + + + +template +inline +void +glue_quantile_default::apply(Mat& out, const mtGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T2::elem_type eTb; + + const quasi_unwrap UA(expr.A); + const quasi_unwrap UB(expr.B); + + const uword dim = (T1::is_xvec) ? uword(UA.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); + + arma_debug_check((UA.M.internal_has_nan() || UB.M.internal_has_nan()), "quantile(): detected NaN"); + + if(UA.is_alias(out) || UB.is_alias(out)) + { + Mat tmp; + + glue_quantile::apply_noalias(tmp, UA.M, UB.M, dim); + + out.steal_mem(tmp); + } + else + { + glue_quantile::apply_noalias(out, UA.M, UB.M, dim); + } + } + + +//! @} diff --git a/src/armadillo_bits/glue_relational_bones.hpp b/src/armadillo_bits/glue_relational_bones.hpp index 7a466002..876ffb7f 100644 --- a/src/armadillo_bits/glue_relational_bones.hpp +++ b/src/armadillo_bits/glue_relational_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_relational_meat.hpp b/src/armadillo_bits/glue_relational_meat.hpp index 0f1b5891..5728a091 100644 --- a/src/armadillo_bits/glue_relational_meat.hpp +++ b/src/armadillo_bits/glue_relational_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_solve_bones.hpp b/src/armadillo_bits/glue_solve_bones.hpp index b46126ca..20c01659 100644 --- a/src/armadillo_bits/glue_solve_bones.hpp +++ b/src/armadillo_bits/glue_solve_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,21 +22,40 @@ -class glue_solve_gen +class glue_solve_gen_default { public: template struct traits { - static const bool is_row = false; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; - template inline static void apply(Mat& out, const Glue& X); + template inline static void apply(Mat& out, const Glue& X); - template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr); + }; + + + +class glue_solve_gen_full + { + public: + + template + struct traits + { + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template inline static void apply(Mat& out, const Glue& X); + + template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); }; @@ -46,9 +67,9 @@ class glue_solve_tri_default template struct traits { - static const bool is_row = false; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template inline static void apply(Mat& out, const Glue& X); @@ -58,19 +79,19 @@ class glue_solve_tri_default -class glue_solve_tri +class glue_solve_tri_full { public: template struct traits { - static const bool is_row = false; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; - template inline static void apply(Mat& out, const Glue& X); + template inline static void apply(Mat& out, const Glue& X); template inline static bool apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags); }; @@ -83,12 +104,13 @@ namespace solve_opts { const uword flags; - inline explicit opts(const uword in_flags); + inline constexpr explicit opts(const uword in_flags); inline const opts operator+(const opts& rhs) const; }; inline + constexpr opts::opts(const uword in_flags) : flags(in_flags) {} @@ -105,44 +127,47 @@ namespace solve_opts // The values below (eg. 1u << 1) are for internal Armadillo use only. // The values can change without notice. - static const uword flag_none = uword(0 ); - static const uword flag_fast = uword(1u << 0); - static const uword flag_equilibrate = uword(1u << 1); - static const uword flag_no_approx = uword(1u << 2); - static const uword flag_triu = uword(1u << 3); - static const uword flag_tril = uword(1u << 4); - static const uword flag_no_band = uword(1u << 5); - static const uword flag_no_sympd = uword(1u << 6); - static const uword flag_allow_ugly = uword(1u << 7); - static const uword flag_likely_sympd = uword(1u << 8); - static const uword flag_refine = uword(1u << 9); - static const uword flag_no_trimat = uword(1u << 10); - - struct opts_none : public opts { inline opts_none() : opts(flag_none ) {} }; - struct opts_fast : public opts { inline opts_fast() : opts(flag_fast ) {} }; - struct opts_equilibrate : public opts { inline opts_equilibrate() : opts(flag_equilibrate ) {} }; - struct opts_no_approx : public opts { inline opts_no_approx() : opts(flag_no_approx ) {} }; - struct opts_triu : public opts { inline opts_triu() : opts(flag_triu ) {} }; - struct opts_tril : public opts { inline opts_tril() : opts(flag_tril ) {} }; - struct opts_no_band : public opts { inline opts_no_band() : opts(flag_no_band ) {} }; - struct opts_no_sympd : public opts { inline opts_no_sympd() : opts(flag_no_sympd ) {} }; - struct opts_allow_ugly : public opts { inline opts_allow_ugly() : opts(flag_allow_ugly ) {} }; - struct opts_likely_sympd : public opts { inline opts_likely_sympd() : opts(flag_likely_sympd) {} }; - struct opts_refine : public opts { inline opts_refine() : opts(flag_refine ) {} }; - struct opts_no_trimat : public opts { inline opts_no_trimat() : opts(flag_no_trimat ) {} }; - - static const opts_none none; - static const opts_fast fast; - static const opts_equilibrate equilibrate; - static const opts_no_approx no_approx; - static const opts_triu triu; - static const opts_tril tril; - static const opts_no_band no_band; - static const opts_no_sympd no_sympd; - static const opts_allow_ugly allow_ugly; - static const opts_likely_sympd likely_sympd; - static const opts_refine refine; - static const opts_no_trimat no_trimat; + static constexpr uword flag_none = uword(0 ); + static constexpr uword flag_fast = uword(1u << 0); + static constexpr uword flag_equilibrate = uword(1u << 1); + static constexpr uword flag_no_approx = uword(1u << 2); + static constexpr uword flag_triu = uword(1u << 3); + static constexpr uword flag_tril = uword(1u << 4); + static constexpr uword flag_no_band = uword(1u << 5); + static constexpr uword flag_no_sympd = uword(1u << 6); + static constexpr uword flag_allow_ugly = uword(1u << 7); + static constexpr uword flag_likely_sympd = uword(1u << 8); + static constexpr uword flag_refine = uword(1u << 9); + static constexpr uword flag_no_trimat = uword(1u << 10); + static constexpr uword flag_force_approx = uword(1u << 11); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; + struct opts_equilibrate : public opts { inline constexpr opts_equilibrate() : opts(flag_equilibrate ) {} }; + struct opts_no_approx : public opts { inline constexpr opts_no_approx() : opts(flag_no_approx ) {} }; + struct opts_triu : public opts { inline constexpr opts_triu() : opts(flag_triu ) {} }; + struct opts_tril : public opts { inline constexpr opts_tril() : opts(flag_tril ) {} }; + struct opts_no_band : public opts { inline constexpr opts_no_band() : opts(flag_no_band ) {} }; + struct opts_no_sympd : public opts { inline constexpr opts_no_sympd() : opts(flag_no_sympd ) {} }; + struct opts_allow_ugly : public opts { inline constexpr opts_allow_ugly() : opts(flag_allow_ugly ) {} }; + struct opts_likely_sympd : public opts { inline constexpr opts_likely_sympd() : opts(flag_likely_sympd) {} }; + struct opts_refine : public opts { inline constexpr opts_refine() : opts(flag_refine ) {} }; + struct opts_no_trimat : public opts { inline constexpr opts_no_trimat() : opts(flag_no_trimat ) {} }; + struct opts_force_approx : public opts { inline constexpr opts_force_approx() : opts(flag_force_approx) {} }; + + static constexpr opts_none none; + static constexpr opts_fast fast; + static constexpr opts_equilibrate equilibrate; + static constexpr opts_no_approx no_approx; + static constexpr opts_triu triu; + static constexpr opts_tril tril; + static constexpr opts_no_band no_band; + static constexpr opts_no_sympd no_sympd; + static constexpr opts_allow_ugly allow_ugly; + static constexpr opts_likely_sympd likely_sympd; + static constexpr opts_refine refine; + static constexpr opts_no_trimat no_trimat; + static constexpr opts_force_approx force_approx; } diff --git a/src/armadillo_bits/glue_solve_meat.hpp b/src/armadillo_bits/glue_solve_meat.hpp index 622a3441..1c3bbf0c 100644 --- a/src/armadillo_bits/glue_solve_meat.hpp +++ b/src/armadillo_bits/glue_solve_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,20 +22,21 @@ // -// glue_solve_gen +// glue_solve_gen_default template inline void -glue_solve_gen::apply(Mat& out, const Glue& X) +glue_solve_gen_default::apply(Mat& out, const Glue& X) { arma_extra_debug_sigprint(); - const bool status = glue_solve_gen::apply( out, X.A, X.B, X.aux_uword ); + const bool status = glue_solve_gen_default::apply(out, X.A, X.B); if(status == false) { + out.soft_reset(); arma_stop_runtime_error("solve(): solution not found"); } } @@ -43,82 +46,146 @@ glue_solve_gen::apply(Mat& out, const Glue inline bool -glue_solve_gen::apply(Mat& out, const Base& A_expr, const Base& B_expr, const uword flags) +glue_solve_gen_default::apply(Mat& out, const Base& A_expr, const Base& B_expr) + { + arma_extra_debug_sigprint(); + + return glue_solve_gen_full::apply( out, A_expr, B_expr, uword(0)); + } + + + +// +// glue_solve_gen_full + + +template +inline +void +glue_solve_gen_full::apply(Mat& out, const Glue& X) + { + arma_extra_debug_sigprint(); + + const bool status = glue_solve_gen_full::apply( out, X.A, X.B, X.aux_uword ); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("solve(): solution not found"); + } + } + + + +template +inline +bool +glue_solve_gen_full::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) { arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; - const bool fast = bool(flags & solve_opts::flag_fast ); - const bool equilibrate = bool(flags & solve_opts::flag_equilibrate ); - const bool no_approx = bool(flags & solve_opts::flag_no_approx ); - const bool no_band = bool(flags & solve_opts::flag_no_band ); - const bool no_sympd = bool(flags & solve_opts::flag_no_sympd ); - const bool allow_ugly = bool(flags & solve_opts::flag_allow_ugly ); - const bool likely_sympd = bool(flags & solve_opts::flag_likely_sympd); - const bool refine = bool(flags & solve_opts::flag_refine ); - const bool no_trimat = bool(flags & solve_opts::flag_no_trimat ); + if(has_user_flags == true ) { arma_extra_debug_print("glue_solve_gen_full::apply(): has_user_flags = true" ); } + if(has_user_flags == false) { arma_extra_debug_print("glue_solve_gen_full::apply(): has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & solve_opts::flag_fast ); + const bool equilibrate = has_user_flags && bool(flags & solve_opts::flag_equilibrate ); + const bool no_approx = has_user_flags && bool(flags & solve_opts::flag_no_approx ); + const bool no_band = has_user_flags && bool(flags & solve_opts::flag_no_band ); + const bool no_sympd = has_user_flags && bool(flags & solve_opts::flag_no_sympd ); + const bool allow_ugly = has_user_flags && bool(flags & solve_opts::flag_allow_ugly ); + const bool likely_sympd = has_user_flags && bool(flags & solve_opts::flag_likely_sympd); + const bool refine = has_user_flags && bool(flags & solve_opts::flag_refine ); + const bool no_trimat = has_user_flags && bool(flags & solve_opts::flag_no_trimat ); + const bool force_approx = has_user_flags && bool(flags & solve_opts::flag_force_approx); + + if(has_user_flags) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(equilibrate ) { arma_extra_debug_print("equilibrate"); } + if(no_approx ) { arma_extra_debug_print("no_approx"); } + if(no_band ) { arma_extra_debug_print("no_band"); } + if(no_sympd ) { arma_extra_debug_print("no_sympd"); } + if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); } + if(likely_sympd) { arma_extra_debug_print("likely_sympd"); } + if(refine ) { arma_extra_debug_print("refine"); } + if(no_trimat ) { arma_extra_debug_print("no_trimat"); } + if(force_approx) { arma_extra_debug_print("force_approx"); } + + arma_debug_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); + arma_debug_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); + arma_debug_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); + } - arma_extra_debug_print("glue_solve_gen::apply(): enabled flags:"); + Mat A = A_expr.get_ref(); - if(fast ) { arma_extra_debug_print("fast"); } - if(equilibrate ) { arma_extra_debug_print("equilibrate"); } - if(no_approx ) { arma_extra_debug_print("no_approx"); } - if(no_band ) { arma_extra_debug_print("no_band"); } - if(no_sympd ) { arma_extra_debug_print("no_sympd"); } - if(allow_ugly ) { arma_extra_debug_print("allow_ugly"); } - if(likely_sympd) { arma_extra_debug_print("likely_sympd"); } - if(refine ) { arma_extra_debug_print("refine"); } - if(no_trimat ) { arma_extra_debug_print("no_trimat"); } + if(force_approx) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): forced approximate solution"); + + arma_debug_check( no_approx, "solve(): options 'no_approx' and 'force_approx' are mutually exclusive" ); + + if(fast) { arma_debug_warn_level(2, "solve(): option 'fast' ignored for forced approximate solution" ); } + if(equilibrate) { arma_debug_warn_level(2, "solve(): option 'equilibrate' ignored for forced approximate solution" ); } + if(refine) { arma_debug_warn_level(2, "solve(): option 'refine' ignored for forced approximate solution" ); } + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for forced approximate solution" ); } + + return auxlib::solve_approx_svd(actual_out, A, B_expr.get_ref()); // A is overwritten + } - arma_debug_check( (fast && equilibrate ), "solve(): options 'fast' and 'equilibrate' are mutually exclusive" ); - arma_debug_check( (fast && refine ), "solve(): options 'fast' and 'refine' are mutually exclusive" ); - arma_debug_check( (no_sympd && likely_sympd), "solve(): options 'no_sympd' and 'likely_sympd' are mutually exclusive" ); + // A_expr and B_expr can be used more than once (sympd optimisation fails or approximate solution required), + // so ensure they are not overwritten in case we have aliasing + + bool is_alias = true; // assume we have aliasing until we can prove otherwise + + if(is_Mat::value && is_Mat::value) + { + const quasi_unwrap UA( A_expr.get_ref() ); + const quasi_unwrap UB( B_expr.get_ref() ); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; T rcond = T(0); bool status = false; - Mat A = A_expr.get_ref(); - if(A.n_rows == A.n_cols) { - arma_extra_debug_print("glue_solve_gen::apply(): detected square system"); + arma_extra_debug_print("glue_solve_gen_full::apply(): detected square system"); uword KL = 0; uword KU = 0; - #if defined(ARMA_OPTIMISE_SOLVE_BAND) - const bool is_band = (no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32)); - #else - const bool is_band = false; - #endif + const bool is_band = arma_config::optimise_band && ((no_band || auxlib::crippled_lapack(A)) ? false : band_helper::is_band(KL, KU, A, uword(32))); const bool is_triu = (no_trimat || refine || equilibrate || likely_sympd || is_band ) ? false : trimat_helper::is_triu(A); const bool is_tril = (no_trimat || refine || equilibrate || likely_sympd || is_band || is_triu) ? false : trimat_helper::is_tril(A); - #if defined(ARMA_OPTIMISE_SOLVE_SYMPD) - const bool try_sympd = (no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sympd_helper::guess_sympd(A)); - #else - const bool try_sympd = false; - #endif + const bool try_sympd = arma_config::optimise_sym && ((no_sympd || auxlib::crippled_lapack(A) || is_band || is_triu || is_tril) ? false : (likely_sympd ? true : sym_helper::guess_sympd(A, uword(16)))); if(fast) { // fast mode: solvers without refinement and without rcond estimate - arma_extra_debug_print("glue_solve_gen::apply(): fast mode"); + arma_extra_debug_print("glue_solve_gen_full::apply(): fast mode"); if(is_band) { if( (KL == 1) && (KU == 1) ) { - arma_extra_debug_print("glue_solve_gen::apply(): fast + tridiagonal"); + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + tridiagonal"); status = auxlib::solve_tridiag_fast(out, A, B_expr.get_ref()); } else { - arma_extra_debug_print("glue_solve_gen::apply(): fast + band"); + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + band"); status = auxlib::solve_band_fast(out, A, KL, KU, B_expr.get_ref()); } @@ -126,8 +193,8 @@ glue_solve_gen::apply(Mat& out, const Base& A_expr, const Base else if(is_triu || is_tril) { - if(is_triu) { arma_extra_debug_print("glue_solve_gen::apply(): fast + upper triangular matrix"); } - if(is_tril) { arma_extra_debug_print("glue_solve_gen::apply(): fast + lower triangular matrix"); } + if(is_triu) { arma_extra_debug_print("glue_solve_gen_full::apply(): fast + upper triangular matrix"); } + if(is_tril) { arma_extra_debug_print("glue_solve_gen_full::apply(): fast + lower triangular matrix"); } const uword layout = (is_triu) ? uword(0) : uword(1); @@ -136,22 +203,24 @@ glue_solve_gen::apply(Mat& out, const Base& A_expr, const Base else if(try_sympd) { - arma_extra_debug_print("glue_solve_gen::apply(): fast + try_sympd"); + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + try_sympd"); status = auxlib::solve_sympd_fast(out, A, B_expr.get_ref()); // A is overwritten if(status == false) { - arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_fast() failed; retrying"); - // auxlib::solve_sympd_fast() may have failed because A isn't really sympd + + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_fast() failed; retrying"); + A = A_expr.get_ref(); + status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten } } else { - arma_extra_debug_print("glue_solve_gen::apply(): fast + dense"); + arma_extra_debug_print("glue_solve_gen_full::apply(): fast + dense"); status = auxlib::solve_square_fast(out, A, B_expr.get_ref()); // A is overwritten } @@ -161,134 +230,130 @@ glue_solve_gen::apply(Mat& out, const Base& A_expr, const Base { // refine mode: solvers with refinement and with rcond estimate - arma_extra_debug_print("glue_solve_gen::apply(): refine mode"); + arma_extra_debug_print("glue_solve_gen_full::apply(): refine mode"); if(is_band) { - arma_extra_debug_print("glue_solve_gen::apply(): refine + band"); + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + band"); - status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate, allow_ugly); + status = auxlib::solve_band_refine(out, rcond, A, KL, KU, B_expr, equilibrate); } else if(try_sympd) { - arma_extra_debug_print("glue_solve_gen::apply(): refine + try_sympd"); + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + try_sympd"); - status = auxlib::solve_sympd_refine(out, rcond, A, B_expr.get_ref(), equilibrate, allow_ugly); // A is overwritten + status = auxlib::solve_sympd_refine(out, rcond, A, B_expr.get_ref(), equilibrate); // A is overwritten - if(status == false) + if( (status == false) && (rcond == T(0)) ) { - arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_refine() failed; retrying"); + // auxlib::solve_sympd_refine() may have failed because A isn't really sympd; + // in that case rcond is set to zero + + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_refine() failed; retrying"); - // auxlib::solve_sympd_refine() may have failed because A isn't really sympd A = A_expr.get_ref(); - status = auxlib::solve_square_refine(out, rcond, A, B_expr.get_ref(), equilibrate, allow_ugly); // A is overwritten + + status = auxlib::solve_square_refine(out, rcond, A, B_expr.get_ref(), equilibrate); // A is overwritten } } else { - arma_extra_debug_print("glue_solve_gen::apply(): refine + dense"); + arma_extra_debug_print("glue_solve_gen_full::apply(): refine + dense"); - status = auxlib::solve_square_refine(out, rcond, A, B_expr, equilibrate, allow_ugly); // A is overwritten + status = auxlib::solve_square_refine(out, rcond, A, B_expr, equilibrate); // A is overwritten } } else { // default mode: solvers without refinement but with rcond estimate - arma_extra_debug_print("glue_solve_gen::apply(): default mode"); + arma_extra_debug_print("glue_solve_gen_full::apply(): default mode"); if(is_band) { - arma_extra_debug_print("glue_solve_gen::apply(): rcond + band"); + arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + band"); - status = auxlib::solve_band_rcond(out, rcond, A, KL, KU, B_expr.get_ref(), allow_ugly); + status = auxlib::solve_band_rcond(out, rcond, A, KL, KU, B_expr.get_ref()); } else if(is_triu || is_tril) { - if(is_triu) { arma_extra_debug_print("glue_solve_gen::apply(): rcond + upper triangular matrix"); } - if(is_tril) { arma_extra_debug_print("glue_solve_gen::apply(): rcond + lower triangular matrix"); } + if(is_triu) { arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + upper triangular matrix"); } + if(is_tril) { arma_extra_debug_print("glue_solve_gen_full::apply(): rcond + lower triangular matrix"); } const uword layout = (is_triu) ? uword(0) : uword(1); - status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly); + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); } else if(try_sympd) { - status = auxlib::solve_sympd_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten + bool sympd_state = false; - if(status == false) + status = auxlib::solve_sympd_rcond(out, sympd_state, rcond, A, B_expr.get_ref()); // A is overwritten + + if( (status == false) && (sympd_state == false) ) { - arma_extra_debug_print("glue_solve_gen::apply(): auxlib::solve_sympd_rcond() failed; retrying"); + arma_extra_debug_print("glue_solve_gen_full::apply(): auxlib::solve_sympd_rcond() failed; retrying"); - // auxlib::solve_sympd_rcond() may have failed because A isn't really sympd A = A_expr.get_ref(); - status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten + + status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten } } else { - status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref(), allow_ugly); // A is overwritten + status = auxlib::solve_square_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten } } + } + else + { + arma_extra_debug_print("glue_solve_gen_full::apply(): detected non-square system"); + if(equilibrate) { arma_debug_warn_level(2, "solve(): option 'equilibrate' ignored for non-square matrix" ); } + if(refine) { arma_debug_warn_level(2, "solve(): option 'refine' ignored for non-square matrix" ); } + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for non-square matrix" ); } - - if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) ) + if(fast) { - arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")"); + status = auxlib::solve_rect_fast(out, A, B_expr.get_ref()); // A is overwritten } - - - if( (status == false) && (no_approx == false) ) + else { - arma_extra_debug_print("glue_solve_gen::apply(): solving rank deficient system"); - - if(rcond > T(0)) - { - arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution"); - } - else - { - arma_debug_warn("solve(): system seems singular; attempting approx solution"); - } - - // TODO: conditionally recreate A: have a separate state flag which indicates whether A was previously overwritten - - A = A_expr.get_ref(); // as A may have been overwritten - - status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten + status = auxlib::solve_rect_rcond(out, rcond, A, B_expr.get_ref()); // A is overwritten } } - else + + + if( (status == true) && (fast == false) && (allow_ugly == false) && ((rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond)) ) { - arma_extra_debug_print("glue_solve_gen::apply(): detected non-square system"); - - if(equilibrate) { arma_debug_warn( "solve(): option 'equilibrate' ignored for non-square matrix" ); } - if(refine) { arma_debug_warn( "solve(): option 'refine' ignored for non-square matrix" ); } - if(likely_sympd) { arma_debug_warn( "solve(): option 'likely_sympd' ignored for non-square matrix" ); } + status = false; + } + + + if( (status == false) && (no_approx == false) ) + { + arma_extra_debug_print("glue_solve_gen_full::apply(): solving rank deficient system"); - if(fast) + if(rcond == T(0)) { - status = auxlib::solve_approx_fast(out, A, B_expr.get_ref()); // A is overwritten - - if(status == false) - { - A = A_expr.get_ref(); // as A was overwritten - - status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten - } + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); } else { - status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); } + + // TODO: conditionally recreate A: have a separate state flag which indicates whether A was previously overwritten + + A = A_expr.get_ref(); // as A may have been overwritten + + status = auxlib::solve_approx_svd(out, A, B_expr.get_ref()); // A is overwritten } - - if(status == false) { out.soft_reset(); } + if(is_alias) { actual_out.steal_mem(out); } return status; } @@ -296,7 +361,7 @@ glue_solve_gen::apply(Mat& out, const Base& A_expr, const Base // -// glue_solve_tri +// glue_solve_tri_default template @@ -310,6 +375,7 @@ glue_solve_tri_default::apply(Mat& out, const Glue& actual_out, const Base& A_expr, co typedef typename get_pod_type::result T; - const bool triu = bool(flags & solve_opts::flag_triu); - const bool tril = bool(flags & solve_opts::flag_tril); - const bool allow_ugly = false; + const bool triu = bool(flags & solve_opts::flag_triu); + const bool tril = bool(flags & solve_opts::flag_tril); arma_extra_debug_print("glue_solve_tri_default::apply(): enabled flags:"); if(triu) { arma_extra_debug_print("triu"); } if(tril) { arma_extra_debug_print("tril"); } - const quasi_unwrap U(A_expr.get_ref()); - const Mat& A = U.M; + const quasi_unwrap UA(A_expr.get_ref()); + const Mat& A = UA.M; arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" ); - const uword layout = (triu) ? uword(0) : uword(1); - const bool is_alias = U.is_alias(actual_out); + const uword layout = (triu) ? uword(0) : uword(1); + + bool is_alias = true; + + if(is_Mat::value) + { + const quasi_unwrap UB(B_expr.get_ref()); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } T rcond = T(0); bool status = false; @@ -348,25 +421,26 @@ glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, co Mat tmp; Mat& out = (is_alias) ? tmp : actual_out; - status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly); // A is not modified + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); // A is not modified + - if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) ) + if( (status == true) && ( (rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond) ) ) { - arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")"); + status = false; } if(status == false) { - arma_extra_debug_print("glue_solve_tri::apply(): solving rank deficient system"); + arma_extra_debug_print("glue_solve_tri_default::apply(): solving rank deficient system"); - if(rcond > T(0)) + if(rcond == T(0)) { - arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution"); + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); } else { - arma_debug_warn("solve(): system seems singular; attempting approx solution"); + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); } Mat triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type @@ -375,8 +449,6 @@ glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, co } - if(status == false) { out.soft_reset(); } - if(is_alias) { actual_out.steal_mem(out); } return status; @@ -384,17 +456,22 @@ glue_solve_tri_default::apply(Mat& actual_out, const Base& A_expr, co +// +// glue_solve_tri_full + + template inline void -glue_solve_tri::apply(Mat& out, const Glue& X) +glue_solve_tri_full::apply(Mat& out, const Glue& X) { arma_extra_debug_sigprint(); - const bool status = glue_solve_tri::apply( out, X.A, X.B, X.aux_uword ); + const bool status = glue_solve_tri_full::apply( out, X.A, X.B, X.aux_uword ); if(status == false) { + out.soft_reset(); arma_stop_runtime_error("solve(): solution not found"); } } @@ -404,7 +481,7 @@ glue_solve_tri::apply(Mat& out, const Glue inline bool -glue_solve_tri::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) +glue_solve_tri_full::apply(Mat& actual_out, const Base& A_expr, const Base& B_expr, const uword flags) { arma_extra_debug_sigprint(); @@ -419,8 +496,9 @@ glue_solve_tri::apply(Mat& actual_out, const Base& A_expr, const Base const bool likely_sympd = bool(flags & solve_opts::flag_likely_sympd); const bool refine = bool(flags & solve_opts::flag_refine ); const bool no_trimat = bool(flags & solve_opts::flag_no_trimat ); + const bool force_approx = bool(flags & solve_opts::flag_force_approx); - arma_extra_debug_print("glue_solve_tri::apply(): enabled flags:"); + arma_extra_debug_print("glue_solve_tri_full::apply(): enabled flags:"); if(fast ) { arma_extra_debug_print("fast"); } if(equilibrate ) { arma_extra_debug_print("equilibrate"); } @@ -431,23 +509,32 @@ glue_solve_tri::apply(Mat& actual_out, const Base& A_expr, const Base if(likely_sympd) { arma_extra_debug_print("likely_sympd"); } if(refine ) { arma_extra_debug_print("refine"); } if(no_trimat ) { arma_extra_debug_print("no_trimat"); } + if(force_approx) { arma_extra_debug_print("force_approx"); } - if(no_trimat || equilibrate || refine) + if(no_trimat || equilibrate || refine || force_approx) { const uword mask = ~(solve_opts::flag_triu | solve_opts::flag_tril); - return glue_solve_gen::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask)); + return glue_solve_gen_full::apply(actual_out, ((triu) ? trimatu(A_expr.get_ref()) : trimatl(A_expr.get_ref())), B_expr, (flags & mask)); } - if(likely_sympd) { arma_debug_warn("solve(): option 'likely_sympd' ignored for triangular matrix"); } + if(likely_sympd) { arma_debug_warn_level(2, "solve(): option 'likely_sympd' ignored for triangular matrix"); } - const quasi_unwrap U(A_expr.get_ref()); - const Mat& A = U.M; + const quasi_unwrap UA(A_expr.get_ref()); + const Mat& A = UA.M; arma_debug_check( (A.is_square() == false), "solve(): matrix marked as triangular must be square sized" ); - const uword layout = (triu) ? uword(0) : uword(1); - const bool is_alias = U.is_alias(actual_out); + const uword layout = (triu) ? uword(0) : uword(1); + + bool is_alias = true; + + if(is_Mat::value) + { + const quasi_unwrap UB(B_expr.get_ref()); + + is_alias = UA.is_alias(actual_out) || UB.is_alias(actual_out); + } T rcond = T(0); bool status = false; @@ -461,26 +548,27 @@ glue_solve_tri::apply(Mat& actual_out, const Base& A_expr, const Base } else { - status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout, allow_ugly); // A is not modified + status = auxlib::solve_trimat_rcond(out, rcond, A, B_expr.get_ref(), layout); // A is not modified } - if( (status == true) && (rcond > T(0)) && (rcond < auxlib::epsilon_lapack(A)) ) + + if( (status == true) && (fast == false) && (allow_ugly == false) && ((rcond < std::numeric_limits::epsilon()) || arma_isnan(rcond)) ) { - arma_debug_warn("solve(): solution computed, but system seems singular to working precision (rcond: ", rcond, ")"); + status = false; } if( (status == false) && (no_approx == false) ) { - arma_extra_debug_print("glue_solve_tri::apply(): solving rank deficient system"); + arma_extra_debug_print("glue_solve_tri_full::apply(): solving rank deficient system"); - if(rcond > T(0)) + if(rcond == T(0)) { - arma_debug_warn("solve(): system seems singular (rcond: ", rcond, "); attempting approx solution"); + arma_debug_warn_level(2, "solve(): system is singular; attempting approx solution"); } else { - arma_debug_warn("solve(): system seems singular; attempting approx solution"); + arma_debug_warn_level(2, "solve(): system is singular (rcond: ", rcond, "); attempting approx solution"); } Mat triA = (triu) ? trimatu(A) : trimatl(A); // trimatu() and trimatl() return the same type @@ -489,8 +577,6 @@ glue_solve_tri::apply(Mat& actual_out, const Base& A_expr, const Base } - if(status == false) { out.soft_reset(); } - if(is_alias) { actual_out.steal_mem(out); } return status; diff --git a/src/armadillo_bits/glue_times_bones.hpp b/src/armadillo_bits/glue_times_bones.hpp index 69b5b2c9..5792e4ec 100644 --- a/src/armadillo_bits/glue_times_bones.hpp +++ b/src/armadillo_bits/glue_times_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,18 +24,18 @@ //! \brief //! Template metaprogram depth_lhs //! calculates the number of Glue instances on the left hand side argument of Glue -//! i.e. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes) +//! ie. it recursively expands each Tx, until the type of Tx is not "Glue<..,.., glue_type>" (i.e the "glue_type" changes) template struct depth_lhs { - static const uword num = 0; + static constexpr uword num = 0; }; template struct depth_lhs< glue_type, Glue > { - static const uword num = 1 + depth_lhs::num; + static constexpr uword num = 1 + depth_lhs::num; }; @@ -113,9 +115,9 @@ class glue_times template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template @@ -152,9 +154,9 @@ class glue_times_diag template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/glue_times_meat.hpp b/src/armadillo_bits/glue_times_meat.hpp index af974b25..0dc8a02e 100644 --- a/src/armadillo_bits/glue_times_meat.hpp +++ b/src/armadillo_bits/glue_times_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,6 @@ template template -arma_hot inline void glue_times_redirect2_helper::apply(Mat& out, const Glue& X) @@ -72,7 +73,6 @@ glue_times_redirect2_helper::apply(Mat& o template -arma_hot inline void glue_times_redirect2_helper::apply(Mat& out, const Glue& X) @@ -81,7 +81,7 @@ glue_times_redirect2_helper::apply(Mat& out, const typedef typename T1::elem_type eT; - if(strip_inv::do_inv == true) + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) { // replace inv(A)*B with solve(A,B) @@ -93,24 +93,10 @@ glue_times_redirect2_helper::apply(Mat& out, const arma_debug_check( (A.is_square() == false), "inv(): given matrix must be square sized" ); - if(strip_inv::do_inv_sympd) + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) ) { - // if(auxlib::rudimentary_sym_check(A) == false) - // { - // if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - // if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - // - // out.soft_reset(); - // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); - // - // return; - // } - - if( (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) ) - { - if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - } + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } } const unwrap_check B_tmp(X.B, out); @@ -118,11 +104,7 @@ glue_times_redirect2_helper::apply(Mat& out, const arma_debug_assert_mul_size(A, B, "matrix multiplication"); - #if defined(ARMA_OPTIMISE_SOLVE_SYMPD) - const bool status = (strip_inv::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B); - #else - const bool status = auxlib::solve_square_fast(out, A, B); - #endif + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, B) : auxlib::solve_square_fast(out, A, B); if(status == false) { @@ -133,56 +115,41 @@ glue_times_redirect2_helper::apply(Mat& out, const return; } - #if defined(ARMA_OPTIMISE_SOLVE_SYMPD) + if(arma_config::optimise_invexpr && strip_inv::do_inv_spd) { - if(strip_inv::do_inv_sympd) + // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) ) + // transpose of B is avoided as B is explicitly marked as symmetric + + arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sympd(B)"); + + const Mat At = trans(X.A); + + const strip_inv B_strip(X.B); + + Mat B = B_strip.M; + + arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must be square sized" ); + + if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) { - // replace A*inv_sympd(B) with trans( solve(trans(B),trans(A)) ) - // transpose of B is avoided as B is explicitly marked as symmetric - - arma_extra_debug_print("glue_times_redirect<2>::apply(): detected A*inv_sympd(B)"); - - const Mat At = trans(X.A); - - const strip_inv B_strip(X.B); - - Mat B = B_strip.M; - - arma_debug_check( (B.is_square() == false), "inv_sympd(): given matrix must be square sized" ); - - // if(auxlib::rudimentary_sym_check(B) == false) - // { - // if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - // if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - // - // out.soft_reset(); - // arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); - // - // return; - // } - - if( (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) - { - if(is_cx::no ) { arma_debug_warn("inv_sympd(): given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn("inv_sympd(): given matrix is not hermitian"); } - } - - arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); - - const bool status = auxlib::solve_sympd_fast(out, B, At); - - if(status == false) - { - out.soft_reset(); - arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); - } - - out = trans(out); - - return; + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } } + + arma_debug_assert_mul_size(At.n_cols, At.n_rows, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool status = auxlib::solve_sympd_fast(out, B, At); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("matrix multiplication: problem with matrix inverse; suggest to use solve() instead"); + } + + out = trans(out); + + return; } - #endif glue_times_redirect2_helper::apply(out, X); } @@ -191,7 +158,6 @@ glue_times_redirect2_helper::apply(Mat& out, const template template -arma_hot inline void glue_times_redirect3_helper::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) @@ -249,7 +215,6 @@ glue_times_redirect3_helper::apply(Mat& o template -arma_hot inline void glue_times_redirect3_helper::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) @@ -258,7 +223,7 @@ glue_times_redirect3_helper::apply(Mat& out, const typedef typename T1::elem_type eT; - if(strip_inv::do_inv == true) + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) { // replace inv(A)*B*C with solve(A,B*C); @@ -292,11 +257,13 @@ glue_times_redirect3_helper::apply(Mat& out, const arma_debug_assert_mul_size(A, BC, "matrix multiplication"); - #if defined(ARMA_OPTIMISE_SOLVE_SYMPD) - const bool status = (strip_inv::do_inv_sympd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC); - #else - const bool status = auxlib::solve_square_fast(out, A, BC); - #endif + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(A) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(out, A, BC) : auxlib::solve_square_fast(out, A, BC); if(status == false) { @@ -308,7 +275,7 @@ glue_times_redirect3_helper::apply(Mat& out, const } - if(strip_inv::do_inv == true) + if(arma_config::optimise_invexpr && (strip_inv::do_inv_gen || strip_inv::do_inv_spd)) { // replace A*inv(B)*C with A*solve(B,C) @@ -325,13 +292,15 @@ glue_times_redirect3_helper::apply(Mat& out, const arma_debug_assert_mul_size(B, C, "matrix multiplication"); + if( (strip_inv::do_inv_spd) && (arma_config::debug) && (auxlib::rudimentary_sym_check(B) == false) ) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + Mat solve_result; - #if defined(ARMA_OPTIMISE_SOLVE_SYMPD) - const bool status = (strip_inv::do_inv_sympd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C); - #else - const bool status = auxlib::solve_square_fast(solve_result, B, C); - #endif + const bool status = (strip_inv::do_inv_spd) ? auxlib::solve_sympd_fast(solve_result, B, C) : auxlib::solve_square_fast(solve_result, B, C); if(status == false) { @@ -367,7 +336,6 @@ glue_times_redirect3_helper::apply(Mat& out, const template template -arma_hot inline void glue_times_redirect::apply(Mat& out, const Glue& X) @@ -418,7 +386,6 @@ glue_times_redirect::apply(Mat& out, const Glue -arma_hot inline void glue_times_redirect<2>::apply(Mat& out, const Glue& X) @@ -433,7 +400,6 @@ glue_times_redirect<2>::apply(Mat& out, const Glue -arma_hot inline void glue_times_redirect<3>::apply(Mat& out, const Glue< Glue, T3, glue_times>& X) @@ -448,7 +414,6 @@ glue_times_redirect<3>::apply(Mat& out, const Glue< Glue template -arma_hot inline void glue_times_redirect<4>::apply(Mat& out, const Glue< Glue< Glue, T3, glue_times>, T4, glue_times>& X) @@ -510,16 +475,15 @@ glue_times_redirect<4>::apply(Mat& out, const Glue< Glue template -arma_hot inline void glue_times::apply(Mat& out, const Glue& X) { arma_extra_debug_sigprint(); - const sword N_mat = 1 + depth_lhs< glue_times, Glue >::num; + constexpr uword N_mat = 1 + depth_lhs< glue_times, Glue >::num; - arma_extra_debug_print(arma_str::format("N_mat = %d") % N_mat); + arma_extra_debug_print(arma_str::format("N_mat = %u") % N_mat); glue_times_redirect::apply(out, X); } @@ -527,7 +491,6 @@ glue_times::apply(Mat& out, const Glue template -arma_hot inline void glue_times::apply_inplace(Mat& out, const T1& X) @@ -540,7 +503,6 @@ glue_times::apply_inplace(Mat& out, const T1& X) template -arma_hot inline void glue_times::apply_inplace_plus(Mat& out, const Glue& X, const sword sign) @@ -550,7 +512,7 @@ glue_times::apply_inplace_plus(Mat& out, const Glue::result T; - if( (is_outer_product::value) || (has_op_inv::value) || (has_op_inv::value) || (has_op_inv_sympd::value) || (has_op_inv_sympd::value) ) + if( (is_outer_product::value) || (has_op_inv_any::value) || (has_op_inv_any::value) ) { // partial workaround for corner cases @@ -584,11 +546,7 @@ glue_times::apply_inplace_plus(Mat& out, const Glue sword(0)) ? "addition" : "subtraction" ) ); - if(out.n_elem == 0) - { - return; - } - + if(out.n_elem == 0) { return; } if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) { @@ -677,7 +635,6 @@ template typename TA, typename TB > -arma_hot inline void glue_times::apply @@ -698,12 +655,7 @@ glue_times::apply out.set_size(final_n_rows, final_n_cols); - if( (A.n_elem == 0) || (B.n_elem == 0) ) - { - out.zeros(); - return; - } - + if( (A.n_elem == 0) || (B.n_elem == 0) ) { out.zeros(); return; } if( (do_trans_A == false) && (do_trans_B == false) && (use_alpha == false) ) { @@ -781,7 +733,6 @@ template typename TB, typename TC > -arma_hot inline void glue_times::apply @@ -831,7 +782,6 @@ template typename TC, typename TD > -arma_hot inline void glue_times::apply @@ -876,10 +826,9 @@ glue_times::apply template -arma_hot inline void -glue_times_diag::apply(Mat& out, const Glue& X) +glue_times_diag::apply(Mat& actual_out, const Glue& X) { arma_extra_debug_sigprint(); @@ -895,10 +844,10 @@ glue_times_diag::apply(Mat& out, const Glue A(S1.M, out); + const diagmat_proxy A(S1.M); - const unwrap_check tmp(X.B, out); - const Mat& B = tmp.M; + const quasi_unwrap UB(X.B); + const Mat& B = UB.M; const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; @@ -909,6 +858,13 @@ glue_times_diag::apply(Mat& out, const Glue tmp; + Mat& out = (is_alias) ? tmp : actual_out; + out.zeros(A_n_rows, B_n_cols); for(uword col=0; col < B_n_cols; ++col) @@ -916,21 +872,20 @@ glue_times_diag::apply(Mat& out, const Glue::do_diagmat == false) && (strip_diagmat::do_diagmat == true) ) { arma_extra_debug_print("glue_times_diag::apply(): A * diagmat(B)"); - const unwrap_check tmp(X.A, out); - const Mat& A = tmp.M; + const quasi_unwrap UA(X.A); + const Mat& A = UA.M; - const diagmat_proxy_check B(S2.M, out); + const diagmat_proxy B(S2.M); const uword A_n_rows = A.n_rows; const uword A_n_cols = A.n_cols; @@ -941,6 +896,13 @@ glue_times_diag::apply(Mat& out, const Glue tmp; + Mat& out = (is_alias) ? tmp : actual_out; + out.zeros(A_n_rows, B_n_cols); for(uword col=0; col < B_length; ++col) @@ -950,22 +912,28 @@ glue_times_diag::apply(Mat& out, const Glue::do_diagmat == true) && (strip_diagmat::do_diagmat == true) ) { arma_extra_debug_print("glue_times_diag::apply(): diagmat(A) * diagmat(B)"); - const diagmat_proxy_check A(S1.M, out); - const diagmat_proxy_check B(S2.M, out); + const diagmat_proxy A(S1.M); + const diagmat_proxy B(S2.M); arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + const bool is_alias = (A.is_alias(actual_out) || B.is_alias(actual_out)); + + if(is_alias) { arma_extra_debug_print("glue_times_diag::apply(): aliasing detected"); } + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + out.zeros(A.n_rows, B.n_cols); const uword A_length = (std::min)(A.n_rows, A.n_cols); @@ -973,10 +941,9 @@ glue_times_diag::apply(Mat& out, const Glue + arma_inline static typename arma_not_cx::result dot(const eT* A_mem, const SpMat& B, const uword col); + + template + arma_inline static typename arma_cx_only::result dot(const eT* A_mem, const SpMat& B, const uword col); + }; + + + +class glue_times_dense_sparse + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const SpToDGlue& expr); + + template + inline static void apply_noalias(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); + }; + + + +class glue_times_sparse_dense + { + public: + + template + struct traits + { + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; + }; + + template + inline static void apply(Mat& out, const SpToDGlue& expr); + + template + inline static void apply_noalias(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_noalias_trans(Mat& out, const T1& x, const T2& y); + + template + inline static void apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); + }; + + + +//! @} diff --git a/src/armadillo_bits/glue_times_misc_meat.hpp b/src/armadillo_bits/glue_times_misc_meat.hpp new file mode 100644 index 00000000..cafbb98a --- /dev/null +++ b/src/armadillo_bits/glue_times_misc_meat.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup glue_times_misc +//! @{ + + + +template +arma_inline +typename arma_not_cx::result +dense_sparse_helper::dot(const eT* A_mem, const SpMat& B, const uword col) + { + arma_extra_debug_sigprint(); + + uword col_offset = B.col_ptrs[col ]; + const uword next_col_offset = B.col_ptrs[col + 1]; + + const uword* start_ptr = &(B.row_indices[ col_offset]); + const uword* end_ptr = &(B.row_indices[next_col_offset]); + + const eT* B_values = B.values; + + eT acc = eT(0); + + for(const uword* ptr = start_ptr; ptr != end_ptr; ++ptr) + { + const uword index = (*ptr); + + acc += A_mem[index] * B_values[col_offset]; + + ++col_offset; + } + + return acc; + } + + + +template +arma_inline +typename arma_cx_only::result +dense_sparse_helper::dot(const eT* A_mem, const SpMat& B, const uword col) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + uword col_offset = B.col_ptrs[col ]; + const uword next_col_offset = B.col_ptrs[col + 1]; + + const uword* start_ptr = &(B.row_indices[ col_offset]); + const uword* end_ptr = &(B.row_indices[next_col_offset]); + + const eT* B_values = B.values; + + T acc_real = T(0); + T acc_imag = T(0); + + for(const uword* ptr = start_ptr; ptr != end_ptr; ++ptr) + { + const uword index = (*ptr); + + const std::complex& X = A_mem[index]; + const std::complex& Y = B_values[col_offset]; + + const T a = X.real(); + const T b = X.imag(); + + const T c = Y.real(); + const T d = Y.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + + ++col_offset; + } + + return std::complex(acc_real, acc_imag); + } + + + +template +inline +void +glue_times_dense_sparse::apply(Mat& out, const SpToDGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_op_diagmat::value) { out = SpMat(expr.A) * expr.B; return; } // SpMat has specialised handling for op_diagmat + + const quasi_unwrap UA(expr.A); + + if(UA.is_alias(out)) + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, UA.M, expr.B); + + out.steal_mem(tmp); + } + else + { + glue_times_dense_sparse::apply_noalias(out, UA.M, expr.B); + } + } + + + +template +inline +void +glue_times_dense_sparse::apply_noalias(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const quasi_unwrap UA(x); + const Mat& A = UA.M; + + const unwrap_spmat UB(y); + const SpMat& B = UB.M; + + arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + out.set_size(A.n_rows, B.n_cols); + + if((A.n_elem == 0) || (B.n_nonzero == 0)) { out.zeros(); return; } + + if((resolves_to_rowvector::value) || (A.n_rows == 1)) + { + arma_extra_debug_print("using row vector specialisation"); + + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (B.n_cols >= 2) && mp_gate::eval(B.n_nonzero) ) + { + #if defined(ARMA_USE_OPENMP) + { + arma_extra_debug_print("openmp implementation"); + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + + const uword B_n_cols = B.n_cols; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < B_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(A_mem, B, col); + } + } + #endif + } + else + { + arma_extra_debug_print("serial implementation"); + + eT* out_mem = out.memptr(); + const eT* A_mem = A.memptr(); + + const uword B_n_cols = B.n_cols; + + for(uword col=0; col < B_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(A_mem, B, col); + } + } + } + else + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (A.n_rows <= (A.n_cols / uword(100))) ) + { + #if defined(ARMA_USE_OPENMP) + { + arma_extra_debug_print("using parallelised multiplication"); + + const uword B_n_cols = B.n_cols; + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword i=0; i < B_n_cols; ++i) + { + const uword col_offset_1 = B.col_ptrs[i ]; + const uword col_offset_2 = B.col_ptrs[i+1]; + + const uword col_offset_delta = col_offset_2 - col_offset_1; + + const uvec indices(const_cast(&(B.row_indices[col_offset_1])), col_offset_delta, false, false); + const Col B_col(const_cast< eT*>(&( B.values[col_offset_1])), col_offset_delta, false, false); + + out.col(i) = A.cols(indices) * B_col; + } + } + #endif + } + else + { + arma_extra_debug_print("using standard multiplication"); + + out.zeros(); + + typename SpMat::const_iterator B_it = B.begin(); + + const uword nnz = B.n_nonzero; + const uword out_n_rows = out.n_rows; + + for(uword count = 0; count < nnz; ++count, ++B_it) + { + const eT B_it_val = (*B_it); + const uword B_it_col = B_it.col(); + const uword B_it_row = B_it.row(); + + const eT* A_col = A.colptr(B_it_row); + eT* out_col = out.colptr(B_it_col); + + for(uword row = 0; row < out_n_rows; ++row) + { + out_col[row] += A_col[row] * B_it_val; + } + } + } + } + + + +template +inline +void +glue_times_dense_sparse::apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat AA = conv_to< Mat >::from(A); + + const SpMat& BB = reinterpret_cast< const SpMat& >(B); + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat& AA = reinterpret_cast< const Mat& >(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + else + { + // upgrade T1 and T2 + + const quasi_unwrap UA(X); + const unwrap_spmat UB(Y); + + const Mat& A = UA.M; + const SpMat& B = UB.M; + + const Mat AA = conv_to< Mat >::from(A); + + SpMat BB(arma_layout_indicator(), B); + + for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } + + glue_times_dense_sparse::apply_noalias(out, AA, BB); + } + } + + + +// + + + +template +inline +void +glue_times_sparse_dense::apply(Mat& out, const SpToDGlue& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_op_diagmat::value) { out = expr.A * SpMat(expr.B); return; } // SpMat has specialised handling for op_diagmat + + const quasi_unwrap UB(expr.B); + + if((sp_strip_trans::do_htrans && is_cx::no) || (sp_strip_trans::do_strans)) + { + arma_extra_debug_print("detected non-conjugate transpose of A"); + + const sp_strip_trans x_strip(expr.A); + + if(UB.is_alias(out)) + { + Mat tmp; + + glue_times_sparse_dense::apply_noalias_trans(tmp, x_strip.M, UB.M); + + out.steal_mem(tmp); + } + else + { + glue_times_sparse_dense::apply_noalias_trans(out, x_strip.M, UB.M); + } + } + else + { + if(UB.is_alias(out)) + { + Mat tmp; + + glue_times_sparse_dense::apply_noalias(tmp, expr.A, UB.M); + + out.steal_mem(tmp); + } + else + { + glue_times_sparse_dense::apply_noalias(out, expr.A, UB.M); + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_noalias(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(x); + const SpMat& A = UA.M; + + const quasi_unwrap UB(y); + const Mat& B = UB.M; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); + + if((resolves_to_colvector::value) || (B_n_cols == 1)) + { + arma_extra_debug_print("using column vector specialisation"); + + out.zeros(A_n_rows, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + out_mem[A_it_row] += A_it_val * B_mem[A_it_col]; + } + } + else + if(B_n_cols >= (B_n_rows / uword(100))) + { + arma_extra_debug_print("using transpose-based multiplication"); + + const SpMat At = A.st(); + const Mat Bt = B.st(); + + if(A_n_rows == B_n_cols) + { + glue_times_dense_sparse::apply_noalias(out, Bt, At); + + op_strans::apply_mat(out, out); // since 'out' is square-sized, this will do an inplace transpose + } + else + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, Bt, At); + + op_strans::apply_mat(out, tmp); + } + } + else + { + arma_extra_debug_print("using standard multiplication"); + + out.zeros(A_n_rows, B_n_cols); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + for(uword col = 0; col < B_n_cols; ++col) + { + out.at(A_it_row, col) += A_it_val * B.at(A_it_col, col); + } + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_noalias_trans(Mat& out, const T1& x, const T2& y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat UA(x); + const SpMat& A = UA.M; // NOTE: this is the given matrix without the transpose operation applied + + const quasi_unwrap UB(y); + const Mat& B = UB.M; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + arma_debug_assert_mul_size(A_n_cols, A_n_rows, B_n_rows, B_n_cols, "matrix multiplication"); + + if((resolves_to_colvector::value) || (B_n_cols == 1)) + { + arma_extra_debug_print("using column vector specialisation (avoiding transpose of A)"); + + if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (A_n_cols >= 2) && mp_gate::eval(A.n_nonzero) ) + { + arma_extra_debug_print("opemp implementation"); + + #if defined(ARMA_USE_OPENMP) + { + out.zeros(A_n_cols, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + const int n_threads = mp_thread_limit::get(); + + #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=0; col < A_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(B_mem, A, col); + } + } + #endif + } + else + { + arma_extra_debug_print("serial implementation"); + + out.zeros(A_n_cols, 1); + + eT* out_mem = out.memptr(); + const eT* B_mem = B.memptr(); + + for(uword col=0; col < A_n_cols; ++col) + { + out_mem[col] = dense_sparse_helper::dot(B_mem, A, col); + } + } + } + else + if(B_n_cols >= (B_n_rows / uword(100))) + { + arma_extra_debug_print("using transpose-based multiplication (avoiding transpose of A)"); + + const Mat Bt = B.st(); + + if(A_n_cols == B_n_cols) + { + glue_times_dense_sparse::apply_noalias(out, Bt, A); + + op_strans::apply_mat(out, out); // since 'out' is square-sized, this will do an inplace transpose + } + else + { + Mat tmp; + + glue_times_dense_sparse::apply_noalias(tmp, Bt, A); + + op_strans::apply_mat(out, tmp); + } + } + else + { + arma_extra_debug_print("using standard multiplication (avoiding transpose of A)"); + + out.zeros(A_n_cols, B_n_cols); + + typename SpMat::const_iterator A_it = A.begin(); + + const uword nnz = A.n_nonzero; + + for(uword count = 0; count < nnz; ++count, ++A_it) + { + const eT A_it_val = (*A_it); + const uword A_it_row = A_it.row(); + const uword A_it_col = A_it.col(); + + for(uword col = 0; col < B_n_cols; ++col) + { + out.at(A_it_col, col) += A_it_val * B.at(A_it_row, col); + } + } + } + } + + + +template +inline +void +glue_times_sparse_dense::apply_mixed(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT1; + typedef typename T2::elem_type eT2; + + typedef typename promote_type::result out_eT; + + promote_type::check(); + + if( (is_same_type::no) && (is_same_type::yes) ) + { + // upgrade T1 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const Mat& BB = reinterpret_cast< const Mat& >(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + else + if( (is_same_type::yes) && (is_same_type::no) ) + { + // upgrade T2 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + const SpMat& AA = reinterpret_cast< const SpMat& >(A); + + const Mat BB = conv_to< Mat >::from(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + else + { + // upgrade T1 and T2 + + const unwrap_spmat UA(X); + const quasi_unwrap UB(Y); + + const SpMat& A = UA.M; + const Mat& B = UB.M; + + SpMat AA(arma_layout_indicator(), A); + + for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } + + const Mat BB = conv_to< Mat >::from(B); + + glue_times_sparse_dense::apply_noalias(out, AA, BB); + } + } + + + +//! @} diff --git a/src/armadillo_bits/glue_toeplitz_bones.hpp b/src/armadillo_bits/glue_toeplitz_bones.hpp index af8d23d3..338de148 100644 --- a/src/armadillo_bits/glue_toeplitz_bones.hpp +++ b/src/armadillo_bits/glue_toeplitz_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/glue_toeplitz_meat.hpp b/src/armadillo_bits/glue_toeplitz_meat.hpp index da57337f..77f9a090 100644 --- a/src/armadillo_bits/glue_toeplitz_meat.hpp +++ b/src/armadillo_bits/glue_toeplitz_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -38,7 +40,7 @@ glue_toeplitz::apply(Mat& out, const Glue struct traits { - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = true; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; }; template inline static void apply(Mat& out, const Glue& in); diff --git a/src/armadillo_bits/glue_trapz_meat.hpp b/src/armadillo_bits/glue_trapz_meat.hpp index 04276719..ed7b577d 100644 --- a/src/armadillo_bits/glue_trapz_meat.hpp +++ b/src/armadillo_bits/glue_trapz_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/gmm_diag_bones.hpp b/src/armadillo_bits/gmm_diag_bones.hpp index 1a294608..386a40d8 100644 --- a/src/armadillo_bits/gmm_diag_bones.hpp +++ b/src/armadillo_bits/gmm_diag_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -62,11 +64,11 @@ class gmm_diag inline Col generate() const; inline Mat generate(const uword N) const; - template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = 0) const; - template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = 0) const; + template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; - template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = 0) const; - template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = 0) const; + template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; template inline eT sum_log_p(const Base& expr) const; template inline eT sum_log_p(const Base& expr, const uword gaus_id) const; @@ -74,8 +76,8 @@ class gmm_diag template inline eT avg_log_p(const Base& expr) const; template inline eT avg_log_p(const Base& expr, const uword gaus_id) const; - template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = 0) const; - template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = 0) const; + template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = nullptr) const; + template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = nullptr) const; template inline urowvec raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; template inline Row norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; diff --git a/src/armadillo_bits/gmm_diag_meat.hpp b/src/armadillo_bits/gmm_diag_meat.hpp index ea623d82..1b6681ed 100644 --- a/src/armadillo_bits/gmm_diag_meat.hpp +++ b/src/armadillo_bits/gmm_diag_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -149,9 +151,9 @@ gmm_diag::set_params(const Base& in_means_expr, const Base& in "gmm_diag::set_params(): given parameters have inconsistent and/or wrong sizes" ); - arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_params(): given means have non-finite values" ); - arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_params(): given dcovs have non-finite values" ); - arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_params(): given hefts have non-finite values" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_params(): given means have non-finite values" ); + arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_params(): given dcovs have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_params(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_params(): given dcovs have negative or zero values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_params(): given hefts have negative values" ); @@ -182,7 +184,7 @@ gmm_diag::set_means(const Base& in_means_expr) const Mat& in_means = tmp.M; arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_diag::set_means(): given means have incompatible size" ); - arma_debug_check( (in_means.is_finite() == false), "gmm_diag::set_means(): given means have non-finite values" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_diag::set_means(): given means have non-finite values" ); access::rw(means) = in_means; } @@ -202,7 +204,7 @@ gmm_diag::set_dcovs(const Base& in_dcovs_expr) const Mat& in_dcovs = tmp.M; arma_debug_check( (arma::size(in_dcovs) != arma::size(dcovs)), "gmm_diag::set_dcovs(): given dcovs have incompatible size" ); - arma_debug_check( (in_dcovs.is_finite() == false), "gmm_diag::set_dcovs(): given dcovs have non-finite values" ); + arma_debug_check( (in_dcovs.internal_has_nonfinite()), "gmm_diag::set_dcovs(): given dcovs have non-finite values" ); arma_debug_check( (any(vectorise(in_dcovs) <= eT(0))), "gmm_diag::set_dcovs(): given dcovs have negative or zero values" ); access::rw(dcovs) = in_dcovs; @@ -225,7 +227,7 @@ gmm_diag::set_hefts(const Base& in_hefts_expr) const Mat& in_hefts = tmp.M; arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_diag::set_hefts(): given hefts have incompatible size" ); - arma_debug_check( (in_hefts.is_finite() == false), "gmm_diag::set_hefts(): given hefts have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_diag::set_hefts(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_diag::set_hefts(): given hefts have negative values" ); const eT s = accu(in_hefts); @@ -283,7 +285,7 @@ gmm_diag::load(const std::string name) if( (status == false) || (Q.n_slices != 2) ) { reset(); - arma_debug_warn("gmm_diag::load(): problem with loading or incompatible format"); + arma_debug_warn_level(3, "gmm_diag::load(): problem with loading or incompatible format"); return false; } @@ -311,7 +313,7 @@ gmm_diag::save(const std::string name) const { arma_extra_debug_sigprint(); - Cube Q(means.n_rows + 1, means.n_cols, 2); + Cube Q(means.n_rows + 1, means.n_cols, 2, arma_nozeros_indicator()); if(Q.n_elem > 0) { @@ -643,7 +645,7 @@ gmm_diag::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) if(acc == eT(0)) { acc = eT(1); } - Row out(hist_n_elem); + Row out(hist_n_elem, arma_nozeros_indicator()); eT* out_mem = out.memptr(); @@ -688,8 +690,8 @@ gmm_diag::learn const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; - if(X.is_empty() ) { arma_debug_warn("gmm_diag::learn(): given matrix is empty" ); return false; } - if(X.is_finite() == false) { arma_debug_warn("gmm_diag::learn(): given matrix has non-finite values"); return false; } + if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "gmm_diag::learn(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } @@ -718,14 +720,14 @@ gmm_diag::learn if(seed_mode == keep_existing) { - if(means.is_empty() ) { arma_debug_warn("gmm_diag::learn(): no existing means" ); return false; } - if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_diag::learn(): dimensionality mismatch"); return false; } + if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_diag::learn(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_diag::learn(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { - if(X.n_cols < N_gaus) { arma_debug_warn("gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_diag::learn(): number of vectors is less than number of gaussians"); return false; } reset(X.n_rows, N_gaus); @@ -749,7 +751,7 @@ gmm_diag::learn stream_state.restore(get_cout_stream()); - if(status == false) { arma_debug_warn("gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } + if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } } @@ -776,7 +778,7 @@ gmm_diag::learn stream_state.restore(get_cout_stream()); - if(status == false) { arma_debug_warn("gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } + if(status == false) { arma_debug_warn_level(3, "gmm_diag::learn(): EM algorithm failed"); init(orig); return false; } } mah_aux.reset(); @@ -816,8 +818,8 @@ gmm_diag::kmeans_wrapper const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; - if(X.is_empty() ) { arma_debug_warn("kmeans(): given matrix is empty" ); return false; } - if(X.is_finite() == false) { arma_debug_warn("kmeans(): given matrix has non-finite values"); return false; } + if(X.is_empty() ) { arma_debug_warn_level(3, "kmeans(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "kmeans(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } @@ -828,14 +830,14 @@ gmm_diag::kmeans_wrapper { access::rw(means) = user_means; - if(means.is_empty() ) { arma_debug_warn("kmeans(): no existing means" ); return false; } - if(X.n_rows != means.n_rows) { arma_debug_warn("kmeans(): dimensionality mismatch"); return false; } + if(means.is_empty() ) { arma_debug_warn_level(3, "kmeans(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "kmeans(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { - if(X.n_cols < N_gaus) { arma_debug_warn("kmeans(): number of vectors is less than number of means"); return false; } + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "kmeans(): number of vectors is less than number of means"); return false; } access::rw(means).zeros(X.n_rows, N_gaus); @@ -857,7 +859,7 @@ gmm_diag::kmeans_wrapper stream_state.restore(get_cout_stream()); - if(status == false) { arma_debug_warn("kmeans(): clustering failed; not enough data, or too many means requested"); return false; } + if(status == false) { arma_debug_warn_level(3, "kmeans(): clustering failed; not enough data, or too many means requested"); return false; } } return true; @@ -970,7 +972,7 @@ gmm_diag::init_constants() // - const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum::pi); + const eT tmp = (eT(N_dims)/eT(2)) * std::log(Datum::tau); log_det_etc.set_size(N_gaus); @@ -1013,12 +1015,12 @@ gmm_diag::internal_gen_boundaries(const uword N) const const uword n_threads_avail = (omp_in_parallel()) ? uword(1) : uword(omp_get_max_threads()); const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1; #else - static const uword n_threads = 1; + static constexpr uword n_threads = 1; #endif // get_cout_stream() << "gmm_diag::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; - umat boundaries(2, n_threads); + umat boundaries(2, n_threads, arma_nozeros_indicator()); if(N > 0) { @@ -1050,7 +1052,6 @@ gmm_diag::internal_gen_boundaries(const uword N) const template -arma_hot inline eT gmm_diag::internal_scalar_log_p(const eT* x) const @@ -1083,7 +1084,6 @@ gmm_diag::internal_scalar_log_p(const eT* x) const template -arma_hot inline eT gmm_diag::internal_scalar_log_p(const eT* x, const uword g) const @@ -1135,7 +1135,7 @@ gmm_diag::internal_vec_log_p(const Mat& X) const const uword N = X.n_cols; - Row out(N); + Row out(N, arma_nozeros_indicator()); if(N > 0) { @@ -1188,7 +1188,7 @@ gmm_diag::internal_vec_log_p(const Mat& X, const uword gaus_id) const const uword N = X.n_cols; - Row out(N); + Row out(N, arma_nozeros_indicator()); if(N > 0) { @@ -1249,7 +1249,7 @@ gmm_diag::internal_sum_log_p(const Mat& X) const const uword n_threads = boundaries.n_cols; - Col t_accs(n_threads, fill::zeros); + Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) @@ -1306,7 +1306,7 @@ gmm_diag::internal_sum_log_p(const Mat& X, const uword gaus_id) const const uword n_threads = boundaries.n_cols; - Col t_accs(n_threads, fill::zeros); + Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) @@ -1911,10 +1911,10 @@ gmm_diag::generate_initial_params(const Mat& X, const eT var_floor) // as the covariances are calculated via accumulators, // the means also need to be calculated via accumulators to ensure numerical consistency - Mat acc_means(N_dims, N_gaus, fill::zeros); - Mat acc_dcovs(N_dims, N_gaus, fill::zeros); + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Mat acc_dcovs(N_dims, N_gaus, arma_zeros_indicator()); - Row acc_hefts(N_gaus, fill::zeros); + Row acc_hefts(N_gaus, arma_zeros_indicator()); uword* acc_hefts_mem = acc_hefts.memptr(); @@ -2072,9 +2072,9 @@ gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verb const eT* mah_aux_mem = mah_aux.memptr(); - Mat acc_means(N_dims, N_gaus, fill::zeros); - Row acc_hefts(N_gaus, fill::zeros); - Row last_indx(N_gaus, fill::zeros); + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Row acc_hefts( N_gaus, arma_zeros_indicator()); + Row last_indx( N_gaus, arma_zeros_indicator()); Mat new_means = means; Mat old_means = means; @@ -2157,6 +2157,10 @@ gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verb } #else { + acc_hefts.zeros(); + acc_means.zeros(); + last_indx.zeros(); + uword* acc_hefts_mem = acc_hefts.memptr(); uword* last_indx_mem = last_indx.memptr(); @@ -2273,7 +2277,7 @@ gmm_diag::km_iterate(const Mat& X, const uword max_iter, const bool verb access::rw(means) = old_means; - if(means.is_finite() == false) { return false; } + if(means.internal_has_nonfinite()) { return false; } return true; } @@ -2314,7 +2318,7 @@ gmm_diag::em_iterate(const Mat& X, const uword max_iter, const eT var_fl field< Col > t_acc_norm_lhoods(n_threads); field< Col > t_gaus_log_lhoods(n_threads); - Col t_progress_log_lhood(n_threads); + Col t_progress_log_lhood(n_threads, arma_nozeros_indicator()); for(uword t=0; t::em_iterate(const Mat& X, const uword max_iter, const eT var_fl if(any(vectorise(dcovs) <= eT(0))) { return false; } - if(means.is_finite() == false ) { return false; } - if(dcovs.is_finite() == false ) { return false; } - if(hefts.is_finite() == false ) { return false; } + if(means.internal_has_nonfinite()) { return false; } + if(dcovs.internal_has_nonfinite()) { return false; } + if(hefts.internal_has_nonfinite()) { return false; } return true; } diff --git a/src/armadillo_bits/gmm_full_bones.hpp b/src/armadillo_bits/gmm_full_bones.hpp index cfdc1c60..a842a62f 100644 --- a/src/armadillo_bits/gmm_full_bones.hpp +++ b/src/armadillo_bits/gmm_full_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -62,11 +64,11 @@ class gmm_full inline Col generate() const; inline Mat generate(const uword N) const; - template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = 0) const; - template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = 0) const; + template inline eT log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; + template inline eT log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk2 = nullptr) const; - template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = 0) const; - template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = 0) const; + template inline Row log_p(const T1& expr, const gmm_empty_arg& junk1 = gmm_empty_arg(), typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; + template inline Row log_p(const T1& expr, const uword gaus_id, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk2 = nullptr) const; template inline eT sum_log_p(const Base& expr) const; template inline eT sum_log_p(const Base& expr, const uword gaus_id) const; @@ -74,8 +76,8 @@ class gmm_full template inline eT avg_log_p(const Base& expr) const; template inline eT avg_log_p(const Base& expr, const uword gaus_id) const; - template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = 0) const; - template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = 0) const; + template inline uword assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == true ))>::result* junk = nullptr) const; + template inline urowvec assign(const T1& expr, const gmm_dist_mode& dist, typename enable_if<((is_arma_type::value) && (resolves_to_colvector::value == false))>::result* junk = nullptr) const; template inline urowvec raw_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; template inline Row norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) const; diff --git a/src/armadillo_bits/gmm_full_meat.hpp b/src/armadillo_bits/gmm_full_meat.hpp index d003b5df..5bbcce06 100644 --- a/src/armadillo_bits/gmm_full_meat.hpp +++ b/src/armadillo_bits/gmm_full_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -149,9 +151,9 @@ gmm_full::set_params(const Base& in_means_expr, const BaseCube "gmm_full::set_params(): given parameters have inconsistent and/or wrong sizes" ); - arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_params(): given means have non-finite values" ); - arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_params(): given fcovs have non-finite values" ); - arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_params(): given hefts have non-finite values" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_full::set_params(): given means have non-finite values" ); + arma_debug_check( (in_fcovs.internal_has_nonfinite()), "gmm_full::set_params(): given fcovs have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_full::set_params(): given hefts have non-finite values" ); for(uword g=0; g < in_fcovs.n_slices; ++g) { @@ -186,7 +188,7 @@ gmm_full::set_means(const Base& in_means_expr) const Mat& in_means = tmp.M; arma_debug_check( (arma::size(in_means) != arma::size(means)), "gmm_full::set_means(): given means have incompatible size" ); - arma_debug_check( (in_means.is_finite() == false), "gmm_full::set_means(): given means have non-finite values" ); + arma_debug_check( (in_means.internal_has_nonfinite()), "gmm_full::set_means(): given means have non-finite values" ); access::rw(means) = in_means; } @@ -206,7 +208,7 @@ gmm_full::set_fcovs(const BaseCube& in_fcovs_expr) const Cube& in_fcovs = tmp.M; arma_debug_check( (arma::size(in_fcovs) != arma::size(fcovs)), "gmm_full::set_fcovs(): given fcovs have incompatible size" ); - arma_debug_check( (in_fcovs.is_finite() == false), "gmm_full::set_fcovs(): given fcovs have non-finite values" ); + arma_debug_check( (in_fcovs.internal_has_nonfinite()), "gmm_full::set_fcovs(): given fcovs have non-finite values" ); for(uword i=0; i < in_fcovs.n_slices; ++i) { @@ -233,7 +235,7 @@ gmm_full::set_hefts(const Base& in_hefts_expr) const Mat& in_hefts = tmp.M; arma_debug_check( (arma::size(in_hefts) != arma::size(hefts)), "gmm_full::set_hefts(): given hefts have incompatible size" ); - arma_debug_check( (in_hefts.is_finite() == false), "gmm_full::set_hefts(): given hefts have non-finite values" ); + arma_debug_check( (in_hefts.internal_has_nonfinite()), "gmm_full::set_hefts(): given hefts have non-finite values" ); arma_debug_check( (any(vectorise(in_hefts) < eT(0))), "gmm_full::set_hefts(): given hefts have negative values" ); const eT s = accu(in_hefts); @@ -291,7 +293,7 @@ gmm_full::load(const std::string name) if( (status == false) || (storage.n_elem < 2) ) { reset(); - arma_debug_warn("gmm_full::load(): problem with loading or incompatible format"); + arma_debug_warn_level(3, "gmm_full::load(): problem with loading or incompatible format"); return false; } @@ -306,7 +308,7 @@ gmm_full::load(const std::string name) if( (storage.n_elem != (N_gaus + 2)) || (storage_hefts.n_rows != 1) || (storage_hefts.n_cols != N_gaus) ) { reset(); - arma_debug_warn("gmm_full::load(): incompatible format"); + arma_debug_warn_level(3, "gmm_full::load(): incompatible format"); return false; } @@ -322,7 +324,7 @@ gmm_full::load(const std::string name) if( (storage_fcov.n_rows != N_dims) || (storage_fcov.n_cols != N_dims) ) { reset(); - arma_debug_warn("gmm_full::load(): incompatible format"); + arma_debug_warn_level(3, "gmm_full::load(): incompatible format"); return false; } @@ -374,8 +376,8 @@ gmm_full::generate() const const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; - Col out( (N_gaus > 0) ? N_dims : uword(0) ); - Col tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn ); + Col out( (N_gaus > 0) ? N_dims : uword(0), arma_nozeros_indicator() ); + Col tmp( (N_gaus > 0) ? N_dims : uword(0), fill::randn ); if(N_gaus > 0) { @@ -410,8 +412,8 @@ gmm_full::generate(const uword N_vec) const const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; - Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec ); - Mat tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn ); + Mat out( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, arma_nozeros_indicator() ); + Mat tmp( ( (N_gaus > 0) ? N_dims : uword(0) ), N_vec, fill::randn ); if(N_gaus > 0) { @@ -682,7 +684,7 @@ gmm_full::norm_hist(const Base& expr, const gmm_dist_mode& dist_mode) if(acc == eT(0)) { acc = eT(1); } - Row out(hist_n_elem); + Row out(hist_n_elem, arma_nozeros_indicator()); eT* out_mem = out.memptr(); @@ -727,8 +729,8 @@ gmm_full::learn const unwrap tmp_X(data.get_ref()); const Mat& X = tmp_X.M; - if(X.is_empty() ) { arma_debug_warn("gmm_full::learn(): given matrix is empty" ); return false; } - if(X.is_finite() == false) { arma_debug_warn("gmm_full::learn(): given matrix has non-finite values"); return false; } + if(X.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix is empty" ); return false; } + if(X.internal_has_nonfinite()) { arma_debug_warn_level(3, "gmm_full::learn(): given matrix has non-finite values"); return false; } if(N_gaus == 0) { reset(); return true; } @@ -757,14 +759,14 @@ gmm_full::learn if(seed_mode == keep_existing) { - if(means.is_empty() ) { arma_debug_warn("gmm_full::learn(): no existing means" ); return false; } - if(X.n_rows != means.n_rows) { arma_debug_warn("gmm_full::learn(): dimensionality mismatch"); return false; } + if(means.is_empty() ) { arma_debug_warn_level(3, "gmm_full::learn(): no existing means" ); return false; } + if(X.n_rows != means.n_rows) { arma_debug_warn_level(3, "gmm_full::learn(): dimensionality mismatch"); return false; } // TODO: also check for number of vectors? } else { - if(X.n_cols < N_gaus) { arma_debug_warn("gmm_full::learn(): number of vectors is less than number of gaussians"); return false; } + if(X.n_cols < N_gaus) { arma_debug_warn_level(3, "gmm_full::learn(): number of vectors is less than number of gaussians"); return false; } reset(X.n_rows, N_gaus); @@ -788,7 +790,7 @@ gmm_full::learn stream_state.restore(get_cout_stream()); - if(status == false) { arma_debug_warn("gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } + if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): k-means algorithm failed; not enough data, or too many gaussians requested"); init(orig); return false; } } @@ -815,7 +817,7 @@ gmm_full::learn stream_state.restore(get_cout_stream()); - if(status == false) { arma_debug_warn("gmm_full::learn(): EM algorithm failed"); init(orig); return false; } + if(status == false) { arma_debug_warn_level(3, "gmm_full::learn(): EM algorithm failed"); init(orig); return false; } } mah_aux.reset(); @@ -920,7 +922,7 @@ gmm_full::init_constants(const bool calc_chol) const uword N_dims = means.n_rows; const uword N_gaus = means.n_cols; - const eT tmp = (eT(N_dims)/eT(2)) * std::log(eT(2) * Datum::pi); + const eT tmp = (eT(N_dims)/eT(2)) * std::log(Datum::tau); // @@ -940,9 +942,9 @@ gmm_full::init_constants(const bool calc_chol) eT log_det_val = eT(0); eT log_det_sign = eT(0); - log_det(log_det_val, log_det_sign, fcov); + const bool log_det_status = log_det(log_det_val, log_det_sign, fcov); - const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); + const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); if(inv_ok && log_det_ok) { @@ -1030,12 +1032,12 @@ gmm_full::internal_gen_boundaries(const uword N) const const uword n_threads_avail = uword(omp_get_max_threads()); const uword n_threads = (n_threads_avail > 0) ? ( (n_threads_avail <= N) ? n_threads_avail : 1 ) : 1; #else - static const uword n_threads = 1; + static constexpr uword n_threads = 1; #endif // get_cout_stream() << "gmm_full::internal_gen_boundaries(): n_threads: " << n_threads << '\n'; - umat boundaries(2, n_threads); + umat boundaries(2, n_threads, arma_nozeros_indicator()); if(N > 0) { @@ -1143,7 +1145,7 @@ gmm_full::internal_vec_log_p(const Mat& X) const arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); - Row out(N_samples); + Row out(N_samples, arma_nozeros_indicator()); if(N_samples > 0) { @@ -1197,7 +1199,7 @@ gmm_full::internal_vec_log_p(const Mat& X, const uword gaus_id) const arma_debug_check( (X.n_rows != N_dims), "gmm_full::log_p(): incompatible dimensions" ); arma_debug_check( (gaus_id >= means.n_cols), "gmm_full::log_p(): specified gaussian is out of range" ); - Row out(N_samples); + Row out(N_samples, arma_nozeros_indicator()); if(N_samples > 0) { @@ -1258,7 +1260,7 @@ gmm_full::internal_sum_log_p(const Mat& X) const const uword n_threads = boundaries.n_cols; - Col t_accs(n_threads, fill::zeros); + Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) @@ -1315,7 +1317,7 @@ gmm_full::internal_sum_log_p(const Mat& X, const uword gaus_id) const const uword n_threads = boundaries.n_cols; - Col t_accs(n_threads, fill::zeros); + Col t_accs(n_threads, arma_zeros_indicator()); #pragma omp parallel for schedule(static) for(uword t=0; t < n_threads; ++t) @@ -1938,10 +1940,10 @@ gmm_full::generate_initial_params(const Mat& X, const eT var_floor) // as the covariances are calculated via accumulators, // the means also need to be calculated via accumulators to ensure numerical consistency - Mat acc_means(N_dims, N_gaus, fill::zeros); - Mat acc_dcovs(N_dims, N_gaus, fill::zeros); + Mat acc_means(N_dims, N_gaus); + Mat acc_dcovs(N_dims, N_gaus); - Row acc_hefts(N_gaus, fill::zeros); + Row acc_hefts(N_gaus, arma_zeros_indicator()); uword* acc_hefts_mem = acc_hefts.memptr(); @@ -2101,9 +2103,9 @@ gmm_full::km_iterate(const Mat& X, const uword max_iter, const bool verb const eT* mah_aux_mem = mah_aux.memptr(); - Mat acc_means(N_dims, N_gaus, fill::zeros); - Row acc_hefts(N_gaus, fill::zeros); - Row last_indx(N_gaus, fill::zeros); + Mat acc_means(N_dims, N_gaus, arma_zeros_indicator()); + Row acc_hefts( N_gaus, arma_zeros_indicator()); + Row last_indx( N_gaus, arma_zeros_indicator()); Mat new_means = means; Mat old_means = means; @@ -2186,6 +2188,10 @@ gmm_full::km_iterate(const Mat& X, const uword max_iter, const bool verb } #else { + acc_hefts.zeros(); + acc_means.zeros(); + last_indx.zeros(); + uword* acc_hefts_mem = acc_hefts.memptr(); uword* last_indx_mem = last_indx.memptr(); @@ -2302,7 +2308,7 @@ gmm_full::km_iterate(const Mat& X, const uword max_iter, const bool verb access::rw(means) = old_means; - if(means.is_finite() == false) { return false; } + if(means.internal_has_nonfinite()) { return false; } return true; } @@ -2341,7 +2347,7 @@ gmm_full::em_iterate(const Mat& X, const uword max_iter, const eT var_fl field< Col > t_acc_norm_lhoods(n_threads); field< Col > t_gaus_log_lhoods(n_threads); - Col t_progress_log_lhood(n_threads); + Col t_progress_log_lhood(n_threads, arma_nozeros_indicator()); for(uword t=0; t::em_iterate(const Mat& X, const uword max_iter, const eT var_fl if(any(vectorise(fcov.diag()) <= eT(0))) { return false; } } - if(means.is_finite() == false) { return false; } - if(fcovs.is_finite() == false) { return false; } - if(hefts.is_finite() == false) { return false; } + if(means.internal_has_nonfinite()) { return false; } + if(fcovs.internal_has_nonfinite()) { return false; } + if(hefts.internal_has_nonfinite()) { return false; } return true; } @@ -2476,7 +2482,7 @@ gmm_full::em_update_params eT* hefts_mem = access::rw(hefts).memptr(); - Mat mean_outer(N_dims, N_dims); + Mat mean_outer(N_dims, N_dims, arma_nozeros_indicator()); //// update each component without sanity checking @@ -2535,14 +2541,14 @@ gmm_full::em_update_params if(val < var_floor) { val = var_floor; } } - if(acc_fcov.is_finite() == false) { continue; } + if(acc_fcov.internal_has_nonfinite()) { continue; } eT log_det_val = eT(0); eT log_det_sign = eT(0); - log_det(log_det_val, log_det_sign, acc_fcov); + const bool log_det_status = log_det(log_det_val, log_det_sign, acc_fcov); - const bool log_det_ok = ( (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); + const bool log_det_ok = ( log_det_status && (arma_isfinite(log_det_val)) && (log_det_sign > eT(0)) ); const bool inv_ok = (log_det_ok) ? bool(auxlib::inv_sympd(mean_outer, acc_fcov)) : bool(false); // mean_outer is used as a junk matrix diff --git a/src/armadillo_bits/gmm_misc_bones.hpp b/src/armadillo_bits/gmm_misc_bones.hpp index c4519b4b..44507d41 100644 --- a/src/armadillo_bits/gmm_misc_bones.hpp +++ b/src/armadillo_bits/gmm_misc_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,37 +20,37 @@ //! @{ -struct gmm_dist_mode { const uword id; inline explicit gmm_dist_mode(const uword in_id) : id(in_id) {} }; +struct gmm_dist_mode { const uword id; inline constexpr explicit gmm_dist_mode(const uword in_id) : id(in_id) {} }; inline bool operator==(const gmm_dist_mode& a, const gmm_dist_mode& b) { return (a.id == b.id); } inline bool operator!=(const gmm_dist_mode& a, const gmm_dist_mode& b) { return (a.id != b.id); } -struct gmm_dist_eucl : public gmm_dist_mode { inline gmm_dist_eucl() : gmm_dist_mode(1) {} }; -struct gmm_dist_maha : public gmm_dist_mode { inline gmm_dist_maha() : gmm_dist_mode(2) {} }; -struct gmm_dist_prob : public gmm_dist_mode { inline gmm_dist_prob() : gmm_dist_mode(3) {} }; +struct gmm_dist_eucl : public gmm_dist_mode { inline constexpr gmm_dist_eucl() : gmm_dist_mode(1) {} }; +struct gmm_dist_maha : public gmm_dist_mode { inline constexpr gmm_dist_maha() : gmm_dist_mode(2) {} }; +struct gmm_dist_prob : public gmm_dist_mode { inline constexpr gmm_dist_prob() : gmm_dist_mode(3) {} }; -static const gmm_dist_eucl eucl_dist; -static const gmm_dist_maha maha_dist; -static const gmm_dist_prob prob_dist; +static constexpr gmm_dist_eucl eucl_dist; +static constexpr gmm_dist_maha maha_dist; +static constexpr gmm_dist_prob prob_dist; -struct gmm_seed_mode { const uword id; inline explicit gmm_seed_mode(const uword in_id) : id(in_id) {} }; +struct gmm_seed_mode { const uword id; inline constexpr explicit gmm_seed_mode(const uword in_id) : id(in_id) {} }; inline bool operator==(const gmm_seed_mode& a, const gmm_seed_mode& b) { return (a.id == b.id); } inline bool operator!=(const gmm_seed_mode& a, const gmm_seed_mode& b) { return (a.id != b.id); } -struct gmm_seed_keep_existing : public gmm_seed_mode { inline gmm_seed_keep_existing() : gmm_seed_mode(1) {} }; -struct gmm_seed_static_subset : public gmm_seed_mode { inline gmm_seed_static_subset() : gmm_seed_mode(2) {} }; -struct gmm_seed_static_spread : public gmm_seed_mode { inline gmm_seed_static_spread() : gmm_seed_mode(3) {} }; -struct gmm_seed_random_subset : public gmm_seed_mode { inline gmm_seed_random_subset() : gmm_seed_mode(4) {} }; -struct gmm_seed_random_spread : public gmm_seed_mode { inline gmm_seed_random_spread() : gmm_seed_mode(5) {} }; +struct gmm_seed_keep_existing : public gmm_seed_mode { inline constexpr gmm_seed_keep_existing() : gmm_seed_mode(1) {} }; +struct gmm_seed_static_subset : public gmm_seed_mode { inline constexpr gmm_seed_static_subset() : gmm_seed_mode(2) {} }; +struct gmm_seed_static_spread : public gmm_seed_mode { inline constexpr gmm_seed_static_spread() : gmm_seed_mode(3) {} }; +struct gmm_seed_random_subset : public gmm_seed_mode { inline constexpr gmm_seed_random_subset() : gmm_seed_mode(4) {} }; +struct gmm_seed_random_spread : public gmm_seed_mode { inline constexpr gmm_seed_random_spread() : gmm_seed_mode(5) {} }; -static const gmm_seed_keep_existing keep_existing; -static const gmm_seed_static_subset static_subset; -static const gmm_seed_static_spread static_spread; -static const gmm_seed_random_subset random_subset; -static const gmm_seed_random_spread random_spread; +static constexpr gmm_seed_keep_existing keep_existing; +static constexpr gmm_seed_static_subset static_subset; +static constexpr gmm_seed_static_spread static_spread; +static constexpr gmm_seed_random_subset random_subset; +static constexpr gmm_seed_random_spread random_spread; namespace gmm_priv @@ -99,7 +101,7 @@ struct distance {}; template struct distance { - arma_inline arma_hot static eT eval(const uword N, const eT* A, const eT* B, const eT*); + arma_inline static eT eval(const uword N, const eT* A, const eT* B, const eT*); }; @@ -107,7 +109,7 @@ struct distance template struct distance { - arma_inline arma_hot static eT eval(const uword N, const eT* A, const eT* B, const eT* C); + arma_inline static eT eval(const uword N, const eT* A, const eT* B, const eT* C); }; diff --git a/src/armadillo_bits/gmm_misc_meat.hpp b/src/armadillo_bits/gmm_misc_meat.hpp index 3c787f2b..3276b46f 100644 --- a/src/armadillo_bits/gmm_misc_meat.hpp +++ b/src/armadillo_bits/gmm_misc_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -60,7 +62,6 @@ running_mean_scalar::operator=(const running_mean_scalar& in) template -arma_hot inline void running_mean_scalar::operator() (const eT X) @@ -124,7 +125,6 @@ running_mean_scalar::mean() const template arma_inline -arma_hot eT distance::eval(const uword N, const eT* A, const eT* B, const eT*) { @@ -158,7 +158,6 @@ distance::eval(const uword N, const eT* A, const eT* B, const eT*) template arma_inline -arma_hot eT distance::eval(const uword N, const eT* A, const eT* B, const eT* C) { diff --git a/src/armadillo_bits/hdf5_misc.hpp b/src/armadillo_bits/hdf5_misc.hpp index 5da340aa..0dd4a7a1 100644 --- a/src/armadillo_bits/hdf5_misc.hpp +++ b/src/armadillo_bits/hdf5_misc.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -44,7 +46,7 @@ inline hid_t get_hdf5_type< unsigned char >() { - return arma_H5Tcopy(arma_H5T_NATIVE_UCHAR); + return H5Tcopy(H5T_NATIVE_UCHAR); } template<> @@ -52,7 +54,7 @@ inline hid_t get_hdf5_type< char >() { - return arma_H5Tcopy(arma_H5T_NATIVE_CHAR); + return H5Tcopy(H5T_NATIVE_CHAR); } template<> @@ -60,7 +62,7 @@ inline hid_t get_hdf5_type< short >() { - return arma_H5Tcopy(arma_H5T_NATIVE_SHORT); + return H5Tcopy(H5T_NATIVE_SHORT); } template<> @@ -68,7 +70,7 @@ inline hid_t get_hdf5_type< unsigned short >() { - return arma_H5Tcopy(arma_H5T_NATIVE_USHORT); + return H5Tcopy(H5T_NATIVE_USHORT); } template<> @@ -76,7 +78,7 @@ inline hid_t get_hdf5_type< int >() { - return arma_H5Tcopy(arma_H5T_NATIVE_INT); + return H5Tcopy(H5T_NATIVE_INT); } template<> @@ -84,7 +86,7 @@ inline hid_t get_hdf5_type< unsigned int >() { - return arma_H5Tcopy(arma_H5T_NATIVE_UINT); + return H5Tcopy(H5T_NATIVE_UINT); } template<> @@ -92,7 +94,7 @@ inline hid_t get_hdf5_type< long >() { - return arma_H5Tcopy(arma_H5T_NATIVE_LONG); + return H5Tcopy(H5T_NATIVE_LONG); } template<> @@ -100,35 +102,31 @@ inline hid_t get_hdf5_type< unsigned long >() { - return arma_H5Tcopy(arma_H5T_NATIVE_ULONG); + return H5Tcopy(H5T_NATIVE_ULONG); } +template<> +inline +hid_t +get_hdf5_type< long long >() + { + return H5Tcopy(H5T_NATIVE_LLONG); + } -#if defined(ARMA_USE_U64S64) && defined(ULLONG_MAX) - template<> - inline - hid_t - get_hdf5_type< long long >() - { - return arma_H5Tcopy(arma_H5T_NATIVE_LLONG); - } - - template<> - inline - hid_t - get_hdf5_type< unsigned long long >() - { - return arma_H5Tcopy(arma_H5T_NATIVE_ULLONG); - } -#endif - +template<> +inline +hid_t +get_hdf5_type< unsigned long long >() + { + return H5Tcopy(H5T_NATIVE_ULLONG); + } template<> inline hid_t get_hdf5_type< float >() { - return arma_H5Tcopy(arma_H5T_NATIVE_FLOAT); + return H5Tcopy(H5T_NATIVE_FLOAT); } template<> @@ -136,7 +134,7 @@ inline hid_t get_hdf5_type< double >() { - return arma_H5Tcopy(arma_H5T_NATIVE_DOUBLE); + return H5Tcopy(H5T_NATIVE_DOUBLE); } @@ -156,10 +154,10 @@ inline hid_t get_hdf5_type< std::complex >() { - hid_t type = arma_H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); + hid_t type = H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); - arma_H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), arma_H5T_NATIVE_FLOAT); - arma_H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), arma_H5T_NATIVE_FLOAT); + H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), H5T_NATIVE_FLOAT); + H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), H5T_NATIVE_FLOAT); return type; } @@ -171,10 +169,10 @@ inline hid_t get_hdf5_type< std::complex >() { - hid_t type = arma_H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); + hid_t type = H5Tcreate(H5T_COMPOUND, sizeof(hdf5_complex_t)); - arma_H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), arma_H5T_NATIVE_DOUBLE); - arma_H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), arma_H5T_NATIVE_DOUBLE); + H5Tinsert(type, "real", HOFFSET(hdf5_complex_t, real), H5T_NATIVE_DOUBLE); + H5Tinsert(type, "imag", HOFFSET(hdf5_complex_t, imag), H5T_NATIVE_DOUBLE); return type; } @@ -194,85 +192,77 @@ is_supported_arma_hdf5_type(hid_t datatype) // start with most likely used types: double, complex, float, complex search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type< std::complex >(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type< std::complex >(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } // remaining supported types: u8, s8, u16, s16, u32, s32, u64, s64, ulng_t, slng_t search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } - - #if defined(ARMA_USE_U64S64) - { - search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } - - search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } - } - #endif - - #if defined(ARMA_ALLOW_LONG) - { - search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } - - search_type = get_hdf5_type(); - is_equal = ( arma_H5Tequal(datatype, search_type) > 0 ); - arma_H5Tclose(search_type); - if (is_equal) { return true; } - } - #endif + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } + + search_type = get_hdf5_type(); + is_equal = ( H5Tequal(datatype, search_type) > 0 ); + H5Tclose(search_type); + if(is_equal) { return true; } return false; } @@ -304,16 +294,16 @@ hdf5_search_callback hdf5_search_info* search_info = (hdf5_search_info*) operator_data; // We are looking for datasets. - if (info->type == H5O_TYPE_DATASET) + if(info->type == H5O_TYPE_DATASET) { // Check type of dataset to see if we could even load it. - hid_t dataset = arma_H5Dopen(loc_id, name, H5P_DEFAULT); - hid_t datatype = arma_H5Dget_type(dataset); + hid_t dataset = H5Dopen(loc_id, name, H5P_DEFAULT); + hid_t datatype = H5Dget_type(dataset); const bool is_supported = is_supported_arma_hdf5_type(datatype); - arma_H5Tclose(datatype); - arma_H5Dclose(dataset); + H5Tclose(datatype); + H5Dclose(dataset); if(is_supported == false) { @@ -323,7 +313,7 @@ hdf5_search_callback // Now we have to check against our set of names. // Only check names which could be better. - for (size_t string_pos = 0; string_pos < search_info->best_match_position; ++string_pos) + for(size_t string_pos = 0; string_pos < search_info->best_match_position; ++string_pos) { // name is the full path (/path/to/dataset); names[string_pos] may be // "dataset", "/to/dataset", or "/path/to/dataset". @@ -333,48 +323,48 @@ hdf5_search_callback // Count the number of forward slashes in names[string_pos]. uword name_count = 0; - for (uword i = 0; i < search_info->names[string_pos].length(); ++i) + for(uword i = 0; i < search_info->names[string_pos].length(); ++i) { - if ((search_info->names[string_pos])[i] == '/') { ++name_count; } + if((search_info->names[string_pos])[i] == '/') { ++name_count; } } // Count the number of forward slashes in the full name. uword count = 0; const std::string str = std::string(name); - for (uword i = 0; i < str.length(); ++i) + for(uword i = 0; i < str.length(); ++i) { - if (str[i] == '/') { ++count; } + if(str[i] == '/') { ++count; } } // Is the full string the same? - if (str == search_info->names[string_pos]) + if(str == search_info->names[string_pos]) { // We found it exactly. - hid_t match_candidate = arma_H5Dopen(loc_id, name, H5P_DEFAULT); + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); - if (match_candidate < 0) + if(match_candidate < 0) { return -1; } // Ensure that the dataset is valid and of the correct dimensionality. - hid_t filespace = arma_H5Dget_space(match_candidate); - int num_dims = arma_H5Sget_simple_extent_ndims(filespace); + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); - if (num_dims <= search_info->num_dims) + if(num_dims <= search_info->num_dims) { // Valid dataset -- we'll keep it. // If we already have an existing match we have to close it. - if (search_info->best_match != -1) + if(search_info->best_match != -1) { - arma_H5Dclose(search_info->best_match); + H5Dclose(search_info->best_match); } search_info->best_match_position = string_pos; search_info->best_match = match_candidate; } - arma_H5Sclose(filespace); + H5Sclose(filespace); // There is no possibility of anything better, so terminate the search. return 1; } @@ -382,16 +372,16 @@ hdf5_search_callback // If we are asking for more slashes than we have, this can't be a match. // Skip to below, where we decide whether or not to keep it anyway based // on the exactness condition of the search. - if (count <= name_count) + if(count <= name_count) { size_t start_pos = (count == 0) ? 0 : std::string::npos; - while (count > 0) + while(count > 0) { // Move pointer to previous slash. start_pos = str.rfind('/', start_pos); // Break if we've run out of slashes. - if (start_pos == std::string::npos) { break; } + if(start_pos == std::string::npos) { break; } --count; } @@ -400,10 +390,10 @@ hdf5_search_callback const std::string substring = str.substr(start_pos); // Are they the same? - if (substring == search_info->names[string_pos]) + if(substring == search_info->names[string_pos]) { // We have found the object; it must be better than our existing match. - hid_t match_candidate = arma_H5Dopen(loc_id, name, H5P_DEFAULT); + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); // arma_check(match_candidate < 0, "Mat::load(): cannot open an HDF5 dataset"); @@ -414,31 +404,31 @@ hdf5_search_callback // Ensure that the dataset is valid and of the correct dimensionality. - hid_t filespace = arma_H5Dget_space(match_candidate); - int num_dims = arma_H5Sget_simple_extent_ndims(filespace); + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); - if (num_dims <= search_info->num_dims) + if(num_dims <= search_info->num_dims) { // Valid dataset -- we'll keep it. // If we already have an existing match we have to close it. - if (search_info->best_match != -1) + if(search_info->best_match != -1) { - arma_H5Dclose(search_info->best_match); + H5Dclose(search_info->best_match); } search_info->best_match_position = string_pos; search_info->best_match = match_candidate; } - arma_H5Sclose(filespace); + H5Sclose(filespace); } } // If they are not the same, but we have not found anything and we don't need an exact match, take this. - if ((search_info->exact == false) && (search_info->best_match == -1)) + if((search_info->exact == false) && (search_info->best_match == -1)) { - hid_t match_candidate = arma_H5Dopen(loc_id, name, H5P_DEFAULT); + hid_t match_candidate = H5Dopen(loc_id, name, H5P_DEFAULT); // arma_check(match_candidate < 0, "Mat::load(): cannot open an HDF5 dataset"); if(match_candidate < 0) @@ -446,16 +436,16 @@ hdf5_search_callback return -1; } - hid_t filespace = arma_H5Dget_space(match_candidate); - int num_dims = arma_H5Sget_simple_extent_ndims(filespace); + hid_t filespace = H5Dget_space(match_candidate); + int num_dims = H5Sget_simple_extent_ndims(filespace); - if (num_dims <= search_info->num_dims) + if(num_dims <= search_info->num_dims) { // Valid dataset -- we'll keep it. - search_info->best_match = arma_H5Dopen(loc_id, name, H5P_DEFAULT); + search_info->best_match = H5Dopen(loc_id, name, H5P_DEFAULT); } - arma_H5Sclose(filespace); + H5Sclose(filespace); } } } @@ -485,7 +475,7 @@ search_hdf5_file hdf5_search_info search_info = { names, num_dims, exact, -1, names.size() }; // We'll use the H5Ovisit to track potential entries. - herr_t status = arma_H5Ovisit(hdf5_file, H5_INDEX_NAME, H5_ITER_NATIVE, hdf5_search_callback, void_ptr(&search_info)); + herr_t status = H5Ovisit(hdf5_file, H5_INDEX_NAME, H5_ITER_NATIVE, hdf5_search_callback, void_ptr(&search_info)); // Return the best match; it will be -1 if there was a problem. return (status < 0) ? -1 : search_info.best_match; @@ -518,13 +508,13 @@ load_and_convert_hdf5 // u8 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -533,13 +523,13 @@ load_and_convert_hdf5 // s8 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -548,13 +538,13 @@ load_and_convert_hdf5 // u16 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -563,13 +553,13 @@ load_and_convert_hdf5 // s16 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -578,13 +568,13 @@ load_and_convert_hdf5 // u32 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -593,96 +583,88 @@ load_and_convert_hdf5 // s32 search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; } - #if defined(ARMA_USE_U64S64) + // u64 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) { - // u64 - search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); - - if(is_equal) - { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); - arrayops::convert(dest, v.memptr(), n_elem); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); - return status; - } - - - // s64 - search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); - - if(is_equal) - { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); - arrayops::convert(dest, v.memptr(), n_elem); + return status; + } + + + // s64 + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); - return status; - } + return status; } - #endif - #if defined(ARMA_ALLOW_LONG) + // ulng_t + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) { - // ulng_t - search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); - - if(is_equal) - { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); - arrayops::convert(dest, v.memptr(), n_elem); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); - return status; - } - - - // slng_t - search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); - - if(is_equal) - { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); - arrayops::convert(dest, v.memptr(), n_elem); + return status; + } + + + // slng_t + search_type = get_hdf5_type(); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); + + if(is_equal) + { + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + arrayops::convert(dest, v.memptr(), n_elem); - return status; - } + return status; } - #endif // float search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -691,13 +673,13 @@ load_and_convert_hdf5 // double search_type = get_hdf5_type(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { - Col v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert(dest, v.memptr(), n_elem); return status; @@ -706,8 +688,8 @@ load_and_convert_hdf5 // complex float search_type = get_hdf5_type< std::complex >(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { @@ -716,8 +698,8 @@ load_and_convert_hdf5 return -1; // can't read complex data into non-complex matrix/cube } - Col< std::complex > v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col< std::complex > v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert_cx(dest, v.memptr(), n_elem); return status; @@ -726,8 +708,8 @@ load_and_convert_hdf5 // complex double search_type = get_hdf5_type< std::complex >(); - is_equal = (arma_H5Tequal(datatype, search_type) > 0); - arma_H5Tclose(search_type); + is_equal = (H5Tequal(datatype, search_type) > 0); + H5Tclose(search_type); if(is_equal) { @@ -736,8 +718,8 @@ load_and_convert_hdf5 return -1; // can't read complex data into non-complex matrix/cube } - Col< std::complex > v(n_elem); - hid_t status = arma_H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); + Col< std::complex > v(n_elem, arma_nozeros_indicator()); + hid_t status = H5Dread(dataset, datatype, H5S_ALL, H5S_ALL, H5P_DEFAULT, void_ptr(v.memptr())); arrayops::convert_cx(dest, v.memptr(), n_elem); return status; @@ -751,7 +733,7 @@ load_and_convert_hdf5 struct hdf5_suspend_printing_errors { - #if defined(ARMA_PRINT_HDF5_ERRORS) + #if (ARMA_WARN_LEVEL >= 3) inline hdf5_suspend_printing_errors() {} @@ -765,16 +747,16 @@ struct hdf5_suspend_printing_errors hdf5_suspend_printing_errors() { // Save old error handler. - arma_H5Eget_auto(H5E_DEFAULT, &old_client_func, &old_client_data); + H5Eget_auto(H5E_DEFAULT, &old_client_func, &old_client_data); // Disable annoying HDF5 error messages. - arma_H5Eset_auto(H5E_DEFAULT, NULL, NULL); + H5Eset_auto(H5E_DEFAULT, NULL, NULL); } inline ~hdf5_suspend_printing_errors() { - arma_H5Eset_auto(H5E_DEFAULT, old_client_func, old_client_data); + H5Eset_auto(H5E_DEFAULT, old_client_func, old_client_data); } #endif diff --git a/src/armadillo_bits/hdf5_name.hpp b/src/armadillo_bits/hdf5_name.hpp new file mode 100644 index 00000000..8dd38ccf --- /dev/null +++ b/src/armadillo_bits/hdf5_name.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup diskio +//! @{ + + +namespace hdf5_opts + { + typedef unsigned int flag_type; + + struct opts + { + const flag_type flags; + + inline constexpr explicit opts(const flag_type in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const flag_type in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 0) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr flag_type flag_none = flag_type(0 ); + static constexpr flag_type flag_trans = flag_type(1u << 0); + static constexpr flag_type flag_append = flag_type(1u << 1); + static constexpr flag_type flag_replace = flag_type(1u << 2); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_trans : public opts { inline constexpr opts_trans() : opts(flag_trans ) {} }; + struct opts_append : public opts { inline constexpr opts_append() : opts(flag_append ) {} }; + struct opts_replace : public opts { inline constexpr opts_replace() : opts(flag_replace) {} }; + + static constexpr opts_none none; + static constexpr opts_trans trans; + static constexpr opts_append append; + static constexpr opts_replace replace; + } + + +struct hdf5_name + { + const std::string filename; + const std::string dsname; + const hdf5_opts::opts opts; + + inline + hdf5_name(const std::string& in_filename) + : filename(in_filename ) + , dsname (std::string() ) + , opts (hdf5_opts::none) + {} + + inline + hdf5_name(const std::string& in_filename, const std::string& in_dsname, const hdf5_opts::opts& in_opts = hdf5_opts::none) + : filename(in_filename) + , dsname (in_dsname ) + , opts (in_opts ) + {} + }; + + +//! @} diff --git a/src/armadillo_bits/include_hdf5.hpp b/src/armadillo_bits/include_hdf5.hpp index 9df3372e..a639f78a 100644 --- a/src/armadillo_bits/include_hdf5.hpp +++ b/src/armadillo_bits/include_hdf5.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,24 +17,29 @@ #if defined(ARMA_USE_HDF5) - #if !defined(ARMA_HDF5_INCLUDE_DIR) - #include + + #undef H5_USE_110_API + #define H5_USE_110_API + + #if defined(__has_include) + #if __has_include() + #include + #else + #undef ARMA_USE_HDF5 + #pragma message ("WARNING: use of HDF5 disabled; hdf5.h header not found") + #endif #else - #define ARMA_STR1(x) x - #define ARMA_STR2(x) ARMA_STR1(x) - - #define ARMA_HDF5_HEADER ARMA_STR2(ARMA_HDF5_INCLUDE_DIR)ARMA_STR2(hdf5.h) - - #include ARMA_INCFILE_WRAP(ARMA_HDF5_HEADER) - - #undef ARMA_STR1 - #undef ARMA_STR2 - #undef ARMA_HDF5_HEADER + #include #endif - - #if defined(H5_USE_16_API_DEFAULT) || defined(H5_USE_16_API) - #pragma message ("WARNING: disabling use of HDF5 due to its incompatible configuration") + + #if defined(H5_USE_16_API) || defined(H5_USE_16_API_DEFAULT) + #pragma message ("WARNING: use of HDF5 disabled; incompatible configuration: H5_USE_16_API or H5_USE_16_API_DEFAULT") #undef ARMA_USE_HDF5 - #undef ARMA_USE_HDF5_ALT #endif + + // // TODO + // #if defined(H5_USE_18_API) || defined(H5_USE_18_API_DEFAULT) + // #pragma message ("WARNING: detected possibly incompatible configuration of HDF5: H5_USE_18_API or H5_USE_18_API_DEFAULT") + // #endif + #endif diff --git a/src/armadillo_bits/include_superlu.hpp b/src/armadillo_bits/include_superlu.hpp index 577fc165..43fa0a73 100644 --- a/src/armadillo_bits/include_superlu.hpp +++ b/src/armadillo_bits/include_superlu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -52,8 +54,7 @@ #if defined(ARMA_USE_SUPERLU) - -#if defined(ARMA_USE_SUPERLU_HEADERS) || defined(ARMA_SUPERLU_INCLUDE_DIR) +#undef ARMA_SLU_HEADERS_FOUND // Since we need to suport float, double, cx_float and cx_double, // as well as preserve the sanity of the user, @@ -68,142 +69,173 @@ namespace arma { - namespace superlu { // slu_*defs.h has int typedefed to int_t. // I'll just write it as int for simplicity, where I can, but supermatrix.h needs int_t. typedef int int_t; - + } +} + +#if defined(ARMA_USE_SUPERLU_HEADERS) || defined(ARMA_SUPERLU_INCLUDE_DIR) + +namespace arma +{ +namespace superlu + { // Include supermatrix.h. This gives us SuperMatrix. // Put it in the slu namespace. // For versions of SuperLU I am familiar with, supermatrix.h does not include any other files. // Therefore, putting it in the superlu namespace is reasonably safe. // This same reasoning is true for superlu_enum_consts.h. + #undef ARMA_SLU_HEADER_A + #undef ARMA_SLU_HEADER_B + #if defined(ARMA_SUPERLU_INCLUDE_DIR) - #define ARMA_SLU_STR(x) x - #define ARMA_SLU_STR2(x) ARMA_SLU_STR(x) + #undef ARMA_SLU_STR1 + #undef ARMA_SLU_STR2 - #define ARMA_SLU_SUPERMATRIX_H ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(supermatrix.h) - #define ARMA_SLU_SUPERLU_ENUM_CONSTS_H ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(superlu_enum_consts.h) + #define ARMA_SLU_STR1(x) x + #define ARMA_SLU_STR2(x) ARMA_SLU_STR1(x) + + #define ARMA_SLU_HEADER_A ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(supermatrix.h) + #define ARMA_SLU_HEADER_B ARMA_SLU_STR2(ARMA_SUPERLU_INCLUDE_DIR)ARMA_SLU_STR2(superlu_enum_consts.h) #else - #define ARMA_SLU_SUPERMATRIX_H supermatrix.h - #define ARMA_SLU_SUPERLU_ENUM_CONSTS_H superlu_enum_consts.h + #define ARMA_SLU_HEADER_A supermatrix.h + #define ARMA_SLU_HEADER_B superlu_enum_consts.h #endif - #include ARMA_INCFILE_WRAP(ARMA_SLU_SUPERMATRIX_H) - #include ARMA_INCFILE_WRAP(ARMA_SLU_SUPERLU_ENUM_CONSTS_H) - - #undef ARMA_SLU_SUPERMATRIX_H - #undef ARMA_SLU_SUPERLU_ENUM_CONSTS_H - - - typedef struct - { - int* panel_histo; - double* utime; - float* ops; - int TinyPivots; - int RefineSteps; - int expansions; - } SuperLUStat_t; - - - typedef struct - { - fact_t Fact; - yes_no_t Equil; - colperm_t ColPerm; - trans_t Trans; - IterRefine_t IterRefine; - double DiagPivotThresh; - yes_no_t SymmetricMode; - yes_no_t PivotGrowth; - yes_no_t ConditionNumber; - rowperm_t RowPerm; - int ILU_DropRule; - double ILU_DropTol; - double ILU_FillFactor; - norm_t ILU_Norm; - double ILU_FillTol; - milu_t ILU_MILU; - double ILU_MILU_Dim; - yes_no_t ParSymbFact; - yes_no_t ReplaceTinyPivot; - yes_no_t SolveInitialized; - yes_no_t RefineInitialized; - yes_no_t PrintStat; - int nnzL, nnzU; - int num_lookaheads; - yes_no_t lookahead_etree; - yes_no_t SymPattern; - } superlu_options_t; - - - typedef struct - { - float for_lu; - float total_needed; - } mem_usage_t; - - - typedef struct e_node - { - int size; - void* mem; - } ExpHeader; - - - typedef struct - { - int size; - int used; - int top1; - int top2; - void* array; - } LU_stack_t; + #if defined(__has_include) + #if __has_include(ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A)) && __has_include(ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B)) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B) + #define ARMA_SLU_HEADERS_FOUND + #endif + #else + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_A) + #include ARMA_INCFILE_WRAP(ARMA_SLU_HEADER_B) + #define ARMA_SLU_HEADERS_FOUND + #endif + #undef ARMA_SLU_STR1 + #undef ARMA_SLU_STR2 + + #undef ARMA_SLU_HEADER_A + #undef ARMA_SLU_HEADER_B - typedef struct - { - int* xsup; - int* supno; - int* lsub; - int* xlsub; - void* lusup; - int* xlusup; - void* ucol; - int* usub; - int* xusub; - int nzlmax; - int nzumax; - int nzlumax; - int n; - LU_space_t MemModel; - int num_expansions; - ExpHeader* expanders; - LU_stack_t stack; - } GlobalLU_t; + #if defined(ARMA_SLU_HEADERS_FOUND) + + typedef struct + { + int* panel_histo; + double* utime; + float* ops; + int TinyPivots; + int RefineSteps; + int expansions; + } SuperLUStat_t; + + typedef struct + { + fact_t Fact; + yes_no_t Equil; + colperm_t ColPerm; + trans_t Trans; + IterRefine_t IterRefine; + double DiagPivotThresh; + yes_no_t SymmetricMode; + yes_no_t PivotGrowth; + yes_no_t ConditionNumber; + rowperm_t RowPerm; + int ILU_DropRule; + double ILU_DropTol; + double ILU_FillFactor; + norm_t ILU_Norm; + double ILU_FillTol; + milu_t ILU_MILU; + double ILU_MILU_Dim; + yes_no_t ParSymbFact; + yes_no_t ReplaceTinyPivot; + yes_no_t SolveInitialized; + yes_no_t RefineInitialized; + yes_no_t PrintStat; + int nnzL, nnzU; + int num_lookaheads; + yes_no_t lookahead_etree; + yes_no_t SymPattern; + } superlu_options_t; + + typedef struct + { + float for_lu; + float total_needed; + } mem_usage_t; + + typedef struct e_node + { + int size; + void* mem; + } ExpHeader; + + typedef struct + { + int size; + int used; + int top1; + int top2; + void* array; + } LU_stack_t; + + typedef struct + { + int* xsup; + int* supno; + int* lsub; + int* xlsub; + void* lusup; + int* xlusup; + void* ucol; + int* usub; + int* xusub; + int nzlmax; + int nzumax; + int nzlumax; + int n; + LU_space_t MemModel; + int num_expansions; + ExpHeader* expanders; + LU_stack_t stack; + } GlobalLU_t; + + #endif } } -#else +#endif + +#if defined(ARMA_USE_SUPERLU_HEADERS) && !defined(ARMA_SLU_HEADERS_FOUND) + #undef ARMA_USE_SUPERLU + #pragma message ("WARNING: use of SuperLU disabled; required headers not found") +#endif + +#endif + + + +#if defined(ARMA_USE_SUPERLU) && !defined(ARMA_SLU_HEADERS_FOUND) // Not using any SuperLU headers, so define all required enums and structs. -// -// CAVEAT: -// This code requires SuperLU version 5.2, -// and assumes that newer 5.x versions will have no API changes. + +#if defined(ARMA_SUPERLU_INCLUDE_DIR) + #pragma message ("WARNING: SuperLU headers not found; using built-in definitions") +#endif namespace arma { - namespace superlu { - typedef int int_t; - typedef enum { SLU_NC, @@ -216,7 +248,6 @@ namespace superlu SLU_NR_loc } Stype_t; - typedef enum { SLU_S, @@ -225,7 +256,6 @@ namespace superlu SLU_Z } Dtype_t; - typedef enum { SLU_GE, @@ -239,7 +269,6 @@ namespace superlu SLU_HEU } Mtype_t; - typedef struct { Stype_t Stype; @@ -250,7 +279,6 @@ namespace superlu void* Store; } SuperMatrix; - typedef struct { int* panel_histo; @@ -261,7 +289,6 @@ namespace superlu int expansions; } SuperLUStat_t; - typedef enum {NO, YES} yes_no_t; typedef enum {DOFACT, SamePattern, SamePattern_SameRowPerm, FACTORED} fact_t; typedef enum {NOROWPERM, LargeDiag, MY_PERMR} rowperm_t; @@ -273,7 +300,6 @@ namespace superlu typedef enum {ONE_NORM, TWO_NORM, INF_NORM} norm_t; typedef enum {SILU, SMILU_1, SMILU_2, SMILU_3} milu_t; - typedef struct { fact_t Fact; @@ -304,14 +330,12 @@ namespace superlu yes_no_t SymPattern; } superlu_options_t; - typedef struct { float for_lu; float total_needed; } mem_usage_t; - typedef struct { int_t nnz; @@ -320,21 +344,18 @@ namespace superlu int_t* colptr; } NCformat; - typedef struct { int_t lda; void* nzval; } DNformat; - typedef struct e_node { int size; void* mem; } ExpHeader; - typedef struct { int size; @@ -344,7 +365,6 @@ namespace superlu void* array; } LU_stack_t; - typedef struct { int* xsup; @@ -368,7 +388,6 @@ namespace superlu } } -#endif - +#undef ARMA_SLU_HEADERS_FOUND #endif diff --git a/src/armadillo_bits/injector_bones.hpp b/src/armadillo_bits/injector_bones.hpp index 70d89a8e..80e2a173 100644 --- a/src/armadillo_bits/injector_bones.hpp +++ b/src/armadillo_bits/injector_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,22 +21,6 @@ -template -class mat_injector_row - { - public: - - inline mat_injector_row(); - - inline void insert(const eT val) const; - - mutable uword n_cols; - mutable podarray A; - mutable podarray B; - }; - - - template class mat_injector { @@ -42,21 +28,20 @@ class mat_injector typedef typename T1::elem_type elem_type; - inline void insert(const elem_type val) const; - inline void end_of_row() const; - inline ~mat_injector(); + arma_cold inline void insert(const elem_type val) const; + arma_cold inline void end_of_row() const; + arma_cold inline ~mat_injector(); private: inline mat_injector(T1& in_X, const elem_type val); - inline mat_injector(T1& in_X, const injector_end_of_row<>& x); + inline mat_injector(T1& in_X, const injector_end_of_row<>&); - T1& X; - mutable uword n_rows; + T1& parent; - mutable podarray< mat_injector_row* >* AA; - mutable podarray< mat_injector_row* >* BB; + mutable std::vector values; + mutable std::vector rowend; friend class Mat; friend class Row; @@ -69,23 +54,6 @@ class mat_injector -template -class field_injector_row - { - public: - - inline field_injector_row(); - inline ~field_injector_row(); - - inline void insert(const oT& val) const; - - mutable uword n_cols; - mutable field* AA; - mutable field* BB; - }; - - - template class field_injector { @@ -93,21 +61,20 @@ class field_injector typedef typename T1::object_type object_type; - inline void insert(const object_type& val) const; - inline void end_of_row() const; - inline ~field_injector(); + arma_cold inline void insert(const object_type& val) const; + arma_cold inline void end_of_row() const; + arma_cold inline ~field_injector(); private: inline field_injector(T1& in_X, const object_type& val); - inline field_injector(T1& in_X, const injector_end_of_row<>& x); + inline field_injector(T1& in_X, const injector_end_of_row<>&); - T1& X; - mutable uword n_rows; + T1& parent; - mutable podarray< field_injector_row* >* AA; - mutable podarray< field_injector_row* >* BB; + mutable std::vector values; + mutable std::vector rowend; friend class field; }; diff --git a/src/armadillo_bits/injector_meat.hpp b/src/armadillo_bits/injector_meat.hpp index 32191cc2..81962c06 100644 --- a/src/armadillo_bits/injector_meat.hpp +++ b/src/armadillo_bits/injector_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,201 +21,158 @@ -template +template inline -mat_injector_row::mat_injector_row() - : n_cols(0) +mat_injector::mat_injector(T1& in_parent, const typename mat_injector::elem_type val) + : parent(in_parent) { arma_extra_debug_sigprint(); - A.set_size( podarray_prealloc_n_elem::val ); - } - - - -template -inline -void -mat_injector_row::insert(const eT val) const - { - arma_extra_debug_sigprint(); + values.reserve(16); + rowend.reserve(16); - if(n_cols < A.n_elem) - { - A[n_cols] = val; - ++n_cols; - } - else - { - B.set_size(2 * A.n_elem); - - arrayops::copy(B.memptr(), A.memptr(), n_cols); - - B[n_cols] = val; - ++n_cols; - - std::swap( access::rw(A.mem), access::rw(B.mem) ); - std::swap( access::rw(A.n_elem), access::rw(B.n_elem) ); - } + insert(val); } -// -// -// - - - template inline -mat_injector::mat_injector(T1& in_X, const typename mat_injector::elem_type val) - : X(in_X) - , n_rows(1) +mat_injector::mat_injector(T1& in_parent, const injector_end_of_row<>&) + : parent(in_parent) { arma_extra_debug_sigprint(); - typedef typename mat_injector::elem_type eT; + values.reserve(16); + rowend.reserve(16); - AA = new podarray< mat_injector_row* >; - BB = new podarray< mat_injector_row* >; - - podarray< mat_injector_row* >& A = *AA; - - A.set_size(n_rows); - - for(uword row=0; row; - } - - (*(A[0])).insert(val); + end_of_row(); } template inline -mat_injector::mat_injector(T1& in_X, const injector_end_of_row<>& x) - : X(in_X) - , n_rows(1) +mat_injector::~mat_injector() { arma_extra_debug_sigprint(); - arma_ignore(x); - typedef typename mat_injector::elem_type eT; + const uword N = values.size(); + + if(N == 0) { return; } - AA = new podarray< mat_injector_row* >; - BB = new podarray< mat_injector_row* >; + uword n_rows = 1; + uword n_cols = 0; - podarray< mat_injector_row* >& A = *AA; + for(uword i=0; i; + if(rowend[i]) + { + n_cols = (std::max)(n_cols, n_cols_in_row); + n_cols_in_row = 0; + } + else + { + ++n_cols_in_row; + } } - (*this).end_of_row(); - } - - - -template -inline -mat_injector::~mat_injector() - { - arma_extra_debug_sigprint(); + n_rows = (rowend[N-1]) ? (n_rows-1) : n_rows; + n_cols = (std::max)(n_cols, n_cols_in_row); - typedef typename mat_injector::elem_type eT; - - podarray< mat_injector_row* >& A = *AA; - - if(n_rows > 0) + if(is_Row::value) { - uword max_n_cols = (*(A[0])).n_cols; + arma_debug_check( (n_rows > 1), "matrix initialisation: incompatible dimensions" ); + + parent.zeros(1,n_cols); - for(uword row=1; row::value) + { + const bool is_vec = ((n_cols == 1) || (n_rows == 1)); - const uword max_n_rows = ((*(A[n_rows-1])).n_cols == 0) ? n_rows-1 : n_rows; + arma_debug_check( (is_vec == false), "matrix initialisation: incompatible dimensions" ); - if(is_Mat_only::value == true) + if(n_cols == 1) { - X.set_size(max_n_rows, max_n_cols); + parent.zeros(n_rows,1); + + uword row = 0; - for(uword row=0; row 0) && rowend[i-1]) { ++row; } } - - for(uword col=n_cols; col::value == true) - { - arma_debug_check( (max_n_rows > 1), "matrix initialisation: incompatible dimensions" ); - - const uword n_cols = (*(A[0])).n_cols; - - X.set_size(1, n_cols); - - arrayops::copy( X.memptr(), (*(A[0])).A.memptr(), n_cols ); - } - else - if(is_Col::value == true) + if(n_rows == 1) { - const bool is_vec = ( (max_n_rows == 1) || (max_n_cols == 1) ); - - arma_debug_check( (is_vec == false), "matrix initialisation: incompatible dimensions" ); + parent.zeros(n_cols,1); - const uword n_elem = (std::max)(max_n_rows, max_n_cols); + uword row = 0; - X.set_size(n_elem, 1); - - uword i = 0; - for(uword row=0; row::insert(const typename mat_injector::elem_type val) const { arma_extra_debug_sigprint(); - typedef typename mat_injector::elem_type eT; - - podarray< mat_injector_row* >& A = *AA; - - (*(A[n_rows-1])).insert(val); + values.push_back(val ); + rowend.push_back(char(0)); } @@ -244,27 +200,14 @@ mat_injector::end_of_row() const typedef typename mat_injector::elem_type eT; - podarray< mat_injector_row* >& A = *AA; - podarray< mat_injector_row* >& B = *BB; - - B.set_size( n_rows+1 ); - - arrayops::copy(B.memptr(), A.memptr(), n_rows); - - for(uword row=n_rows; row<(n_rows+1); ++row) - { - B[row] = new mat_injector_row; - } - - std::swap(AA, BB); - - n_rows += 1; + values.push_back( eT(0)); + rowend.push_back(char(1)); } template -arma_inline +inline const mat_injector& operator<<(const mat_injector& ref, const typename mat_injector::elem_type val) { @@ -278,12 +221,11 @@ operator<<(const mat_injector& ref, const typename mat_injector::elem_ty template -arma_inline +inline const mat_injector& -operator<<(const mat_injector& ref, const injector_end_of_row<>& x) +operator<<(const mat_injector& ref, const injector_end_of_row<>&) { arma_extra_debug_sigprint(); - arma_ignore(x); ref.end_of_row(); @@ -292,222 +234,87 @@ operator<<(const mat_injector& ref, const injector_end_of_row<>& x) -//// using a mixture of operator << and , doesn't work yet -//// e.g. A << 1, 2, 3 << endr -//// in the above "3 << endr" requires special handling. -//// similarly, special handling is necessary for "endr << 3" -//// -// template -// arma_inline -// const mat_injector& -// operator,(const mat_injector& ref, const typename mat_injector::elem_type val) -// { -// arma_extra_debug_sigprint(); -// -// ref.insert(val); -// -// return ref; -// } - - - -// template -// arma_inline -// const mat_injector& -// operator,(const mat_injector& ref, const injector_end_of_row<>& x) -// { -// arma_extra_debug_sigprint(); -// arma_ignore(x); -// -// ref.end_of_row(); -// -// return ref; -// } - - - - // // // -template -inline -field_injector_row::field_injector_row() - : n_cols(0) - { - arma_extra_debug_sigprint(); - - AA = new field; - BB = new field; - - field& A = *AA; - - A.set_size( field_prealloc_n_elem::val ); - } - - - -template -inline -field_injector_row::~field_injector_row() - { - arma_extra_debug_sigprint(); - - delete AA; - delete BB; - } - - - -template +template inline -void -field_injector_row::insert(const oT& val) const +field_injector::field_injector(T1& in_parent, const typename field_injector::object_type& val) + : parent(in_parent) { arma_extra_debug_sigprint(); - field& A = *AA; - field& B = *BB; - - if(n_cols < A.n_elem) - { - A[n_cols] = val; - ++n_cols; - } - else - { - B.set_size(2 * A.n_elem); - - for(uword i=0; i inline -field_injector::field_injector(T1& in_X, const typename field_injector::object_type& val) - : X(in_X) - , n_rows(1) +field_injector::field_injector(T1& in_parent, const injector_end_of_row<>&) + : parent(in_parent) { arma_extra_debug_sigprint(); - typedef typename field_injector::object_type oT; - - AA = new podarray< field_injector_row* >; - BB = new podarray< field_injector_row* >; - - podarray< field_injector_row* >& A = *AA; - - A.set_size(n_rows); - - for(uword row=0; row; - } - - (*(A[0])).insert(val); + end_of_row(); } template inline -field_injector::field_injector(T1& in_X, const injector_end_of_row<>& x) - : X(in_X) - , n_rows(1) +field_injector::~field_injector() { arma_extra_debug_sigprint(); - arma_ignore(x); - typedef typename field_injector::object_type oT; + const uword N = values.size(); + + if(N == 0) { return; } - AA = new podarray< field_injector_row* >; - BB = new podarray< field_injector_row* >; + uword n_rows = 1; + uword n_cols = 0; - podarray< field_injector_row* >& A = *AA; + for(uword i=0; i; + if(rowend[i]) + { + n_cols = (std::max)(n_cols, n_cols_in_row); + n_cols_in_row = 0; + } + else + { + ++n_cols_in_row; + } } - (*this).end_of_row(); - } - - - -template -inline -field_injector::~field_injector() - { - arma_extra_debug_sigprint(); + n_rows = (rowend[N-1]) ? (n_rows-1) : n_rows; + n_cols = (std::max)(n_cols, n_cols_in_row); - typedef typename field_injector::object_type oT; + parent.set_size(n_rows,n_cols); - podarray< field_injector_row* >& A = *AA; + uword row = 0; + uword col = 0; - if(n_rows > 0) + for(uword i=0; i& tmp = *((*(A[row])).AA); - X.at(row,col) = tmp[col]; - } - - for(uword col=n_cols; col::insert(const typename field_injector::object_type& val) { arma_extra_debug_sigprint(); - typedef typename field_injector::object_type oT; - - podarray< field_injector_row* >& A = *AA; - - (*(A[n_rows-1])).insert(val); + values.push_back(val ); + rowend.push_back(char(0)); } @@ -538,30 +342,14 @@ field_injector::end_of_row() const typedef typename field_injector::object_type oT; - podarray< field_injector_row* >& A = *AA; - podarray< field_injector_row* >& B = *BB; - - B.set_size( n_rows+1 ); - - for(uword row=0; row; - } - - std::swap(AA, BB); - - n_rows += 1; + values.push_back(oT() ); + rowend.push_back(char(1)); } template -arma_inline +inline const field_injector& operator<<(const field_injector& ref, const typename field_injector::object_type& val) { @@ -575,12 +363,11 @@ operator<<(const field_injector& ref, const typename field_injector::obj template -arma_inline +inline const field_injector& -operator<<(const field_injector& ref, const injector_end_of_row<>& x) +operator<<(const field_injector& ref, const injector_end_of_row<>&) { arma_extra_debug_sigprint(); - arma_ignore(x); ref.end_of_row(); diff --git a/src/armadillo_bits/memory.hpp b/src/armadillo_bits/memory.hpp index e0e6041b..ffa4d2c8 100644 --- a/src/armadillo_bits/memory.hpp +++ b/src/armadillo_bits/memory.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,10 +24,7 @@ class memory { public: - inline arma_deprecated static uword enlarge_to_mult_of_chunksize(const uword n_elem); - - template inline arma_malloc static eT* acquire(const uword n_elem); - template inline arma_deprecated static eT* acquire_chunked(const uword n_elem); + template arma_malloc inline static eT* acquire(const uword n_elem); template arma_inline static void release(eT* mem); @@ -36,24 +35,13 @@ class memory -//! no longer used; this function will be removed -inline -arma_deprecated -uword -memory::enlarge_to_mult_of_chunksize(const uword n_elem) - { - return n_elem; - } - - - template -inline arma_malloc +inline eT* memory::acquire(const uword n_elem) { - if(n_elem == 0) { return NULL; } + if(n_elem == 0) { return nullptr; } arma_debug_check ( @@ -63,7 +51,11 @@ memory::acquire(const uword n_elem) eT* out_memptr; - #if defined(ARMA_USE_TBB_ALLOC) + #if defined(ARMA_ALIEN_MEM_ALLOC_FUNCTION) + { + out_memptr = (eT *) ARMA_ALIEN_MEM_ALLOC_FUNCTION(sizeof(eT)*n_elem); + } + #elif defined(ARMA_USE_TBB_ALLOC) { out_memptr = (eT *) scalable_malloc(sizeof(eT)*n_elem); } @@ -73,7 +65,7 @@ memory::acquire(const uword n_elem) } #elif defined(ARMA_HAVE_POSIX_MEMALIGN) { - eT* memptr = NULL; + eT* memptr = nullptr; const size_t n_bytes = sizeof(eT)*size_t(n_elem); const size_t alignment = (n_bytes >= size_t(1024)) ? size_t(32) : size_t(16); @@ -81,10 +73,12 @@ memory::acquire(const uword n_elem) // TODO: investigate apparent memory leak when using alignment >= 64 (as shown on Fedora 28, glibc 2.27) int status = posix_memalign((void **)&memptr, ( (alignment >= sizeof(void*)) ? alignment : sizeof(void*) ), n_bytes); - out_memptr = (status == 0) ? memptr : NULL; + out_memptr = (status == 0) ? memptr : nullptr; } #elif defined(_MSC_VER) { + // Windoze is too primitive to handle C++17 std::aligned_alloc() + //out_memptr = (eT *) malloc(sizeof(eT)*n_elem); //out_memptr = (eT *) _aligned_malloc( sizeof(eT)*n_elem, 16 ); // lives in malloc.h @@ -102,33 +96,25 @@ memory::acquire(const uword n_elem) // TODO: for mingw, use __mingw_aligned_malloc - arma_check_bad_alloc( (out_memptr == NULL), "arma::memory::acquire(): out of memory" ); + arma_check_bad_alloc( (out_memptr == nullptr), "arma::memory::acquire(): out of memory" ); return out_memptr; } -//! no longer used; this function will be removed; replace with call to memory::acquire() -template -inline -arma_deprecated -eT* -memory::acquire_chunked(const uword n_elem) - { - return memory::acquire(n_elem); - } - - - template arma_inline void memory::release(eT* mem) { - if(mem == NULL) { return; } + if(mem == nullptr) { return; } - #if defined(ARMA_USE_TBB_ALLOC) + #if defined(ARMA_ALIEN_MEM_FREE_FUNCTION) + { + ARMA_ALIEN_MEM_FREE_FUNCTION( (void *)(mem) ); + } + #elif defined(ARMA_USE_TBB_ALLOC) { scalable_free( (void *)(mem) ); } @@ -196,6 +182,9 @@ memory::mark_as_aligned(eT*& mem) } #endif + // TODO: look into C++20 std::assume_aligned() + // TODO: https://en.cppreference.com/w/cpp/memory/assume_aligned + // TODO: MSVC? __assume( (mem & 0x0F) == 0 ); // // http://comments.gmane.org/gmane.comp.gcc.patches/239430 diff --git a/src/armadillo_bits/mp_misc.hpp b/src/armadillo_bits/mp_misc.hpp index 7024ffa6..b323ffeb 100644 --- a/src/armadillo_bits/mp_misc.hpp +++ b/src/armadillo_bits/mp_misc.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtGlueCube_bones.hpp b/src/armadillo_bits/mtGlueCube_bones.hpp index 727b9a8b..846d8050 100644 --- a/src/armadillo_bits/mtGlueCube_bones.hpp +++ b/src/armadillo_bits/mtGlueCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,7 @@ template -class mtGlueCube : public BaseCube > +class mtGlueCube : public BaseCube< out_eT, mtGlueCube > { public: diff --git a/src/armadillo_bits/mtGlueCube_meat.hpp b/src/armadillo_bits/mtGlueCube_meat.hpp index 29eab48e..dd27ecd4 100644 --- a/src/armadillo_bits/mtGlueCube_meat.hpp +++ b/src/armadillo_bits/mtGlueCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtGlue_bones.hpp b/src/armadillo_bits/mtGlue_bones.hpp index 1b0061a7..5937d89f 100644 --- a/src/armadillo_bits/mtGlue_bones.hpp +++ b/src/armadillo_bits/mtGlue_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class mtGlue : public Base > +class mtGlue : public Base< out_eT, mtGlue > { public: typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = glue_type::template traits::is_row; - static const bool is_col = glue_type::template traits::is_col; - static const bool is_xvec = glue_type::template traits::is_xvec; + static constexpr bool is_row = glue_type::template traits::is_row; + static constexpr bool is_col = glue_type::template traits::is_col; + static constexpr bool is_xvec = glue_type::template traits::is_xvec; arma_inline mtGlue(const T1& in_A, const T2& in_B); arma_inline mtGlue(const T1& in_A, const T2& in_B, const uword in_aux_uword); diff --git a/src/armadillo_bits/mtGlue_meat.hpp b/src/armadillo_bits/mtGlue_meat.hpp index 9c915a0f..cf3afccc 100644 --- a/src/armadillo_bits/mtGlue_meat.hpp +++ b/src/armadillo_bits/mtGlue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtOpCube_bones.hpp b/src/armadillo_bits/mtOpCube_bones.hpp index 5af556b5..ea9addfd 100644 --- a/src/armadillo_bits/mtOpCube_bones.hpp +++ b/src/armadillo_bits/mtOpCube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,15 +25,15 @@ struct mtOpCube_dual_aux_indicator {}; template -class mtOpCube : public BaseCube > +class mtOpCube : public BaseCube< out_eT, mtOpCube > { public: typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; - + typedef typename T1::elem_type in_eT; - + inline explicit mtOpCube(const T1& in_m); inline mtOpCube(const T1& in_m, const in_eT in_aux); inline mtOpCube(const T1& in_m, const uword in_aux_uword_a, const uword in_aux_uword_b, const uword in_aux_uword_c); diff --git a/src/armadillo_bits/mtOpCube_meat.hpp b/src/armadillo_bits/mtOpCube_meat.hpp index b9e140c4..9ce8b176 100644 --- a/src/armadillo_bits/mtOpCube_meat.hpp +++ b/src/armadillo_bits/mtOpCube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtOp_bones.hpp b/src/armadillo_bits/mtOp_bones.hpp index 10992c32..ff0e4c38 100644 --- a/src/armadillo_bits/mtOp_bones.hpp +++ b/src/armadillo_bits/mtOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,7 +24,7 @@ struct mtOp_dual_aux_indicator {}; template -class mtOp : public Base > +class mtOp : public Base< out_eT, mtOp > { public: @@ -31,9 +33,9 @@ class mtOp : public Base > typedef typename T1::elem_type in_eT; - static const bool is_row = op_type::template traits::is_row; - static const bool is_col = op_type::template traits::is_col; - static const bool is_xvec = op_type::template traits::is_xvec; + static constexpr bool is_row = op_type::template traits::is_row; + static constexpr bool is_col = op_type::template traits::is_col; + static constexpr bool is_xvec = op_type::template traits::is_xvec; inline explicit mtOp(const T1& in_m); inline mtOp(const T1& in_m, const in_eT in_aux); diff --git a/src/armadillo_bits/mtOp_meat.hpp b/src/armadillo_bits/mtOp_meat.hpp index 635032ea..c5b53b8a 100644 --- a/src/armadillo_bits/mtOp_meat.hpp +++ b/src/armadillo_bits/mtOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtSpGlue_bones.hpp b/src/armadillo_bits/mtSpGlue_bones.hpp index 3ea8b12c..3690914e 100644 --- a/src/armadillo_bits/mtSpGlue_bones.hpp +++ b/src/armadillo_bits/mtSpGlue_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class mtSpGlue : public SpBase > +class mtSpGlue : public SpBase< out_eT, mtSpGlue > { public: typedef out_eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = spglue_type::template traits::is_row; - static const bool is_col = spglue_type::template traits::is_col; - static const bool is_xvec = spglue_type::template traits::is_xvec; + static constexpr bool is_row = spglue_type::template traits::is_row; + static constexpr bool is_col = spglue_type::template traits::is_col; + static constexpr bool is_xvec = spglue_type::template traits::is_xvec; inline mtSpGlue(const T1& in_A, const T2& in_B); inline ~mtSpGlue(); diff --git a/src/armadillo_bits/mtSpGlue_meat.hpp b/src/armadillo_bits/mtSpGlue_meat.hpp index 397e28e8..41ede458 100644 --- a/src/armadillo_bits/mtSpGlue_meat.hpp +++ b/src/armadillo_bits/mtSpGlue_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mtSpOp_bones.hpp b/src/armadillo_bits/mtSpOp_bones.hpp index 4f7f688f..9c73727b 100644 --- a/src/armadillo_bits/mtSpOp_bones.hpp +++ b/src/armadillo_bits/mtSpOp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,7 +25,7 @@ template -class mtSpOp : public SpBase > +class mtSpOp : public SpBase< out_eT, mtSpOp > { public: @@ -32,9 +34,9 @@ class mtSpOp : public SpBase > typedef typename T1::elem_type in_eT; - static const bool is_row = spop_type::template traits::is_row; - static const bool is_col = spop_type::template traits::is_col; - static const bool is_xvec = spop_type::template traits::is_xvec; + static constexpr bool is_row = spop_type::template traits::is_row; + static constexpr bool is_col = spop_type::template traits::is_col; + static constexpr bool is_xvec = spop_type::template traits::is_xvec; inline explicit mtSpOp(const T1& in_m); inline mtSpOp(const T1& in_m, const uword aux_uword_a, const uword aux_uword_b); diff --git a/src/armadillo_bits/mtSpOp_meat.hpp b/src/armadillo_bits/mtSpOp_meat.hpp index ce145076..2273f088 100644 --- a/src/armadillo_bits/mtSpOp_meat.hpp +++ b/src/armadillo_bits/mtSpOp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/mul_gemm.hpp b/src/armadillo_bits/mul_gemm.hpp index d3cda651..27e31832 100644 --- a/src/armadillo_bits/mul_gemm.hpp +++ b/src/armadillo_bits/mul_gemm.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -189,7 +191,7 @@ class gemm_emul const TB& B, const eT alpha = eT(1), const eT beta = eT(0), - const typename arma_not_cx::result* junk = 0 + const typename arma_not_cx::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -212,7 +214,7 @@ class gemm_emul const Mat& B, const eT alpha = eT(1), const eT beta = eT(0), - const typename arma_cx_only::result* junk = 0 + const typename arma_cx_only::result* junk = nullptr ) { arma_extra_debug_sigprint(); @@ -237,8 +239,8 @@ class gemm_emul //! \brief -//! Wrapper for ATLAS/BLAS dgemm function, using template arguments to control the arguments passed to dgemm. -//! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) +//! Wrapper for BLAS dgemm function, using template arguments to control the arguments passed to dgemm. +//! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes) template class gemm @@ -261,7 +263,7 @@ class gemm } else { - Mat BB(B.n_rows, B.n_rows); + Mat BB(B.n_rows, B.n_rows, arma_nozeros_indicator()); op_strans::apply_mat_noalias_tinysq(BB, B); @@ -278,9 +280,9 @@ class gemm atlas::cblas_gemm ( - atlas::CblasColMajor, - (do_trans_A) ? ( is_cx::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, - (do_trans_B) ? ( is_cx::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + (do_trans_B) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, C.n_rows, C.n_cols, (do_trans_A) ? A.n_rows : A.n_cols, diff --git a/src/armadillo_bits/mul_gemm_mixed.hpp b/src/armadillo_bits/mul_gemm_mixed.hpp index d36095fd..749cdb17 100644 --- a/src/armadillo_bits/mul_gemm_mixed.hpp +++ b/src/armadillo_bits/mul_gemm_mixed.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,7 +24,7 @@ //! \brief //! Matrix multplication where the matrices have differing element types. //! Uses caching for speedup. -//! Matrix 'C' is assumed to have been set to the correct size (i.e. taking into account transposes) +//! Matrix 'C' is assumed to have been set to the correct size (ie. taking into account transposes) template class gemm_mixed_large diff --git a/src/armadillo_bits/mul_gemv.hpp b/src/armadillo_bits/mul_gemv.hpp index e9c2124c..2580e4ab 100644 --- a/src/armadillo_bits/mul_gemv.hpp +++ b/src/armadillo_bits/mul_gemv.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,9 +31,9 @@ class gemv_emul_tinysq template struct pos { - static const uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2); - static const uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3); - static const uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4); + static constexpr uword n2 = (do_trans_A == false) ? (row + col*2) : (col + row*2); + static constexpr uword n3 = (do_trans_A == false) ? (row + col*3) : (col + row*3); + static constexpr uword n4 = (do_trans_A == false) ? (row + col*4) : (col + row*4); }; @@ -207,8 +209,8 @@ class gemv_emul_helper //! \brief -//! Partial emulation of ATLAS/BLAS gemv(). -//! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) +//! Partial emulation of BLAS gemv(). +//! 'y' is assumed to have been set to the correct size (ie. taking into account the transpose) template class gemv_emul @@ -291,8 +293,8 @@ class gemv_emul //! \brief -//! Wrapper for ATLAS/BLAS gemv function, using template arguments to control the arguments passed to gemv. -//! 'y' is assumed to have been set to the correct size (i.e. taking into account the transpose) +//! Wrapper for BLAS gemv function, using template arguments to control the arguments passed to gemv. +//! 'y' is assumed to have been set to the correct size (ie. taking into account the transpose) template class gemv @@ -325,9 +327,9 @@ class gemv atlas::cblas_gemm ( - atlas::CblasColMajor, - (do_trans_A) ? ( is_cx::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, - atlas::CblasNoTrans, + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, + atlas_CblasNoTrans, (do_trans_A) ? A.n_cols : A.n_rows, 1, (do_trans_A) ? A.n_rows : A.n_cols, @@ -347,8 +349,8 @@ class gemv atlas::cblas_gemv ( - atlas::CblasColMajor, - (do_trans_A) ? ( is_cx::yes ? CblasConjTrans : atlas::CblasTrans ) : atlas::CblasNoTrans, + atlas_CblasColMajor, + (do_trans_A) ? ( is_cx::yes ? atlas_CblasConjTrans : atlas_CblasTrans ) : atlas_CblasNoTrans, A.n_rows, A.n_cols, (use_alpha) ? alpha : eT(1), diff --git a/src/armadillo_bits/mul_herk.hpp b/src/armadillo_bits/mul_herk.hpp index 55424406..e6b13b2b 100644 --- a/src/armadillo_bits/mul_herk.hpp +++ b/src/armadillo_bits/mul_herk.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -46,9 +48,9 @@ class herk_helper template - static arma_hot inline + static eT dot_conj_row(const uword n_elem, const eT* const A, const Mat& B, const uword row) { @@ -316,7 +318,7 @@ class herk inline static void - apply_blas_type( Mat >& C, const TA& A, const T alpha = T(1), const T beta = T(0) ) + apply_blas_type( Mat>& C, const TA& A, const T alpha = T(1), const T beta = T(0) ) { arma_extra_debug_sigprint(); @@ -324,7 +326,7 @@ class herk if(A.is_vec()) { - // work around poor handling of vectors by herk() in ATLAS 3.8.4 and standard BLAS + // work around poor handling of vectors by herk() in standard BLAS herk_vec::apply(C,A,alpha,beta); @@ -345,7 +347,7 @@ class herk typedef typename std::complex eT; // use a temporary matrix, as we can't assume that matrix C is already symmetric - Mat D(C.n_rows, C.n_cols); + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); herk::apply_blas_type(D,A,alpha); @@ -357,9 +359,9 @@ class herk atlas::cblas_herk ( - atlas::CblasColMajor, - atlas::CblasUpper, - (do_trans_A) ? CblasConjTrans : atlas::CblasNoTrans, + atlas_CblasColMajor, + atlas_CblasUpper, + (do_trans_A) ? atlas_CblasConjTrans : atlas_CblasNoTrans, C.n_cols, (do_trans_A) ? A.n_rows : A.n_cols, (use_alpha) ? alpha : T(1), @@ -379,7 +381,7 @@ class herk typedef typename std::complex eT; // use a temporary matrix, as we can't assume that matrix C is already symmetric - Mat D(C.n_rows, C.n_cols); + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); herk::apply_blas_type(D,A,alpha); @@ -436,7 +438,7 @@ class herk inline static void - apply( Mat& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx::result* junk = 0 ) + apply( Mat& C, const TA& A, const eT alpha = eT(1), const eT beta = eT(0), const typename arma_not_cx::result* junk = nullptr ) { arma_ignore(C); arma_ignore(A); diff --git a/src/armadillo_bits/mul_syrk.hpp b/src/armadillo_bits/mul_syrk.hpp index fd8ecf13..c2da3a27 100644 --- a/src/armadillo_bits/mul_syrk.hpp +++ b/src/armadillo_bits/mul_syrk.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -284,7 +286,7 @@ class syrk if(A.is_vec()) { - // work around poor handling of vectors by syrk() in ATLAS 3.8.4 and standard BLAS + // work around poor handling of vectors by syrk() in standard BLAS syrk_vec::apply(C,A,alpha,beta); @@ -304,7 +306,7 @@ class syrk if(use_beta == true) { // use a temporary matrix, as we can't assume that matrix C is already symmetric - Mat D(C.n_rows, C.n_cols); + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); syrk::apply_blas_type(D,A,alpha); @@ -316,9 +318,9 @@ class syrk atlas::cblas_syrk ( - atlas::CblasColMajor, - atlas::CblasUpper, - (do_trans_A) ? atlas::CblasTrans : atlas::CblasNoTrans, + atlas_CblasColMajor, + atlas_CblasUpper, + (do_trans_A) ? atlas_CblasTrans : atlas_CblasNoTrans, C.n_cols, (do_trans_A) ? A.n_rows : A.n_cols, (use_alpha) ? alpha : eT(1), @@ -336,7 +338,7 @@ class syrk if(use_beta == true) { // use a temporary matrix, as we can't assume that matrix C is already symmetric - Mat D(C.n_rows, C.n_cols); + Mat D(C.n_rows, C.n_cols, arma_nozeros_indicator()); syrk::apply_blas_type(D,A,alpha); diff --git a/src/armadillo_bits/newarp_DenseGenMatProd_bones.hpp b/src/armadillo_bits/newarp_DenseGenMatProd_bones.hpp index 03823014..90c3b5ad 100644 --- a/src/armadillo_bits/newarp_DenseGenMatProd_bones.hpp +++ b/src/armadillo_bits/newarp_DenseGenMatProd_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_DenseGenMatProd_meat.hpp b/src/armadillo_bits/newarp_DenseGenMatProd_meat.hpp index a9c02b99..89092452 100644 --- a/src/armadillo_bits/newarp_DenseGenMatProd_meat.hpp +++ b/src/armadillo_bits/newarp_DenseGenMatProd_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_DoubleShiftQR_bones.hpp b/src/armadillo_bits/newarp_DoubleShiftQR_bones.hpp index 28f43d0e..1599568d 100644 --- a/src/armadillo_bits/newarp_DoubleShiftQR_bones.hpp +++ b/src/armadillo_bits/newarp_DoubleShiftQR_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_DoubleShiftQR_meat.hpp b/src/armadillo_bits/newarp_DoubleShiftQR_meat.hpp index 0996ea42..1c7497d0 100644 --- a/src/armadillo_bits/newarp_DoubleShiftQR_meat.hpp +++ b/src/armadillo_bits/newarp_DoubleShiftQR_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -123,7 +125,7 @@ DoubleShiftQR::update_block(uword il, uword iu) // Apply the first reflector apply_PX(mat_H, il, il, 3, n - il, il); - apply_XP(mat_H, 0, il, il + std::min(bsize, uword(4)), 3, il); + apply_XP(mat_H, 0, il, il + (std::min)(bsize, uword(4)), 3, il); // Calculate the following reflectors // If entering this loop, block size is at least 4. @@ -132,7 +134,7 @@ DoubleShiftQR::update_block(uword il, uword iu) compute_reflector(mat_H.colptr(il + i - 1) + il + i, il + i); // Apply the reflector to X apply_PX(mat_H, il + i, il + i - 1, 3, n + 1 - il - i, il + i); - apply_XP(mat_H, 0, il + i, il + std::min(bsize, uword(i + 4)), 3, il + i); + apply_XP(mat_H, 0, il + i, il + (std::min)(bsize, uword(i + 4)), 3, il + i); } // The last reflector diff --git a/src/armadillo_bits/newarp_EigsSelect.hpp b/src/armadillo_bits/newarp_EigsSelect.hpp index b9ab73d6..d518c64b 100644 --- a/src/armadillo_bits/newarp_EigsSelect.hpp +++ b/src/armadillo_bits/newarp_EigsSelect.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_GenEigsSolver_bones.hpp b/src/armadillo_bits/newarp_GenEigsSolver_bones.hpp index 98d80efd..eabaf063 100644 --- a/src/armadillo_bits/newarp_GenEigsSolver_bones.hpp +++ b/src/armadillo_bits/newarp_GenEigsSolver_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,18 +25,18 @@ template class GenEigsSolver { protected: - - const OpType& op; // object to conduct matrix operation, e.g. matrix-vector product + + const OpType& op; // object to conduct matrix operation, eg. matrix-vector product const uword nev; // number of eigenvalues requested Col< std::complex > ritz_val; // ritz values - + // Sort the first nev Ritz pairs in decreasing magnitude order // This is used to return the final results virtual void sort_ritzpair(); - - + + private: - + const uword dim_n; // dimension of matrix A const uword ncv; // number of ritz values uword nmatop; // number of matrix operations called @@ -46,52 +48,56 @@ class GenEigsSolver Col< std::complex > ritz_est; // last row of ritz_vec std::vector ritz_conv; // indicator of the convergence of ritz values const eT eps; // the machine precision - // e.g. ~= 1e-16 for double type + // eg. ~= 1e-16 for double type const eT approx0; // a number that is approximately zero // approx0 = eps^(2/3) // used to test the orthogonality of vectors, // and in convergence test, tol*approx0 is // the absolute tolerance - + + std::mt19937_64 local_rng; // local random number generator + + inline void fill_rand(eT* dest, const uword N, const uword seed_val); + // Arnoldi factorisation starting from step-k inline void factorise_from(uword from_k, uword to_m, const Col& fk); - + // Implicitly restarted Arnoldi factorisation inline void restart(uword k); - + // Calculate the number of converged Ritz values inline uword num_converged(eT tol); - + // Return the adjusted nev for restarting inline uword nev_adjusted(uword nconv); - + // Retrieve and sort ritz values and ritz vectors inline void retrieve_ritzpair(); - - + + public: - + //! Constructor to create a solver object. inline GenEigsSolver(const OpType& op_, uword nev_, uword ncv_); - + //! Providing the initial residual vector for the algorithm. inline void init(eT* init_resid); - + //! Providing a random initial residual vector. inline void init(); - + //! Conducting the major computation procedure. inline uword compute(uword maxit = 1000, eT tol = 1e-10); - + //! Returning the number of iterations used in the computation. inline int num_iterations() { return niter; } - + //! Returning the number of matrix operations used in the computation. inline int num_operations() { return nmatop; } - + //! Returning the converged eigenvalues. inline Col< std::complex > eigenvalues(); - + //! Returning the eigenvectors associated with the converged eigenvalues. inline Mat< std::complex > eigenvectors(uword nvec); diff --git a/src/armadillo_bits/newarp_GenEigsSolver_meat.hpp b/src/armadillo_bits/newarp_GenEigsSolver_meat.hpp index c043ff7a..290fa4f2 100644 --- a/src/armadillo_bits/newarp_GenEigsSolver_meat.hpp +++ b/src/armadillo_bits/newarp_GenEigsSolver_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,6 +20,24 @@ namespace newarp { +template +inline +void +GenEigsSolver::fill_rand(eT* dest, const uword N, const uword seed_val) + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type seed_type; + + local_rng.seed( seed_type(seed_val) ); + + std::uniform_real_distribution dist(-1.0, +1.0); + + for(uword i=0; i < N; ++i) { dest[i] = eT(dist(local_rng)); } + } + + + template inline void @@ -29,7 +49,7 @@ GenEigsSolver::factorise_from(uword from_k, uword to_ fac_f = fk; - Col w(dim_n); + Col w(dim_n, arma_zeros_indicator()); eT beta = norm(fac_f); // Keep the upperleft k x k submatrix of H and set other elements to 0 fac_H.tail_cols(ncv - from_k).zeros(); @@ -42,12 +62,16 @@ GenEigsSolver::factorise_from(uword from_k, uword to_ // to the current V, which we call a restart if(beta < eps) { + // // Generate new random vector for fac_f + // blas_int idist = 2; + // blas_int iseed[4] = {1, 3, 5, 7}; + // iseed[0] = (i + 100) % 4095; + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + // Generate new random vector for fac_f - blas_int idist = 2; - blas_int iseed[4] = {1, 3, 5, 7}; - iseed[0] = (i + 100) % 4095; - blas_int n = dim_n; - lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + fill_rand(fac_f.memptr(), dim_n, i+1); + // f <- f - V * V' * f, so that f is orthogonal to V Mat Vs(fac_V.memptr(), dim_n, i, false); // First i columns Col Vf = Vs.t() * fac_f; @@ -151,10 +175,11 @@ GenEigsSolver::restart(uword k) fac_H.diag() += ritz_val(i).real(); } } + // V -> VQ // Q has some elements being zero // The first (ncv - k + i) elements of the i-th column of Q are non-zero - Mat Vs(dim_n, k + 1); + Mat Vs(dim_n, k + 1, arma_nozeros_indicator()); uword nnz; for(uword i = 0; i < k; i++) { @@ -186,7 +211,7 @@ GenEigsSolver::num_converged(eT tol) const eT f_norm = arma::norm(fac_f); for(uword i = 0; i < nev; i++) { - eT thresh = tol * std::max(approx0, std::abs(ritz_val(i))); + eT thresh = tol * (std::max)(approx0, std::abs(ritz_val(i))); eT resid = std::abs(ritz_est(i)) * f_norm; ritz_conv[i] = (resid < thresh); } @@ -210,7 +235,7 @@ GenEigsSolver::nev_adjusted(uword nconv) if(std::abs(ritz_est(i)) < eps) { nev_new++; } } // Adjust nev_new again, according to dnaup2.f line 660~674 in ARPACK - nev_new += std::min(nconv, (ncv - nev_new) / 2); + nev_new += (std::min)(nconv, (ncv - nev_new) / 2); if(nev_new == 1 && ncv >= 6) { nev_new = ncv / 2; @@ -278,8 +303,8 @@ GenEigsSolver::sort_ritzpair() std::vector ind = sorting.index(); - Col< std::complex > new_ritz_val(ncv); - Mat< std::complex > new_ritz_vec(ncv, nev); + Col< std::complex > new_ritz_val(ncv, arma_zeros_indicator() ); + Mat< std::complex > new_ritz_vec(ncv, nev, arma_nozeros_indicator()); std::vector new_ritz_conv(nev); for(uword i = 0; i < nev; i++) @@ -342,7 +367,7 @@ GenEigsSolver::init(eT* init_resid) arma_check( (rnorm < eps), "newarp::GenEigsSolver::init(): initial residual vector cannot be zero" ); v = r / rnorm; - Col w(dim_n); + Col w(dim_n, arma_zeros_indicator()); op.perform_op(v.memptr(), w.memptr()); nmatop++; @@ -359,11 +384,17 @@ GenEigsSolver::init() { arma_extra_debug_sigprint(); + // podarray init_resid(dim_n); + // blas_int idist = 2; // Uniform(-1, 1) + // blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + // init(init_resid.memptr()); + podarray init_resid(dim_n); - blas_int idist = 2; // Uniform(-1, 1) - blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed - blas_int n = dim_n; - lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + + fill_rand(init_resid.memptr(), dim_n, 0); + init(init_resid.memptr()); } @@ -394,7 +425,7 @@ GenEigsSolver::compute(uword maxit, eT tol) niter = i + 1; - return std::min(nev, nconv); + return (std::min)(nev, nconv); } @@ -407,7 +438,7 @@ GenEigsSolver::eigenvalues() arma_extra_debug_sigprint(); uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); - Col< std::complex > res(nconv); + Col< std::complex > res(nconv, arma_zeros_indicator()); if(nconv > 0) { @@ -435,12 +466,12 @@ GenEigsSolver::eigenvectors(uword nvec) arma_extra_debug_sigprint(); uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); - nvec = std::min(nvec, nconv); + nvec = (std::min)(nvec, nconv); Mat< std::complex > res(dim_n, nvec); if(nvec > 0) { - Mat< std::complex > ritz_vec_conv(ncv, nvec); + Mat< std::complex > ritz_vec_conv(ncv, nvec, arma_zeros_indicator()); uword j = 0; for(uword i = 0; (i < nev) && (j < nvec); i++) { diff --git a/src/armadillo_bits/newarp_SortEigenvalue.hpp b/src/armadillo_bits/newarp_SortEigenvalue.hpp index 4a9720d7..5f2c357e 100644 --- a/src/armadillo_bits/newarp_SortEigenvalue.hpp +++ b/src/armadillo_bits/newarp_SortEigenvalue.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_SparseGenMatProd_bones.hpp b/src/armadillo_bits/newarp_SparseGenMatProd_bones.hpp index 431a1ebe..2028aee2 100644 --- a/src/armadillo_bits/newarp_SparseGenMatProd_bones.hpp +++ b/src/armadillo_bits/newarp_SparseGenMatProd_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,6 +27,7 @@ class SparseGenMatProd private: const SpMat& op_mat; + SpMat op_mat_st; public: diff --git a/src/armadillo_bits/newarp_SparseGenMatProd_meat.hpp b/src/armadillo_bits/newarp_SparseGenMatProd_meat.hpp index e7e57e29..bbe539a1 100644 --- a/src/armadillo_bits/newarp_SparseGenMatProd_meat.hpp +++ b/src/armadillo_bits/newarp_SparseGenMatProd_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,6 +28,8 @@ SparseGenMatProd::SparseGenMatProd(const SpMat& mat_obj) , n_cols(mat_obj.n_cols) { arma_extra_debug_sigprint(); + + op_mat_st = op_mat.st(); // pre-calculate transpose } @@ -39,10 +43,20 @@ SparseGenMatProd::perform_op(eT* x_in, eT* y_out) const { arma_extra_debug_sigprint(); - const Col x(x_in , n_cols, false, true); - Col y(y_out, n_rows, false, true); + // // OLD METHOD + // + // const Col x(x_in , n_cols, false, true); + // Col y(y_out, n_rows, false, true); + // + // y = op_mat * x; + + + // NEW METHOD + + const Row x(x_in , n_cols, false, true); + Row y(y_out, n_rows, false, true); - y = op_mat * x; + y = x * op_mat_st; } diff --git a/src/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp b/src/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp new file mode 100644 index 00000000..a47575dd --- /dev/null +++ b/src/armadillo_bits/newarp_SparseGenRealShiftSolve_bones.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! Define matrix operations on existing matrix objects +template +class SparseGenRealShiftSolve + { + private: + + #if defined(ARMA_USE_SUPERLU) + // The following objects are read-only in perform_op() + mutable superlu_supermatrix_wrangler l; + mutable superlu_supermatrix_wrangler u; + mutable superlu_array_wrangler perm_c; + mutable superlu_array_wrangler perm_r; + #endif + + + public: + + bool valid = false; + + const uword n_rows; // number of rows of the underlying matrix + const uword n_cols; // number of columns of the underlying matrix + + inline SparseGenRealShiftSolve(const SpMat& mat_obj, const eT shift); + + inline void perform_op(eT* x_in, eT* y_out) const; + }; + + +} // namespace newarp diff --git a/src/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp b/src/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp new file mode 100644 index 00000000..ea206186 --- /dev/null +++ b/src/armadillo_bits/newarp_SparseGenRealShiftSolve_meat.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +SparseGenRealShiftSolve::SparseGenRealShiftSolve(const SpMat& mat_obj, const eT shift) + #if defined(ARMA_USE_SUPERLU) + : perm_c(mat_obj.n_cols + 1) + , perm_r(mat_obj.n_rows + 1) + , n_rows(mat_obj.n_rows) + , n_cols(mat_obj.n_cols) + #else + : n_rows(0) + , n_cols(0) + #endif + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + // Derived from sp_auxlib::run_aupd_shiftinvert() + superlu_opts superlu_opts_default; + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, superlu_opts_default); + + superlu::GlobalLU_t Glu; + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler xC; + superlu_array_wrangler etree(mat_obj.n_cols+1); + + // Copy A-shift*I to x + const bool status_x = sp_auxlib::copy_to_supermatrix_with_shift(x.get_ref(), mat_obj, shift); + + if(status_x == false) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::SparseGenRealShiftSolve(): could not construct SuperLU matrix"); return; } + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int slu_info = 0; // Return code + int lwork = 0; // lwork = 0: allocate space internally by system malloc + + superlu_stat_wrangler stat; + + arma_extra_debug_print("superlu::gstrf()"); + superlu::get_permutation_c(options.ColPerm, x.get_ptr(), perm_c.get_ptr()); + superlu::sp_preorder_mat(&options, x.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), xC.get_ptr()); + superlu::gstrf(&options, xC.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), &Glu, stat.get_ptr(), &slu_info); + + if(slu_info != 0) + { + arma_debug_warn_level(2, "matrix is singular to working precision"); + return; + } + + eT x_norm_val = sp_auxlib::norm1(x.get_ptr()); + eT x_rcond = sp_auxlib::lu_rcond(l.get_ptr(), u.get_ptr(), x_norm_val); + + if( (x_rcond < std::numeric_limits::epsilon()) || arma_isnan(x_rcond) ) + { + if(x_rcond == eT(0)) { arma_debug_warn_level(2, "matrix is singular to working precision"); } + else { arma_debug_warn_level(2, "matrix is singular to working precision (rcond: ", x_rcond, ")"); } + return; + } + + valid = true; + } + #else + { + arma_ignore(mat_obj); + arma_ignore(shift); + } + #endif + } + + + +// Perform the shift-solve operation \f$y=(A-\sigma I)^{-1}x\f$. +// y_out = inv(A - sigma * I) * x_in +template +inline +void +SparseGenRealShiftSolve::perform_op(eT* x_in, eT* y_out) const + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + const Col x(x_in , n_cols, false, true); + Col y(y_out, n_rows, false, true); + + // Derived from sp_auxlib::run_aupd_shiftinvert() + y = x; + superlu_supermatrix_wrangler out_slu; + + const bool status_out_slu = sp_auxlib::wrap_to_supermatrix(out_slu.get_ref(), y); + + if(status_out_slu == false) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::perform_op(): could not construct SuperLU matrix"); return; } + + superlu_stat_wrangler stat; + int info = 0; + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(superlu::NOTRANS, l.get_ptr(), u.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), out_slu.get_ptr(), stat.get_ptr(), &info); + + if(info != 0) { arma_stop_runtime_error("newarp::SparseGenRealShiftSolve::perform_op(): could not solve linear equation"); return; } + + // No need to modify memory further since it was all done in-place. + } + #else + { + arma_ignore(x_in); + arma_ignore(y_out); + } + #endif + } + + +} // namespace newarp diff --git a/src/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp b/src/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp new file mode 100644 index 00000000..bf3231f6 --- /dev/null +++ b/src/armadillo_bits/newarp_SymEigsShiftSolver_bones.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +//! This class implements the eigen solver for real symmetric matrices in the shift-and-invert mode. +template +class SymEigsShiftSolver : public SymEigsSolver + { + private: + + const eT sigma; + + // Sort the first nev Ritz pairs in ascending algebraic order + // This is used to return the final results + void sort_ritzpair(); + + + public: + + //! Constructor to create a solver object. + inline SymEigsShiftSolver(const OpType& op_, uword nev_, uword ncv_, const eT sigma_); + }; + + +} // namespace newarp diff --git a/src/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp b/src/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp new file mode 100644 index 00000000..bfb29132 --- /dev/null +++ b/src/armadillo_bits/newarp_SymEigsShiftSolver_meat.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +namespace newarp +{ + + +template +inline +void +SymEigsShiftSolver::sort_ritzpair() + { + arma_extra_debug_sigprint(); + + // First transform back the Ritz values, and then sort + for(uword i = 0; i < this->nev; i++) + { + this->ritz_val(i) = eT(1.0) / this->ritz_val(i) + sigma; + } + SymEigsSolver::sort_ritzpair(); + } + + + +template +inline +SymEigsShiftSolver::SymEigsShiftSolver(const OpType& op_, uword nev_, uword ncv_, const eT sigma_) + : SymEigsSolver::SymEigsSolver(op_, nev_, ncv_) + , sigma(sigma_) + { + arma_extra_debug_sigprint(); + } + + +} // namespace newarp diff --git a/src/armadillo_bits/newarp_SymEigsSolver_bones.hpp b/src/armadillo_bits/newarp_SymEigsSolver_bones.hpp index b2ecf34d..612f92a6 100644 --- a/src/armadillo_bits/newarp_SymEigsSolver_bones.hpp +++ b/src/armadillo_bits/newarp_SymEigsSolver_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,18 +25,18 @@ template class SymEigsSolver { protected: - - const OpType& op; // object to conduct matrix operation, e.g. matrix-vector product + + const OpType& op; // object to conduct matrix operation, eg. matrix-vector product const uword nev; // number of eigenvalues requested Col ritz_val; // ritz values - - // Sort the first nev Ritz pairs in decreasing magnitude order + + // Sort the first nev Ritz pairs in ascending algebraic order // This is used to return the final results virtual void sort_ritzpair(); - - + + private: - + const uword dim_n; // dimension of matrix A const uword ncv; // number of ritz values uword nmatop; // number of matrix operations called @@ -46,54 +48,57 @@ class SymEigsSolver Col ritz_est; // last row of ritz_vec std::vector ritz_conv; // indicator of the convergence of ritz values const eT eps; // the machine precision - // e.g. ~= 1e-16 for double type - const eT approx0; // a number that is approximately zero - // approx0 = eps^(2/3) - // used to test the orthogonality of vectors, - // and in convergence test, tol*approx0 is - // the absolute tolerance - + // eg. ~= 1e-16 for double type + const eT eps23; // eps^(2/3), used in convergence test + // tol*eps23 is the absolute tolerance + const eT near0; // a very small value, but 1/near0 does not overflow + + std::mt19937_64 local_rng; // local random number generator + + inline void fill_rand(eT* dest, const uword N, const uword seed_val); + // Arnoldi factorisation starting from step-k inline void factorise_from(uword from_k, uword to_m, const Col& fk); - + // Implicitly restarted Arnoldi factorisation inline void restart(uword k); - + // Calculate the number of converged Ritz values inline uword num_converged(eT tol); - + // Return the adjusted nev for restarting inline uword nev_adjusted(uword nconv); - + // Retrieve and sort ritz values and ritz vectors inline void retrieve_ritzpair(); - - + + public: - + //! Constructor to create a solver object. inline SymEigsSolver(const OpType& op_, uword nev_, uword ncv_); - + //! Providing the initial residual vector for the algorithm. inline void init(eT* init_resid); - + //! Providing a random initial residual vector. inline void init(); - + //! Conducting the major computation procedure. inline uword compute(uword maxit = 1000, eT tol = 1e-10); - + //! Returning the number of iterations used in the computation. inline uword num_iterations() { return niter; } - + //! Returning the number of matrix operations used in the computation. inline uword num_operations() { return nmatop; } - + //! Returning the converged eigenvalues. inline Col eigenvalues(); - + //! Returning the eigenvectors associated with the converged eigenvalues. inline Mat eigenvectors(uword nvec); + //! Returning all converged eigenvectors. inline Mat eigenvectors() { return eigenvectors(nev); } }; diff --git a/src/armadillo_bits/newarp_SymEigsSolver_meat.hpp b/src/armadillo_bits/newarp_SymEigsSolver_meat.hpp index 09a1ddc7..2223328c 100644 --- a/src/armadillo_bits/newarp_SymEigsSolver_meat.hpp +++ b/src/armadillo_bits/newarp_SymEigsSolver_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,6 +20,24 @@ namespace newarp { +template +inline +void +SymEigsSolver::fill_rand(eT* dest, const uword N, const uword seed_val) + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type seed_type; + + local_rng.seed( seed_type(seed_val) ); + + std::uniform_real_distribution dist(-1.0, +1.0); + + for(uword i=0; i < N; ++i) { dest[i] = eT(dist(local_rng)); } + } + + + template inline void @@ -29,8 +49,11 @@ SymEigsSolver::factorise_from(uword from_k, uword to_ fac_f = fk; - Col w(dim_n); - eT beta = norm(fac_f), Hii = 0.0; + Col w(dim_n, arma_zeros_indicator()); + // Norm of f + eT beta = norm(fac_f); + // Used to test beta~=0 + const eT beta_thresh = eps * eop_aux::sqrt(dim_n); // Keep the upperleft k x k submatrix of H and set other elements to 0 fac_H.tail_cols(ncv - from_k).zeros(); fac_H.submat(span(from_k, ncv - 1), span(0, from_k - 1)).zeros(); @@ -40,14 +63,18 @@ SymEigsSolver::factorise_from(uword from_k, uword to_ // If beta = 0, then the next V is not full rank // We need to generate a new residual vector that is orthogonal // to the current V, which we call a restart - if(beta < eps) + if(beta < near0) { + // // Generate new random vector for fac_f + // blas_int idist = 2; + // blas_int iseed[4] = {1, 3, 5, 7}; + // iseed[0] = (i + 100) % 4095; + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + // Generate new random vector for fac_f - blas_int idist = 2; - blas_int iseed[4] = {1, 3, 5, 7}; - iseed[0] = (i + 100) % 4095; - blas_int n = dim_n; - lapack::larnv(&idist, &iseed[0], &n, fac_f.memptr()); + fill_rand(fac_f.memptr(), dim_n, i+1); + // f <- f - V * V' * f, so that f is orthogonal to V Mat Vs(fac_V.memptr(), dim_n, i, false); // First i columns Col Vf = Vs.t() * fac_f; @@ -63,14 +90,14 @@ SymEigsSolver::factorise_from(uword from_k, uword to_ v = fac_f / beta; // Note that H[i+1, i] equals to the unrestarted beta - if(restart) { fac_H(i, i - 1) = 0.0; } else { fac_H(i, i - 1) = beta; } + fac_H(i, i - 1) = restart ? eT(0) : beta; // w <- A * v, v = fac_V.col(i) op.perform_op(v.memptr(), w.memptr()); nmatop++; - Hii = dot(v, w); fac_H(i - 1, i) = fac_H(i, i - 1); // Due to symmetry + eT Hii = dot(v, w); fac_H(i, i) = Hii; // f <- w - V * V' * w = w - H[i+1, i] * V{i} - H[i+1, i+1] * V{i+1} @@ -90,10 +117,23 @@ SymEigsSolver::factorise_from(uword from_k, uword to_ // whether V' * (f/||f||) ~= 0 Mat Vs(fac_V.memptr(), dim_n, i + 1, false); // First i+1 columns Col Vf = Vs.t() * fac_f; + eT ortho_err = abs(Vf).max(); // If not, iteratively correct the residual uword count = 0; - while(count < 5 && abs(Vf).max() > approx0 * beta) + while(count < 5 && ortho_err > eps * beta) { + // There is an edge case: when beta=||f|| is close to zero, f mostly consists + // of rounding errors, so the test [ortho_err < eps * beta] is very + // likely to fail. In particular, if beta=0, then the test is ensured to fail. + // Hence when this happens, we force f to be zero, and then restart in the + // next iteration. + if(beta < beta_thresh) + { + fac_f.zeros(); + beta = eT(0); + break; + } + // f <- f - V * Vf fac_f -= Vs * Vf; // h <- h + Vf @@ -104,6 +144,7 @@ SymEigsSolver::factorise_from(uword from_k, uword to_ beta = norm(fac_f); Vf = Vs.t() * fac_f; + ortho_err = abs(Vf).max(); count++; } } @@ -121,7 +162,7 @@ SymEigsSolver::restart(uword k) if(k >= ncv) { return; } TridiagQR decomp; - Mat Q = eye< Mat >(ncv, ncv); + Mat Q(ncv, ncv, fill::eye); for(uword i = k; i < ncv; i++) { @@ -142,22 +183,27 @@ SymEigsSolver::restart(uword k) // V -> VQ, only need to update the first k+1 columns // Q has some elements being zero // The first (ncv - k + i) elements of the i-th column of Q are non-zero - Mat Vs(dim_n, k + 1); + Mat Vs(dim_n, k + 1, arma_nozeros_indicator()); uword nnz; for(uword i = 0; i < k; i++) { nnz = ncv - k + i + 1; Mat V(fac_V.memptr(), dim_n, nnz, false); Col q(Q.colptr(i), nnz, false); - Vs.col(i) = V * q; + // OLD CODE: + // Vs.col(i) = V * q; + // NEW CODE: + Col v(Vs.colptr(i), dim_n, false, true); + v = V * q; } + Vs.col(k) = fac_V * Q.col(k); fac_V.head_cols(k + 1) = Vs; Col fk = fac_f * Q(ncv - 1, k - 1) + fac_V.col(k) * fac_H(k, k - 1); factorise_from(k, ncv, fk); retrieve_ritzpair(); -} + } @@ -172,7 +218,7 @@ SymEigsSolver::num_converged(eT tol) const eT f_norm = norm(fac_f); for(uword i = 0; i < nev; i++) { - eT thresh = tol * std::max(approx0, std::abs(ritz_val(i))); + eT thresh = tol * (std::max)(eps23, std::abs(ritz_val(i))); eT resid = std::abs(ritz_est(i)) * f_norm; ritz_conv[i] = (resid < thresh); } @@ -192,20 +238,18 @@ SymEigsSolver::nev_adjusted(uword nconv) uword nev_new = nev; for(uword i = nev; i < ncv; i++) { - if(std::abs(ritz_est(i)) < eps) { nev_new++; } + if(std::abs(ritz_est(i)) < near0) { nev_new++; } } // Adjust nev_new, according to dsaup2.f line 677~684 in ARPACK - nev_new += std::min(nconv, (ncv - nev_new) / 2); + nev_new += (std::min)(nconv, (ncv - nev_new) / 2); + if(nev_new >= ncv) { nev_new = ncv - 1; } - if(nev_new == 1 && ncv >= 6) - { - nev_new = ncv / 2; - } - else - if(nev_new == 1 && ncv > 2) + + if(nev_new == 1) { - nev_new = 2; + if(ncv >= 6) { nev_new = ncv / 2; } + else if(ncv > 2) { nev_new = 2; } } return nev_new; @@ -242,7 +286,8 @@ SymEigsSolver::retrieve_ritzpair() { // If i is even, pick values from the left (large values) // If i is odd, pick values from the right (small values) - if(i % 2 == 0) { ind[i] = ind_copy[i / 2]; } else { ind[i] = ind_copy[ncv - 1 - i / 2]; } + + ind[i] = (i % 2 == 0) ? ind_copy[i / 2] : ind_copy[ncv - 1 - i / 2]; } } @@ -269,13 +314,13 @@ SymEigsSolver::sort_ritzpair() // SortEigenvalue sorting(ritz_val.memptr(), nev); - // sort Ritz values in ascending algebraic, to be consistent with ARPACK + // Sort Ritz values in ascending algebraic, to be consistent with ARPACK SortEigenvalue sorting(ritz_val.memptr(), nev); std::vector ind = sorting.index(); - Col new_ritz_val(ncv); - Mat new_ritz_vec(ncv, nev); + Col new_ritz_val(ncv, arma_zeros_indicator() ); + Mat new_ritz_vec(ncv, nev, arma_nozeros_indicator()); std::vector new_ritz_conv(nev); for(uword i = 0; i < nev; i++) @@ -302,7 +347,8 @@ SymEigsSolver::SymEigsSolver(const OpType& op_, uword , nmatop(0) , niter(0) , eps(std::numeric_limits::epsilon()) - , approx0(std::pow(eps, eT(2.0) / 3)) + , eps23(std::pow(eps, eT(2.0) / 3)) + , near0(std::numeric_limits::min() * eT(10)) { arma_extra_debug_sigprint(); @@ -335,15 +381,19 @@ SymEigsSolver::init(eT* init_resid) // The first column of fac_V Col v(fac_V.colptr(0), dim_n, false); eT rnorm = norm(r); - arma_check( (rnorm < eps), "newarp::SymEigsSolver::init(): initial residual vector cannot be zero" ); + arma_check( (rnorm < near0), "newarp::SymEigsSolver::init(): initial residual vector cannot be zero" ); v = r / rnorm; - Col w(dim_n); + Col w(dim_n, arma_zeros_indicator()); op.perform_op(v.memptr(), w.memptr()); nmatop++; fac_H(0, 0) = dot(v, w); fac_f = w - v * fac_H(0, 0); + + // In some cases f is zero in exact arithmetics, but due to rounding errors + // it may contain tiny fluctuations. When this happens, we force f to be zero + if(abs(fac_f).max() < eps) { fac_f.zeros(); } } @@ -355,11 +405,17 @@ SymEigsSolver::init() { arma_extra_debug_sigprint(); + // podarray init_resid(dim_n); + // blas_int idist = 2; // Uniform(-1, 1) + // blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed + // blas_int n = dim_n; + // lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + // init(init_resid.memptr()); + podarray init_resid(dim_n); - blas_int idist = 2; // Uniform(-1, 1) - blas_int iseed[4] = {1, 3, 5, 7}; // Fixed random seed - blas_int n = dim_n; - lapack::larnv(&idist, &iseed[0], &n, init_resid.memptr()); + + fill_rand(init_resid.memptr(), dim_n, 0); + init(init_resid.memptr()); } @@ -390,7 +446,7 @@ SymEigsSolver::compute(uword maxit, eT tol) niter = i + 1; - return std::min(nev, nconv); + return (std::min)(nev, nconv); } @@ -403,18 +459,15 @@ SymEigsSolver::eigenvalues() arma_extra_debug_sigprint(); uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); - Col res(nconv); + Col res(nconv, arma_zeros_indicator()); if(nconv > 0) { uword j = 0; - for(uword i = 0; i < nev; i++) + + for(uword i=0; i < nev; i++) { - if(ritz_conv[i]) - { - res(j) = ritz_val(i); - j++; - } + if(ritz_conv[i]) { res(j) = ritz_val(i); j++; } } } @@ -431,21 +484,18 @@ SymEigsSolver::eigenvectors(uword nvec) arma_extra_debug_sigprint(); uword nconv = std::count(ritz_conv.begin(), ritz_conv.end(), true); - nvec = std::min(nvec, nconv); + nvec = (std::min)(nvec, nconv); Mat res(dim_n, nvec); if(nvec > 0) { - Mat ritz_vec_conv(ncv, nvec); + Mat ritz_vec_conv(ncv, nvec, arma_zeros_indicator()); uword j = 0; - for(uword i = 0; i < nev && j < nvec; i++) + + for(uword i=0; i < nev && j < nvec; i++) { - if(ritz_conv[i]) - { - ritz_vec_conv.col(j) = ritz_vec.col(i); - j++; - } + if(ritz_conv[i]) { ritz_vec_conv.col(j) = ritz_vec.col(i); j++; } } res = fac_V * ritz_vec_conv; diff --git a/src/armadillo_bits/newarp_TridiagEigen_bones.hpp b/src/armadillo_bits/newarp_TridiagEigen_bones.hpp index aa648bba..9664a3c5 100644 --- a/src/armadillo_bits/newarp_TridiagEigen_bones.hpp +++ b/src/armadillo_bits/newarp_TridiagEigen_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_TridiagEigen_meat.hpp b/src/armadillo_bits/newarp_TridiagEigen_meat.hpp index 0073fdc7..b11cfec4 100644 --- a/src/armadillo_bits/newarp_TridiagEigen_meat.hpp +++ b/src/armadillo_bits/newarp_TridiagEigen_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -59,37 +61,40 @@ TridiagEigen::compute(const Mat& mat_obj) evecs.set_size(n, n); char compz = 'I'; - blas_int lwork = blas_int(-1); - eT lwork_opt = eT(0); - - blas_int liwork = blas_int(-1); - blas_int liwork_opt = blas_int(0); + blas_int lwork_min = 1 + 4*n + n*n; + blas_int liwork_min = 3 + 5*n; blas_int info = blas_int(0); - // query for lwork and liwork - lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, &lwork_opt, &lwork, &liwork_opt, &liwork, &info); + blas_int lwork_proposed = 0; + blas_int liwork_proposed = 0; - if(info == 0) + if(n >= 32) { - lwork = blas_int(lwork_opt); - liwork = liwork_opt; - } - else - { - lwork = 1 + 4 * n + n * n; - liwork = 3 + 5 * n; + eT work_query[2] = {}; + blas_int lwork_query = blas_int(-1); + + blas_int iwork_query[2] = {}; + blas_int liwork_query = blas_int(-1); + + arma_extra_debug_print("lapack::stedc()"); + lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, &work_query[0], &lwork_query, &iwork_query[0], &liwork_query, &info); + + if(info != 0) { arma_stop_runtime_error("lapack::stedc(): couldn't get size of work arrays"); return; } + + lwork_proposed = static_cast( work_query[0] ); + liwork_proposed = iwork_query[0]; } - info = blas_int(0); + blas_int lwork = (std::max)( lwork_min, lwork_proposed); + blas_int liwork = (std::max)(liwork_min, liwork_proposed); - podarray work(static_cast(lwork) ); - podarray iwork(static_cast(liwork)); + podarray work( static_cast( lwork) ); + podarray iwork( static_cast(liwork) ); + arma_extra_debug_print("lapack::stedc()"); lapack::stedc(&compz, &n, main_diag.memptr(), sub_diag.memptr(), evecs.memptr(), &n, work.memptr(), &lwork, iwork.memptr(), &liwork, &info); - if(info < 0) { arma_stop_logic_error("lapack::stedc(): illegal value"); return; } - - if(info > 0) { arma_stop_runtime_error("lapack::stedc(): failed to compute all eigenvalues"); return; } + if(info != 0) { arma_stop_runtime_error("lapack::stedc(): failed to compute all eigenvalues"); return; } computed = true; } diff --git a/src/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp b/src/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp index b5569a67..668adbe5 100644 --- a/src/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp +++ b/src/armadillo_bits/newarp_UpperHessenbergEigen_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,36 +21,36 @@ namespace newarp //! Calculate the eigenvalues and eigenvectors of an upper Hessenberg matrix. -//! This class is a wrapper of the Lapack functions `_lahqr` and `_trevc`. +//! This class is uses lapack::lahqr() and lapack::trevc() template class UpperHessenbergEigen { private: - - blas_int n; + + uword n_rows; Mat mat_Z; // In the first stage, H = ZTZ', Z is an orthogonal matrix // In the second stage, Z will be overwritten by the eigenvectors of H Mat mat_T; // H = ZTZ', T is a Schur form matrix Col< std::complex > evals; // eigenvalues of H bool computed; - - + + public: - + //! Default constructor. Computation can //! be performed later by calling the compute() method. inline UpperHessenbergEigen(); - + //! Constructor to create an object that calculates the eigenvalues //! and eigenvectors of an upper Hessenberg matrix `mat_obj`. inline UpperHessenbergEigen(const Mat& mat_obj); - + //! Compute the eigenvalue decomposition of an upper Hessenberg matrix. inline void compute(const Mat& mat_obj); - + //! Retrieve the eigenvalues. inline Col< std::complex > eigenvalues(); - + //! Retrieve the eigenvectors. inline Mat< std::complex > eigenvectors(); }; diff --git a/src/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp b/src/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp index 15c8d294..a7582059 100644 --- a/src/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp +++ b/src/armadillo_bits/newarp_UpperHessenbergEigen_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,7 @@ namespace newarp template inline UpperHessenbergEigen::UpperHessenbergEigen() - : n(0) + : n_rows(0) , computed(false) { arma_extra_debug_sigprint(); @@ -32,7 +34,7 @@ UpperHessenbergEigen::UpperHessenbergEigen() template inline UpperHessenbergEigen::UpperHessenbergEigen(const Mat& mat_obj) - : n(mat_obj.n_rows) + : n_rows(mat_obj.n_rows) , computed(false) { arma_extra_debug_sigprint(); @@ -51,11 +53,11 @@ UpperHessenbergEigen::compute(const Mat& mat_obj) arma_debug_check( (mat_obj.is_square() == false), "newarp::UpperHessenbergEigen::compute(): matrix must be square" ); - n = blas_int(mat_obj.n_rows); + n_rows = mat_obj.n_rows; - mat_Z.set_size(n, n); - mat_T.set_size(n, n); - evals.set_size(n); + mat_Z.set_size(n_rows, n_rows); + mat_T.set_size(n_rows, n_rows); + evals.set_size(n_rows); mat_Z.eye(); mat_T = mat_obj; @@ -63,35 +65,34 @@ UpperHessenbergEigen::compute(const Mat& mat_obj) blas_int want_T = blas_int(1); blas_int want_Z = blas_int(1); + blas_int n = blas_int(n_rows); blas_int ilo = blas_int(1); - blas_int ihi = blas_int(n); + blas_int ihi = blas_int(n_rows); blas_int iloz = blas_int(1); - blas_int ihiz = blas_int(n); + blas_int ihiz = blas_int(n_rows); blas_int info = blas_int(0); - podarray wr(static_cast(n)); - podarray wi(static_cast(n)); - + podarray wr(n_rows); + podarray wi(n_rows); + arma_extra_debug_print("lapack::lahqr()"); lapack::lahqr(&want_T, &want_Z, &n, &ilo, &ihi, mat_T.memptr(), &n, wr.memptr(), wi.memptr(), &iloz, &ihiz, mat_Z.memptr(), &n, &info); - for(blas_int i = 0; i < n; i++) - { - evals(i) = std::complex(wr[i], wi[i]); - } + if(info != 0) { arma_stop_runtime_error("lapack::lahqr(): failed to compute all eigenvalues"); return; } - if(info > 0) { arma_stop_runtime_error("lapack::lahqr(): failed to compute all eigenvalues"); return; } + for(uword i=0; i < n_rows; i++) { evals(i) = std::complex(wr[i], wi[i]); } char side = 'R'; char howmny = 'B'; blas_int m = blas_int(0); - podarray work(static_cast(3 * n)); + podarray work(3*n); + arma_extra_debug_print("lapack::trevc()"); lapack::trevc(&side, &howmny, (blas_int*) NULL, &n, mat_T.memptr(), &n, (eT*) NULL, &n, mat_Z.memptr(), &n, &n, &m, work.memptr(), &info); - if(info < 0) { arma_stop_logic_error("lapack::trevc(): illegal value"); return; } + if(info != 0) { arma_stop_runtime_error("lapack::trevc(): illegal value"); return; } computed = true; } @@ -106,7 +107,7 @@ UpperHessenbergEigen::eigenvalues() arma_extra_debug_sigprint(); arma_debug_check( (computed == false), "newarp::UpperHessenbergEigen::eigenvalues(): need to call compute() first" ); - + return evals; } @@ -120,46 +121,46 @@ UpperHessenbergEigen::eigenvectors() arma_extra_debug_sigprint(); arma_debug_check( (computed == false), "newarp::UpperHessenbergEigen::eigenvectors(): need to call compute() first" ); - + // Lapack will set the imaginary parts of real eigenvalues to be exact zero - Mat< std::complex > evecs(n, n); + Mat< std::complex > evecs(n_rows, n_rows, arma_zeros_indicator()); std::complex* col_ptr = evecs.memptr(); - for(blas_int i = 0; i < n; i++) + for(uword i=0; i < n_rows; i++) { if(cx_attrib::is_real(evals(i), eT(0))) { // for real eigenvector, normalise and copy - eT z_norm = norm(mat_Z.col(i)); + const eT z_norm = norm(mat_Z.col(i)); - for(blas_int j = 0; j < n; j++) + for(uword j=0; j < n_rows; j++) { col_ptr[j] = std::complex(mat_Z(j, i) / z_norm, eT(0)); } - - col_ptr += n; + + col_ptr += n_rows; } else { // complex eigenvectors are stored in consecutive columns - eT r2 = dot(mat_Z.col(i), mat_Z.col(i)); - eT i2 = dot(mat_Z.col(i + 1), mat_Z.col(i + 1)); + const eT r2 = dot(mat_Z.col(i ), mat_Z.col(i )); + const eT i2 = dot(mat_Z.col(i+1), mat_Z.col(i+1)); - eT z_norm = std::sqrt(r2 + i2); - eT* z_ptr = mat_Z.colptr(i); + const eT z_norm = std::sqrt(r2 + i2); + const eT* z_ptr = mat_Z.colptr(i); - for(blas_int j = 0; j < n; j++) + for(uword j=0; j < n_rows; j++) { - col_ptr[j ] = std::complex(z_ptr[j] / z_norm, z_ptr[j + n] / z_norm); - col_ptr[j + n] = std::conj(col_ptr[j]); + col_ptr[j ] = std::complex(z_ptr[j] / z_norm, z_ptr[j + n_rows] / z_norm); + col_ptr[j + n_rows] = std::conj(col_ptr[j]); } - + i++; - col_ptr += 2 * n; + col_ptr += 2 * n_rows; } } - + return evecs; } diff --git a/src/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp b/src/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp index 02c646c0..4d07f8c0 100644 --- a/src/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp +++ b/src/armadillo_bits/newarp_UpperHessenbergQR_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp b/src/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp index 3d4cda8b..c3a6fa8c 100644 --- a/src/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp +++ b/src/armadillo_bits/newarp_UpperHessenbergQR_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -273,7 +275,7 @@ TridiagQR::matrix_RQ() arma_debug_check( (this->computed == false), "newarp::TridiagQR::matrix_RQ(): need to call compute() first" ); // Make a copy of the R matrix - Mat RQ(this->n, this->n, fill::zeros); + Mat RQ(this->n, this->n, arma_zeros_indicator()); RQ.diag() = this->mat_T.diag(); RQ.diag(1) = this->mat_T.diag(1); diff --git a/src/armadillo_bits/newarp_cx_attrib.hpp b/src/armadillo_bits/newarp_cx_attrib.hpp index 64147295..e654dc47 100644 --- a/src/armadillo_bits/newarp_cx_attrib.hpp +++ b/src/armadillo_bits/newarp_cx_attrib.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_all_bones.hpp b/src/armadillo_bits/op_all_bones.hpp index 07015389..b8faf9a2 100644 --- a/src/armadillo_bits/op_all_bones.hpp +++ b/src/armadillo_bits/op_all_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -46,8 +48,8 @@ class op_all all_vec_helper ( const mtOp& X, - const typename arma_op_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0 + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr ); @@ -56,9 +58,9 @@ class op_all all_vec_helper ( const mtGlue& X, - const typename arma_glue_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0, - const typename arma_not_cx::result junk3 = 0 + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr ); diff --git a/src/armadillo_bits/op_all_meat.hpp b/src/armadillo_bits/op_all_meat.hpp index 02302b5e..5dff3eca 100644 --- a/src/armadillo_bits/op_all_meat.hpp +++ b/src/armadillo_bits/op_all_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -117,8 +119,8 @@ bool op_all::all_vec_helper ( const mtOp& X, - const typename arma_op_rel_only::result junk1, - const typename arma_not_cx::result junk2 + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 ) { arma_extra_debug_sigprint(); @@ -189,9 +191,9 @@ bool op_all::all_vec_helper ( const mtGlue& X, - const typename arma_glue_rel_only::result junk1, - const typename arma_not_cx::result junk2, - const typename arma_not_cx::result junk3 + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/op_any_bones.hpp b/src/armadillo_bits/op_any_bones.hpp index 3c893e4b..ffb197bd 100644 --- a/src/armadillo_bits/op_any_bones.hpp +++ b/src/armadillo_bits/op_any_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -46,8 +48,8 @@ class op_any any_vec_helper ( const mtOp& X, - const typename arma_op_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0 + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr ); @@ -56,9 +58,9 @@ class op_any any_vec_helper ( const mtGlue& X, - const typename arma_glue_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0, - const typename arma_not_cx::result junk3 = 0 + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr ); diff --git a/src/armadillo_bits/op_any_meat.hpp b/src/armadillo_bits/op_any_meat.hpp index 4a1a5b04..3356ec7c 100644 --- a/src/armadillo_bits/op_any_meat.hpp +++ b/src/armadillo_bits/op_any_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -112,8 +114,8 @@ bool op_any::any_vec_helper ( const mtOp& X, - const typename arma_op_rel_only::result junk1, - const typename arma_not_cx::result junk2 + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 ) { arma_extra_debug_sigprint(); @@ -183,9 +185,9 @@ bool op_any::any_vec_helper ( const mtGlue& X, - const typename arma_glue_rel_only::result junk1, - const typename arma_not_cx::result junk2, - const typename arma_not_cx::result junk3 + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 ) { arma_extra_debug_sigprint(); diff --git a/src/armadillo_bits/op_chi2rnd_bones.hpp b/src/armadillo_bits/op_chi2rnd_bones.hpp index fb873e96..540bdfbd 100644 --- a/src/armadillo_bits/op_chi2rnd_bones.hpp +++ b/src/armadillo_bits/op_chi2rnd_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,8 +37,6 @@ class op_chi2rnd -#if defined(ARMA_USE_CXX11) - template class op_chi2rnd_varying_df { @@ -50,7 +50,5 @@ class op_chi2rnd_varying_df inline eT operator()(const eT df); }; -#endif - //! @} diff --git a/src/armadillo_bits/op_chi2rnd_meat.hpp b/src/armadillo_bits/op_chi2rnd_meat.hpp index f7b884b6..1b681ae8 100644 --- a/src/armadillo_bits/op_chi2rnd_meat.hpp +++ b/src/armadillo_bits/op_chi2rnd_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -53,46 +55,36 @@ op_chi2rnd::apply_noalias(Mat& out, const Proxy& P) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) + typedef typename T1::elem_type eT; + + op_chi2rnd_varying_df generator; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) { - typedef typename T1::elem_type eT; - - op_chi2rnd_varying_df generator; + const uword N = P.get_n_elem(); - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + typename Proxy::ea_type Pea = P.get_ea(); - out.set_size(n_rows, n_cols); - - eT* out_mem = out.memptr(); - - if(Proxy::use_at == false) + for(uword i=0; i::ea_type Pea = P.get_ea(); - - for(uword i=0; i& out, const eT df) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) + if(df > eT(0)) { - if(df > eT(0)) - { - typedef std::mt19937_64 motor_type; - typedef std::mt19937_64::result_type seed_type; - typedef std::chi_squared_distribution distr_type; - - motor_type motor; motor.seed( seed_type(arma_rng::randi()) ); - distr_type distr(df); - - const uword N = out.n_elem; - - eT* out_mem = out.memptr(); - - for(uword i=0; i distr_type; + + motor_type motor; motor.seed( seed_type(arma_rng::randi()) ); + distr_type distr(df); + + const uword N = out.n_elem; + + eT* out_mem = out.memptr(); + + for(uword i=0; i::nan ); + out_mem[i] = eT( distr(motor) ); } } - #else + else { - out.reset(); - arma_ignore(df); - - arma_stop_logic_error("chi2rnd(): C++11 compiler required"); + out.fill( Datum::nan ); } - #endif } @@ -145,8 +126,6 @@ op_chi2rnd::fill_constant_df(Mat& out, const eT df) -#if defined(ARMA_USE_CXX11) - template inline op_chi2rnd_varying_df::~op_chi2rnd_varying_df() @@ -192,8 +171,6 @@ op_chi2rnd_varying_df::operator()(const eT df) } } -#endif - //! @} diff --git a/src/armadillo_bits/op_chol_bones.hpp b/src/armadillo_bits/op_chol_bones.hpp index b5bdfa4d..e3b3a9c0 100644 --- a/src/armadillo_bits/op_chol_bones.hpp +++ b/src/armadillo_bits/op_chol_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_chol_meat.hpp b/src/armadillo_bits/op_chol_meat.hpp index 0a49a30c..ebc6448b 100644 --- a/src/armadillo_bits/op_chol_meat.hpp +++ b/src/armadillo_bits/op_chol_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -48,26 +50,19 @@ op_chol::apply_direct(Mat& out, const Base::no ) { arma_debug_warn("chol(): given matrix is not symmetric"); } - // if(is_cx::yes) { arma_debug_warn("chol(): given matrix is not hermitian"); } - // return false; - // } - if((arma_config::debug) && (auxlib::rudimentary_sym_check(out) == false)) { - if(is_cx::no ) { arma_debug_warn("chol(): given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn("chol(): given matrix is not hermitian"); } + if(is_cx::no ) { arma_debug_warn_level(1, "chol(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "chol(): given matrix is not hermitian"); } } uword KD = 0; - const bool is_band = (auxlib::crippled_lapack(out)) ? false : ((layout == 0) ? band_helper::is_band_upper(KD, out, uword(32)) : band_helper::is_band_lower(KD, out, uword(32))); + const bool is_band = arma_config::optimise_band && ((auxlib::crippled_lapack(out)) ? false : ((layout == 0) ? band_helper::is_band_upper(KD, out, uword(32)) : band_helper::is_band_lower(KD, out, uword(32)))); const bool status = (is_band) ? auxlib::chol_band(out, KD, layout) : auxlib::chol(out, layout); diff --git a/src/armadillo_bits/op_clamp_bones.hpp b/src/armadillo_bits/op_clamp_bones.hpp index 6141dcc7..89e53424 100644 --- a/src/armadillo_bits/op_clamp_bones.hpp +++ b/src/armadillo_bits/op_clamp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,17 +31,42 @@ class op_clamp template inline static void apply(Mat& out, const mtOp& in); - template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); - template inline static void apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val); + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + // cubes template inline static void apply(Cube& out, const mtOpCube& in); + template inline static void apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val); + template inline static void apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + }; + + + +class op_clamp_cx + : public traits_op_passthru + { + public: + + // matrices + + template inline static void apply(Mat& out, const mtOp& in); + + template inline static void apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); + + + // cubes + + template inline static void apply(Cube& out, const mtOpCube& in); template inline static void apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val); + + template inline static void apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val); }; diff --git a/src/armadillo_bits/op_clamp_meat.hpp b/src/armadillo_bits/op_clamp_meat.hpp index f05189d8..a14cd788 100644 --- a/src/armadillo_bits/op_clamp_meat.hpp +++ b/src/armadillo_bits/op_clamp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,17 +29,68 @@ op_clamp::apply(Mat& out, const mtOp P(in.m); + typedef typename T1::elem_type eT; + + const eT min_val = in.aux; + const eT max_val = in.aux_out_eT; + + arma_debug_check( (min_val > max_val), "clamp(): min_val must be less than max_val" ); - if(is_Mat::stored_type>::value || P.is_alias(out)) + if(is_Mat::value) + { + const unwrap U(in.m); + + op_clamp::apply_direct(out, U.M, min_val, max_val); + } + else { - const unwrap::stored_type> U(P.Q); + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_clamp::apply_proxy_noalias(tmp, P, min_val, max_val); + + out.steal_mem(tmp); + } + else + { + op_clamp::apply_proxy_noalias(out, P, min_val, max_val); + } + } + } + + + +template +inline +void +op_clamp::apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(&out != &X) + { + out.set_size(X.n_rows, X.n_cols); + + const uword N = out.n_elem; - op_clamp::apply_direct(out, U.M, in.aux, in.aux_out_eT); + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val) ? max_val : val); + } } else { - op_clamp::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); + arma_extra_debug_print("op_clamp::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); } } @@ -65,40 +118,68 @@ op_clamp::apply_proxy_noalias(Mat& out, const Proxy& typename Proxy::ea_type A = P.get_ea(); - uword j; - for(j=1; j max_val) ? max_val : val_i); - val_j = (val_j < min_val) ? min_val : ((val_j > max_val) ? max_val : val_j); + const eT val = A[i]; - (*out_mem) = val_i; out_mem++; - (*out_mem) = val_j; out_mem++; + out_mem[i] = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); } - - const uword i = j-1; - - if(i < N) + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) { - eT val_i = A[i]; + const eT val = P.at(row,col); - val_i = (val_i < min_val) ? min_val : ((val_i > max_val) ? max_val : val_i); + (*out_mem) = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); - (*out_mem) = val_i; + out_mem++; } } + } + + + +// + + + +template +inline +void +op_clamp::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const eT min_val = in.aux; + const eT max_val = in.aux_out_eT; + + arma_debug_check( (min_val > max_val), "clamp(): min_val must be less than max_val" ); + + if(is_Cube::value) + { + const unwrap_cube U(in.m); + + op_clamp::apply_direct(out, U.M, min_val, max_val); + } else { - for(uword col=0; col P(in.m); + + if(P.is_alias(out)) { - eT val = P.at(row,col); + Cube tmp; - val = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + op_clamp::apply_proxy_noalias(tmp, P, min_val, max_val); - (*out_mem) = val; out_mem++; + out.steal_mem(tmp); + } + else + { + op_clamp::apply_proxy_noalias(out, P, min_val, max_val); } } } @@ -108,29 +189,77 @@ op_clamp::apply_proxy_noalias(Mat& out, const Proxy& template inline void -op_clamp::apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val) +op_clamp::apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val) { arma_extra_debug_sigprint(); if(&out != &X) { - const Proxy< Mat > P(X); + out.set_size(X.n_rows, X.n_cols, X.n_slices); + + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); - op_clamp::apply_proxy_noalias(out, P, min_val, max_val); + for(uword i=0; i max_val) ? max_val : val); + } } else { - arma_extra_debug_print("inplace operation"); + arma_extra_debug_print("op_clamp::apply_direct(): inplace operation"); - const uword N = out.n_elem; + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp::apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + const uword N = P.get_n_elem(); - eT* out_mem = out.memptr(); + typename ProxyCube::ea_type A = P.get_ea(); for(uword i=0; i max_val) ? max_val : out_val ); + out_mem[i] = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + } + } + else + { + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + const eT val = P.at(r,c,s); + + (*out_mem) = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + + out_mem++; } } } @@ -144,21 +273,84 @@ op_clamp::apply_direct(Mat& out, const Mat& X, const eT min_val, const e template inline void -op_clamp::apply(Cube& out, const mtOpCube& in) +op_clamp_cx::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Mat::value) + { + const unwrap U(in.m); + + op_clamp_cx::apply_direct(out, U.M, in.aux, in.aux_out_eT); + } + else + { + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_clamp_cx::apply_proxy_noalias(tmp, P, in.aux, in.aux_out_eT); + + out.steal_mem(tmp); + } + else + { + op_clamp_cx::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); + } + } + } + + + +template +inline +void +op_clamp_cx::apply_direct(Mat& out, const Mat& X, const eT min_val, const eT max_val) { arma_extra_debug_sigprint(); - const ProxyCube P(in.m); + typedef typename get_pod_type::result T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); - if((is_Cube::stored_type>::value) || P.is_alias(out)) + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + if(&out != &X) { - const unwrap_cube::stored_type> U(P.Q); + out.set_size(X.n_rows, X.n_cols); + + const uword N = out.n_elem; - op_clamp::apply_direct(out, U.M, in.aux, in.aux_out_eT); + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } } else { - op_clamp::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); + arma_extra_debug_print("op_clamp_cx::apply_direct(): inplace operation"); + + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); } } @@ -167,61 +359,100 @@ op_clamp::apply(Cube& out, const mtOpCube inline void -op_clamp::apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) +op_clamp_cx::apply_proxy_noalias(Mat& out, const Proxy& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_slices = P.get_n_slices(); + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); - out.set_size(n_rows, n_cols, n_slices); + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + out.set_size(n_rows, n_cols); eT* out_mem = out.memptr(); - if(ProxyCube::use_at == false) + if(Proxy::use_at == false) { const uword N = P.get_n_elem(); - typename ProxyCube::ea_type A = P.get_ea(); + typename Proxy::ea_type A = P.get_ea(); - uword j; - for(j=1; j max_val) ? max_val : val_i); - val_j = (val_j < min_val) ? min_val : ((val_j > max_val) ? max_val : val_j); + val_real = (val_real < min_val_real) ? min_val_real : ((val_real > max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); - (*out_mem) = val_i; out_mem++; - (*out_mem) = val_j; out_mem++; + out_mem[i] = std::complex(val_real,val_imag); } - - const uword i = j-1; - - if(i < N) + } + else + { + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) { - eT val_i = A[i]; + const eT val = P.at(row,col); - val_i = (val_i < min_val) ? min_val : ((val_i > max_val) ? max_val : val_i); + T val_real = std::real(val); + T val_imag = std::imag(val); - (*out_mem) = val_i; + val_real = (val_real < min_val_real) ? min_val_real : ((val_real > max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + (*out_mem) = std::complex(val_real,val_imag); out_mem++; } } + } + + + +// + + + +template +inline +void +op_clamp_cx::apply(Cube& out, const mtOpCube& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(is_Cube::value) + { + const unwrap_cube U(in.m); + + op_clamp_cx::apply_direct(out, U.M, in.aux, in.aux_out_eT); + } else { - for(uword k=0; k < n_slices; ++k) - for(uword j=0; j < n_cols; ++j) - for(uword i=0; i < n_rows; ++i) + const ProxyCube P(in.m); + + if(P.is_alias(out)) { - eT val = P.at(i,j,k); + Cube tmp; - val = (val < min_val) ? min_val : ((val > max_val) ? max_val : val); + op_clamp_cx::apply_proxy_noalias(tmp, P, in.aux, in.aux_out_eT); - (*out_mem) = val; out_mem++; + out.steal_mem(tmp); + } + else + { + op_clamp_cx::apply_proxy_noalias(out, P, in.aux, in.aux_out_eT); } } } @@ -231,29 +462,112 @@ op_clamp::apply_proxy_noalias(Cube& out, const ProxyCube template inline void -op_clamp::apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val) +op_clamp_cx::apply_direct(Cube& out, const Cube& X, const eT min_val, const eT max_val) { arma_extra_debug_sigprint(); + typedef typename get_pod_type::result T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + if(&out != &X) { - const ProxyCube< Cube > P(X); + out.set_size(X.n_rows, X.n_cols, X.n_slices); - op_clamp::apply_proxy_noalias(out, P, min_val, max_val); + const uword N = out.n_elem; + + const eT* X_mem = X.memptr(); + eT* out_mem = out.memptr(); + + for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } } else { - arma_extra_debug_print("inplace operation"); + arma_extra_debug_print("op_clamp_cx::apply_direct(): inplace operation"); - const uword N = out.n_elem; + arrayops::clamp(out.memptr(), out.n_elem, min_val, max_val); + } + } + + + +template +inline +void +op_clamp_cx::apply_proxy_noalias(Cube& out, const ProxyCube& P, const typename T1::elem_type min_val, const typename T1::elem_type max_val) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const T min_val_real = std::real(min_val); + const T min_val_imag = std::imag(min_val); + + const T max_val_real = std::real(max_val); + const T max_val_imag = std::imag(max_val); + + arma_debug_check( (min_val_real > max_val_real), "clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (min_val_imag > max_val_imag), "clamp(): imag(min_val) must be less than imag(max_val)" ); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_slices = P.get_n_slices(); + + out.set_size(n_rows, n_cols, n_slices); + + eT* out_mem = out.memptr(); + + if(ProxyCube::use_at == false) + { + const uword N = P.get_n_elem(); - eT* out_mem = out.memptr(); + typename ProxyCube::ea_type A = P.get_ea(); for(uword i=0; i max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); + + out_mem[i] = std::complex(val_real,val_imag); + } + } + else + { + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + for(uword r=0; r < n_rows; ++r) + { + const eT val = P.at(r,c,s); + + T val_real = std::real(val); + T val_imag = std::imag(val); + + val_real = (val_real < min_val_real) ? min_val_real : ((val_real > max_val_real) ? max_val_real : val_real); + val_imag = (val_imag < min_val_imag) ? min_val_imag : ((val_imag > max_val_imag) ? max_val_imag : val_imag); - out_val = (out_val < min_val) ? min_val : ( (out_val > max_val) ? max_val : out_val ); + (*out_mem) = std::complex(val_real,val_imag); out_mem++; } } } diff --git a/src/armadillo_bits/spglue_elem_helper_bones.hpp b/src/armadillo_bits/op_col_as_mat_bones.hpp similarity index 68% rename from src/armadillo_bits/spglue_elem_helper_bones.hpp rename to src/armadillo_bits/op_col_as_mat_bones.hpp index d7b66f54..6e653ea4 100644 --- a/src/armadillo_bits/spglue_elem_helper_bones.hpp +++ b/src/armadillo_bits/op_col_as_mat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,24 +16,18 @@ // ------------------------------------------------------------------------ -//! \addtogroup spglue_elem_helper +//! \addtogroup op_col_as_mat //! @{ - -class spglue_elem_helper - : public traits_glue_default +class op_col_as_mat + : public traits_op_default { public: - template - arma_hot inline static uword max_n_nonzero_plus(const SpProxy& pa, const SpProxy& pb); - - template - arma_hot inline static uword max_n_nonzero_schur(const SpProxy& pa, const SpProxy& pb); + template inline static void apply(Mat& out, const CubeToMatOp& expr); }; //! @} - diff --git a/src/armadillo_bits/op_col_as_mat_meat.hpp b/src/armadillo_bits/op_col_as_mat_meat.hpp new file mode 100644 index 00000000..2e0f0cdf --- /dev/null +++ b/src/armadillo_bits/op_col_as_mat_meat.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_col_as_mat +//! @{ + + + +template +inline +void +op_col_as_mat::apply(Mat& out, const CubeToMatOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube U(expr.m); + const Cube& A = U.M; + + const uword in_col = expr.aux_uword; + + arma_debug_check_bounds( (in_col >= A.n_cols), "Cube::col_as_mat(): index out of bounds" ); + + const uword A_n_rows = A.n_rows; + const uword A_n_slices = A.n_slices; + + out.set_size(A_n_rows, A_n_slices); + + for(uword s=0; s < A_n_slices; ++s) + { + arrayops::copy(out.colptr(s), A.slice_colptr(s, in_col), A_n_rows); + } + } + + + +//! @} diff --git a/src/armadillo_bits/op_cond_bones.hpp b/src/armadillo_bits/op_cond_bones.hpp index 8ba28c5b..b7648987 100644 --- a/src/armadillo_bits/op_cond_bones.hpp +++ b/src/armadillo_bits/op_cond_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,8 +25,11 @@ class op_cond { public: - template static inline typename T1::pod_type cond(const Base& X); - template static inline typename T1::pod_type rcond(const Base& X); + template static inline typename T1::pod_type apply(const Base& X); + + template static inline typename get_pod_type::result apply_diag(const Mat& A); + template static inline typename get_pod_type::result apply_sym ( Mat& A); + template static inline typename get_pod_type::result apply_gen ( Mat& A); }; diff --git a/src/armadillo_bits/op_cond_meat.hpp b/src/armadillo_bits/op_cond_meat.hpp index 77369bb0..f73ef9a3 100644 --- a/src/armadillo_bits/op_cond_meat.hpp +++ b/src/armadillo_bits/op_cond_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,82 +24,149 @@ template inline typename T1::pod_type -op_cond::cond(const Base& X) +op_cond::apply(const Base& X) { arma_extra_debug_sigprint(); - typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; - Col S; + Mat A(X.get_ref()); - const bool status = auxlib::svd_dc(S, X); + if(A.n_elem == 0) { return T(0); } - if(status == false) + if(is_op_diagmat::value || A.is_diagmat()) { - arma_debug_warn("cond(): svd failed"); + arma_extra_debug_print("op_cond::apply(): detected diagonal matrix"); - return T(0); + return op_cond::apply_diag(A); } - return (S.n_elem > 0) ? T( max(S) / min(S) ) : T(0); + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + const bool do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); + + if(do_sym) + { + arma_extra_debug_print("op_cond: symmetric/hermitian optimisation"); + + return op_cond::apply_sym(A); + } + + return op_cond::apply_gen(A); } -template +template inline -typename T1::pod_type -op_cond::rcond(const Base& X) +typename get_pod_type::result +op_cond::apply_diag(const Mat& A) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - typedef typename T1::pod_type T; + typedef typename get_pod_type::result T; + + const uword N = (std::min)(A.n_rows, A.n_cols); - if(strip_trimat::do_trimat) + T abs_min = Datum::inf; + T abs_max = T(0); + + for(uword i=0; i < N; ++i) { - const strip_trimat T(X.get_ref()); - - arma_debug_check( (T.M.is_square() == false), "rcond(): matrix must be square sized" ); + const T abs_val = std::abs(A.at(i,i)); - const uword layout = (T.do_triu) ? uword(0) : uword(1); + if(arma_isnan(abs_val)) + { + arma_debug_warn_level(3, "cond(): failed"); + + return Datum::nan; + } - return auxlib::rcond_trimat(T.M, layout); + abs_min = (abs_val < abs_min) ? abs_val : abs_min; + abs_max = (abs_val > abs_max) ? abs_val : abs_max; } - Mat A = X.get_ref(); + if((abs_min == T(0)) || (abs_max == T(0))) { return Datum::inf; } - arma_debug_check( (A.is_square() == false), "rcond(): matrix must be square sized" ); + return T(abs_max / abs_min); + } + + + +template +inline +typename get_pod_type::result +op_cond::apply_sym(Mat& A) + { + arma_extra_debug_sigprint(); - if(A.is_empty()) { return Datum::inf; } + typedef typename get_pod_type::result T; - const bool is_triu = trimat_helper::is_triu(A); - const bool is_tril = (is_triu) ? false : trimat_helper::is_tril(A); + Col eigval; - if(is_triu || is_tril) + const bool status = auxlib::eig_sym(eigval, A); + + if(status == false) { - const uword layout = (is_triu) ? uword(0) : uword(1); + arma_debug_warn_level(3, "cond(): failed"); - return auxlib::rcond_trimat(A, layout); + return Datum::nan; } - const bool try_sympd = auxlib::crippled_lapack(A) ? false : sympd_helper::guess_sympd(A); + if(eigval.n_elem == 0) { return T(0); } + + const T* eigval_mem = eigval.memptr(); + + T abs_min = std::abs(eigval_mem[0]); + T abs_max = abs_min; - if(try_sympd) + for(uword i=1; i < eigval.n_elem; ++i) { - bool calc_ok = false; + const T abs_val = std::abs(eigval_mem[i]); - const T out_val = auxlib::rcond_sympd(A, calc_ok); - - if(calc_ok) { return out_val; } + abs_min = (abs_val < abs_min) ? abs_val : abs_min; + abs_max = (abs_val > abs_max) ? abs_val : abs_max; + } + + if((abs_min == T(0)) || (abs_max == T(0))) { return Datum::inf; } + + return T(abs_max / abs_min); + } + + + +template +inline +typename get_pod_type::result +op_cond::apply_gen(Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col S; + + const bool status = auxlib::svd_dc(S, A); + + if(status == false) + { + arma_debug_warn_level(3, "cond(): failed"); - // auxlib::rcond_sympd() may have failed because A isn't really sympd - // restore A, as auxlib::rcond_sympd() may have destroyed it - A = X.get_ref(); - // fallthrough to the next return statement + return Datum::nan; } - return auxlib::rcond(A); + if(S.n_elem == 0) { return T(0); } + + const T S_max = S[0]; + const T S_min = S[S.n_elem-1]; + + if((S_max == T(0)) || (S_min == T(0))) { return Datum::inf; } + + return T(S_max / S_min); } diff --git a/src/armadillo_bits/op_cor_bones.hpp b/src/armadillo_bits/op_cor_bones.hpp index fbb831a9..7a506c3a 100644 --- a/src/armadillo_bits/op_cor_bones.hpp +++ b/src/armadillo_bits/op_cor_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cor_meat.hpp b/src/armadillo_bits/op_cor_meat.hpp index 1255cba1..6763964d 100644 --- a/src/armadillo_bits/op_cor_meat.hpp +++ b/src/armadillo_bits/op_cor_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cov_bones.hpp b/src/armadillo_bits/op_cov_bones.hpp index eb759071..5de43ba8 100644 --- a/src/armadillo_bits/op_cov_bones.hpp +++ b/src/armadillo_bits/op_cov_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cov_meat.hpp b/src/armadillo_bits/op_cov_meat.hpp index a1a4d93c..30944b75 100644 --- a/src/armadillo_bits/op_cov_meat.hpp +++ b/src/armadillo_bits/op_cov_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cumprod_bones.hpp b/src/armadillo_bits/op_cumprod_bones.hpp index 6f3c5a4a..ce3b686f 100644 --- a/src/armadillo_bits/op_cumprod_bones.hpp +++ b/src/armadillo_bits/op_cumprod_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cumprod_meat.hpp b/src/armadillo_bits/op_cumprod_meat.hpp index fa5ba1bd..14dc2245 100644 --- a/src/armadillo_bits/op_cumprod_meat.hpp +++ b/src/armadillo_bits/op_cumprod_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cumsum_bones.hpp b/src/armadillo_bits/op_cumsum_bones.hpp index 3d190e1a..007d3f36 100644 --- a/src/armadillo_bits/op_cumsum_bones.hpp +++ b/src/armadillo_bits/op_cumsum_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cumsum_meat.hpp b/src/armadillo_bits/op_cumsum_meat.hpp index 23c9aa13..d46eda2b 100644 --- a/src/armadillo_bits/op_cumsum_meat.hpp +++ b/src/armadillo_bits/op_cumsum_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cx_scalar_bones.hpp b/src/armadillo_bits/op_cx_scalar_bones.hpp index f145aac3..c1c28477 100644 --- a/src/armadillo_bits/op_cx_scalar_bones.hpp +++ b/src/armadillo_bits/op_cx_scalar_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_cx_scalar_meat.hpp b/src/armadillo_bits/op_cx_scalar_meat.hpp index e2cf0794..9552c01a 100644 --- a/src/armadillo_bits/op_cx_scalar_meat.hpp +++ b/src/armadillo_bits/op_cx_scalar_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_det_bones.hpp b/src/armadillo_bits/op_det_bones.hpp new file mode 100644 index 00000000..59575faa --- /dev/null +++ b/src/armadillo_bits/op_det_bones.hpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_det +//! @{ + + + +class op_det + : public traits_op_default + { + public: + + template + struct pos + { + static constexpr uword n2 = row + col*2; + static constexpr uword n3 = row + col*3; + }; + + template + inline static bool apply_direct(typename T1::elem_type& out_val, const Base& expr); + + template + inline static typename T1::elem_type apply_diagmat(const Base& expr); + + template + inline static typename T1::elem_type apply_trimat(const Base& expr); + + template + arma_cold inline static eT apply_tiny_2x2(const Mat& X); + + template + arma_cold inline static eT apply_tiny_3x3(const Mat& X); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_det_meat.hpp b/src/armadillo_bits/op_det_meat.hpp new file mode 100644 index 00000000..81c9b399 --- /dev/null +++ b/src/armadillo_bits/op_det_meat.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_det +//! @{ + + + +template +inline +bool +op_det::apply_direct(typename T1::elem_type& out_val, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(strip_diagmat::do_diagmat) + { + const strip_diagmat strip(expr.get_ref()); + + out_val = op_det::apply_diagmat(strip.M); + + return true; + } + + if(strip_trimat::do_trimat) + { + const strip_trimat strip(expr.get_ref()); + + out_val = op_det::apply_trimat(strip.M); + + return true; + } + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "det(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(N == 0) { out_val = eT(1); return true; } + if(N == 1) { out_val = A[0]; return true; } + + if((is_cx::no) && (N <= 3)) + { + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT det_val = eT(0); + + if(N == 2) { det_val = op_det::apply_tiny_2x2(A); } + if(N == 3) { det_val = op_det::apply_tiny_3x3(A); } + + const T abs_det_val = std::abs(det_val); + + if((abs_det_val > det_min) && (abs_det_val < det_max)) { out_val = det_val; return true; } + + // fallthrough if det_val is suspect + } + + if(A.is_diagmat()) { out_val = op_det::apply_diagmat(A); return true; } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = is_triu ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) { out_val = op_det::apply_trimat(A); return true; } + + return auxlib::det(out_val, A); + } + + + +template +inline +typename T1::elem_type +op_det::apply_diagmat(const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const diagmat_proxy A(expr.get_ref()); + + arma_debug_check( (A.n_rows != A.n_cols), "det(): given matrix must be square sized" ); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + eT val = eT(1); + + for(uword i=0; i +inline +typename T1::elem_type +op_det::apply_trimat(const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const Proxy P(expr.get_ref()); + + const uword N = P.get_n_rows(); + + arma_debug_check( (N != P.get_n_cols()), "det(): given matrix must be square sized" ); + + eT val = eT(1); + + for(uword i=0; i +inline +eT +op_det::apply_tiny_2x2(const Mat& X) + { + arma_extra_debug_sigprint(); + + const eT* Xm = X.memptr(); + + return ( Xm[pos<0,0>::n2]*Xm[pos<1,1>::n2] - Xm[pos<0,1>::n2]*Xm[pos<1,0>::n2] ); + } + + + +template +inline +eT +op_det::apply_tiny_3x3(const Mat& X) + { + arma_extra_debug_sigprint(); + + const eT* Xm = X.memptr(); + + // const double tmp1 = X.at(0,0) * X.at(1,1) * X.at(2,2); + // const double tmp2 = X.at(0,1) * X.at(1,2) * X.at(2,0); + // const double tmp3 = X.at(0,2) * X.at(1,0) * X.at(2,1); + // const double tmp4 = X.at(2,0) * X.at(1,1) * X.at(0,2); + // const double tmp5 = X.at(2,1) * X.at(1,2) * X.at(0,0); + // const double tmp6 = X.at(2,2) * X.at(1,0) * X.at(0,1); + // return (tmp1+tmp2+tmp3) - (tmp4+tmp5+tmp6); + + const eT val1 = Xm[pos<0,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]); + const eT val2 = Xm[pos<1,0>::n3]*(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]); + const eT val3 = Xm[pos<2,0>::n3]*(Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]); + + return ( val1 - val2 + val3 ); + } + + + +//! @} diff --git a/src/armadillo_bits/op_diagmat_bones.hpp b/src/armadillo_bits/op_diagmat_bones.hpp index 7d1580e4..50f7dc7b 100644 --- a/src/armadillo_bits/op_diagmat_bones.hpp +++ b/src/armadillo_bits/op_diagmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,7 +29,17 @@ class op_diagmat template inline static void apply(Mat& out, const Op& X); - // TODO: implement specialised handling of Op,op_diagmat> + template + inline static void apply(Mat& out, const Proxy& P); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagmat>& X); + + template + inline static void apply_times(Mat& out, const T1& X, const T2& Y, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply_times(Mat& out, const T1& X, const T2& Y, const typename arma_cx_only::result* junk = nullptr); }; @@ -38,10 +50,10 @@ class op_diagmat2 public: template - inline static void apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset); + inline static void apply(Mat& out, const Op& X); template - inline static void apply(Mat& out, const Op& X); + inline static void apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset); }; diff --git a/src/armadillo_bits/op_diagmat_meat.hpp b/src/armadillo_bits/op_diagmat_meat.hpp index c19c3186..727da7bb 100644 --- a/src/armadillo_bits/op_diagmat_meat.hpp +++ b/src/armadillo_bits/op_diagmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,184 +30,647 @@ op_diagmat::apply(Mat& out, const Op& X) typedef typename T1::elem_type eT; - const Proxy P(X.m); + if(is_Mat::value) + { + // allow detection of in-place operation + + const unwrap U(X.m); + const Mat& A = U.M; + + if(&out != &A) // no aliasing + { + const Proxy< Mat > P(A); + + op_diagmat::apply(out, P); + } + else // we have aliasing + { + const uword n_rows = out.n_rows; + const uword n_cols = out.n_cols; + + if((n_rows == 1) || (n_cols == 1)) // create diagonal matrix from vector + { + const eT* out_mem = out.memptr(); + const uword N = out.n_elem; + + Mat tmp(N,N, arma_zeros_indicator()); + + for(uword i=0; i P(X.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_diagmat::apply(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_diagmat::apply(out, P); + } + } + } + + + +template +inline +void +op_diagmat::apply(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); const uword n_rows = P.get_n_rows(); const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); - const bool P_is_vec = (n_rows == 1) || (n_cols == 1); + if(n_elem == 0) { out.reset(); return; } + const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1); - if(P.is_alias(out) == false) + if(P_is_vec) { - if(P_is_vec) // generate a diagonal matrix out of a vector + out.zeros(n_elem, n_elem); + + if(Proxy::use_at == false) { - const uword N = (n_rows == 1) ? n_cols : n_rows; - - out.zeros(N, N); + typename Proxy::ea_type Pea = P.get_ea(); - if(Proxy::use_at == false) + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = Pea[i]; } + } + else + { + if(n_rows == 1) { - typename Proxy::ea_type P_ea = P.get_ea(); - - for(uword i=0; i < N; ++i) { out.at(i,i) = P_ea[i]; } + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(0,i); } } else { - if(n_rows == 1) - { - for(uword i=0; i < N; ++i) { out.at(i,i) = P.at(0,i); } - } - else - { - for(uword i=0; i < N; ++i) { out.at(i,i) = P.at(i,0); } - } + for(uword i=0; i < n_elem; ++i) { out.at(i,i) = P.at(i,0); } } } - else // generate a diagonal matrix out of a matrix + } + else // P represents a matrix + { + out.zeros(n_rows, n_cols); + + const uword N = (std::min)(n_rows, n_cols); + + for(uword i=0; i +inline +void +op_diagmat::apply(Mat& out, const Op< Glue, op_diagmat>& X) + { + arma_extra_debug_sigprint(); + + op_diagmat::apply_times(out, X.m.A, X.m.B); + } + + + +template +inline +void +op_diagmat::apply_times(Mat& actual_out, const T1& X, const T2& Y, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X); + const partial_unwrap UB(Y); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + // check if the multiplication results in a vector + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + if((A_n_rows == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = false; vector result"); + + const Mat C = A*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + if((A_n_cols == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = false; vector result"); + + const Mat C = trans(A)*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_rows == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = true; vector result"); + + const Mat C = A*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_cols == 1) || (B_n_rows == 1)) { - out.zeros(n_rows, n_cols); + arma_extra_debug_print("trans_A = true; trans_B = true; vector result"); + + const Mat C = trans(A)*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; - const uword N = (std::min)(n_rows, n_cols); + actual_out.zeros(N,N); - for(uword i=0; i < N; ++i) { out.at(i,i) = P.at(i,i); } + for(uword i=0; i tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) { - if(P_is_vec) // generate a diagonal matrix out of a vector + arma_extra_debug_print("trans_A = false; trans_B = false; matrix result"); + + out.zeros(A_n_rows, B_n_cols); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + for(uword k=0; k < N; ++k) { - const uword N = (n_rows == 1) ? n_cols : n_rows; + eT acc1 = eT(0); + eT acc2 = eT(0); + + const eT* B_colptr = B.colptr(k); - podarray tmp(N); - eT* tmp_mem = tmp.memptr(); + // condition: A_n_cols = B_n_rows - if(Proxy::use_at == false) + uword j; + + for(j=1; j < A_n_cols; j+=2) { - typename Proxy::ea_type P_ea = P.get_ea(); + const uword i = (j-1); + + const eT tmp_i = B_colptr[i]; + const eT tmp_j = B_colptr[j]; - for(uword i=0; i < N; ++i) { tmp_mem[i] = P_ea[i]; } + acc1 += A.at(k, i) * tmp_i; + acc2 += A.at(k, j) * tmp_j; } - else + + const uword i = (j-1); + + if(i < A_n_cols) { - if(n_rows == 1) - { - for(uword i=0; i < N; ++i) { tmp_mem[i] = P.at(0,i); } - } - else - { - for(uword i=0; i < N; ++i) { tmp_mem[i] = P.at(i,0); } - } + acc1 += A.at(k, i) * B_colptr[i]; } - out.zeros(N, N); + const eT acc = acc1 + acc2; + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false; matrix result"); + + out.zeros(A_n_cols, B_n_cols); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + const eT acc = op_dot::direct_dot(A_n_rows, A_colptr, B_colptr); - for(uword i=0; i < N; ++i) { out.at(i,i) = tmp_mem[i]; } + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); } - else // generate a diagonal matrix out of a matrix + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true; matrix result"); + + out.zeros(A_n_rows, B_n_rows); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + for(uword k=0; k < N; ++k) { - const uword N = (std::min)(n_rows, n_cols); + eT acc = eT(0); + + // condition: A_n_cols = B_n_cols - if( (Proxy::has_subview == false) && (Proxy::fake_mat == false) ) + for(uword i=0; i < A_n_cols; ++i) { - // NOTE: we have aliasing and it's not due to a subview, hence we're assuming that the output matrix already has the correct size - - for(uword i=0; i < n_cols; ++i) - { - if(i < N) - { - const eT val = P.at(i,i); - - arrayops::fill_zeros(out.colptr(i), n_rows); - - out.at(i,i) = val; - } - else - { - arrayops::fill_zeros(out.colptr(i), n_rows); - } - } + acc += A.at(k,i) * B.at(k,i); } - else + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true; matrix result"); + + out.zeros(A_n_cols, B_n_rows); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + for(uword k=0; k < N; ++k) + { + eT acc = eT(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) { - podarray tmp(N); - eT* tmp_mem = tmp.memptr(); - - for(uword i=0; i < N; ++i) { tmp_mem[i] = P.at(i,i); } - - out.zeros(n_rows, n_cols); - - for(uword i=0; i < N; ++i) { out.at(i,i) = tmp_mem[i]; } + acc += A_colptr[i] * B.at(k,i); } + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); } } + + if(is_alias) { actual_out.steal_mem(tmp); } } -template +template inline void -op_diagmat2::apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset) +op_diagmat::apply_times(Mat& actual_out, const T1& X, const T2& Y, const typename arma_cx_only::result* junk) { arma_extra_debug_sigprint(); + arma_ignore(junk); - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_elem = P.get_n_elem(); + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; - if(n_elem == 0) { out.reset(); return; } + const partial_unwrap UA(X); + const partial_unwrap UB(Y); - const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1); + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; - if(P_is_vec) + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + // check if the multiplication results in a vector + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) { - const uword n_pad = (std::max)(row_offset, col_offset); + if((A_n_rows == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = false; vector result"); + + const Mat C = A*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + if((A_n_cols == 1) || (B_n_cols == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = false; vector result"); + + const Mat C = trans(A)*B; + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == false) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_rows == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = false; trans_B = true; vector result"); + + const Mat C = A*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + if((A_n_cols == 1) || (B_n_rows == 1)) + { + arma_extra_debug_print("trans_A = true; trans_B = true; vector result"); + + const Mat C = trans(A)*trans(B); + const eT* C_mem = C.memptr(); + const uword N = C.n_elem; + + actual_out.zeros(N,N); + + for(uword i=0; i tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false; matrix result"); - out.zeros(n_elem + n_pad, n_elem + n_pad); + out.zeros(A_n_rows, B_n_cols); - if(Proxy::use_at == false) + const uword N = (std::min)(A_n_rows, B_n_cols); + + for(uword k=0; k < N; ++k) { - typename Proxy::ea_type Pea = P.get_ea(); + T acc_real = T(0); + T acc_imag = T(0); + + const eT* B_colptr = B.colptr(k); - for(uword i=0; i < n_elem; ++i) + // condition: A_n_cols = B_n_rows + + for(uword i=0; i < A_n_cols; ++i) { - out.at(row_offset + i, col_offset + i) = Pea[i]; + // acc += A.at(k, i) * B_colptr[i]; + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); } - else + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false; matrix result"); + + out.zeros(A_n_cols, B_n_cols); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + for(uword k=0; k < N; ++k) { - const unwrap::stored_type> U(P.Q); + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * B_colptr[i]; + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + // take into account the complex conjugate of xx + + acc_real += (a*c) + (b*d); + acc_imag += (a*d) - (b*c); + } - const Proxy::stored_type>::stored_type> PP(U.M); + const eT acc = std::complex(acc_real, acc_imag); - op_diagmat2::apply(out, PP, row_offset, col_offset); + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); } } - else // P represents a matrix + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true) ) { - arma_debug_check - ( - ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), - "diagmat(): requested diagonal out of bounds" - ); + arma_extra_debug_print("trans_A = false; trans_B = true; matrix result"); - out.zeros(n_rows, n_cols); + out.zeros(A_n_rows, B_n_rows); - const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + const uword N = (std::min)(A_n_rows, B_n_rows); - for(uword i=0; i& xx = A.at(k, i); + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true; matrix result"); + + out.zeros(A_n_cols, B_n_rows); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i)); + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = -xx.imag(); // take the conjugate + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out.at(k,k) = (use_alpha) ? eT(alpha * acc) : eT(acc); } } + + if(is_alias) { actual_out.steal_mem(tmp); } } +// +// +// + + + template inline void @@ -236,4 +701,67 @@ op_diagmat2::apply(Mat& out, const Op& +template +inline +void +op_diagmat2::apply(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); + + if(n_elem == 0) { out.reset(); return; } + + const bool P_is_vec = (T1::is_row) || (T1::is_col) || (n_rows == 1) || (n_cols == 1); + + if(P_is_vec) + { + const uword n_pad = (std::max)(row_offset, col_offset); + + out.zeros(n_elem + n_pad, n_elem + n_pad); + + if(Proxy::use_at == false) + { + typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = Pea[i]; } + } + else + { + if(n_rows == 1) + { + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(0,i); } + } + else + { + for(uword i=0; i < n_elem; ++i) { out.at(row_offset + i, col_offset + i) = P.at(i,0); } + } + } + } + else // P represents a matrix + { + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diagmat(): requested diagonal out of bounds" + ); + + out.zeros(n_rows, n_cols); + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i& out, const Op& X); template - arma_hot inline static void apply_unwrap(Mat& out, const T1& X, const uword row_offset, const uword col_offset, const uword len); + inline static void apply_proxy(Mat& out, const Proxy& P); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagvec>& X, const typename arma_not_cx::result* junk = nullptr); + + template + inline static void apply(Mat& out, const Op< Glue, op_diagvec>& X, const typename arma_cx_only::result* junk = nullptr); + }; + + + +class op_diagvec2 + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const Op& X); template - arma_hot inline static void apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset, const uword len); + inline static void apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset); }; diff --git a/src/armadillo_bits/op_diagvec_meat.hpp b/src/armadillo_bits/op_diagvec_meat.hpp index 9b82a758..f3371928 100644 --- a/src/armadillo_bits/op_diagvec_meat.hpp +++ b/src/armadillo_bits/op_diagvec_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,93 +30,487 @@ op_diagvec::apply(Mat& out, const Op& X) typedef typename T1::elem_type eT; - const uword a = X.aux_uword_a; - const uword b = X.aux_uword_b; + const Proxy P(X.m); - const uword row_offset = (b > 0) ? a : 0; - const uword col_offset = (b == 0) ? a : 0; + if(P.is_alias(out) == false) + { + op_diagvec::apply_proxy(out, P); + } + else + { + Mat tmp; + + op_diagvec::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +op_diagvec::apply_proxy(Mat& out, const Proxy& P) + { + arma_extra_debug_sigprint(); - const Proxy P(X.m); + typedef typename T1::elem_type eT; const uword n_rows = P.get_n_rows(); const uword n_cols = P.get_n_cols(); - arma_debug_check - ( - ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), - "diagvec(): requested diagonal is out of bounds" - ); + const uword len = (std::min)(n_rows, n_cols); - const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + out.set_size(len, 1); + + eT* out_mem = out.memptr(); + + uword i,j; + for(i=0, j=1; j < len; i+=2, j+=2) + { + const eT tmp_i = P.at(i, i); + const eT tmp_j = P.at(j, j); + + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; + } - if( (is_Mat::stored_type>::value) && (Proxy::fake_mat == false) ) + if(i < len) { - op_diagvec::apply_unwrap(out, P.Q, row_offset, col_offset, len); + out_mem[i] = P.at(i, i); + } + } + + + +template +inline +void +op_diagvec::apply(Mat& actual_out, const Op< Glue, op_diagvec>& X, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename T1::elem_type eT; + + const partial_unwrap UA(X.m.A); + const partial_unwrap UB(X.m.B); + + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; + + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); + + if( (A.n_elem == 0) || (B.n_elem == 0) ) { actual_out.reset(); return; } + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out)); + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = false; trans_B = false;"); + + const uword N = (std::min)(A_n_rows, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + eT acc1 = eT(0); + eT acc2 = eT(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + uword j; + + for(j=1; j < A_n_cols; j+=2) + { + const uword i = (j-1); + + const eT tmp_i = B_colptr[i]; + const eT tmp_j = B_colptr[j]; + + acc1 += A.at(k, i) * tmp_i; + acc2 += A.at(k, j) * tmp_j; + } + + const uword i = (j-1); + + if(i < A_n_cols) + { + acc1 += A.at(k, i) * B_colptr[i]; + } + + const eT acc = acc1 + acc2; + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false;"); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + const eT acc = op_dot::direct_dot(A_n_rows, A_colptr, B_colptr); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } } else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true ) ) { - if(P.is_alias(out) == false) + arma_extra_debug_print("trans_A = false; trans_B = true;"); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) { - op_diagvec::apply_proxy(out, P, row_offset, col_offset, len); + eT acc = eT(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + acc += A.at(k,i) * B.at(k,i); + } + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); } - else + } + else + if( (partial_unwrap::do_trans == true ) && (partial_unwrap::do_trans == true ) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true;"); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) { - Mat tmp; + eT acc = eT(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols - op_diagvec::apply_proxy(tmp, P, row_offset, col_offset, len); + for(uword i=0; i < A_n_rows; ++i) + { + acc += A_colptr[i] * B.at(k,i); + } - out.steal_mem(tmp); + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); } } + + if(is_alias) { actual_out.steal_mem(tmp); } } -template -arma_hot +template inline void -op_diagvec::apply_unwrap(Mat& out, const T1& X, const uword row_offset, const uword col_offset, const uword len) +op_diagvec::apply(Mat& actual_out, const Op< Glue, op_diagvec>& X, const typename arma_cx_only::result* junk) { arma_extra_debug_sigprint(); + arma_ignore(junk); + typedef typename T1::pod_type T; typedef typename T1::elem_type eT; - const unwrap_check tmp_A(X, out); - const Mat& A = tmp_A.M; + const partial_unwrap UA(X.m.A); + const partial_unwrap UB(X.m.B); - out.set_size(len, 1); + const typename partial_unwrap::stored_type& A = UA.M; + const typename partial_unwrap::stored_type& B = UB.M; - eT* out_mem = out.memptr(); + arma_debug_assert_trans_mul_size< partial_unwrap::do_trans, partial_unwrap::do_trans >(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); - uword i,j; - for(i=0, j=1; j < len; i+=2, j+=2) + if( (A.n_elem == 0) || (B.n_elem == 0) ) { actual_out.reset(); return; } + + const bool use_alpha = partial_unwrap::do_times || partial_unwrap::do_times; + const eT alpha = use_alpha ? (UA.get_val() * UB.get_val()) : eT(0); + + const bool is_alias = (UA.is_alias(actual_out) || UB.is_alias(actual_out)); + + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; + + const uword A_n_rows = A.n_rows; + const uword A_n_cols = A.n_cols; + + const uword B_n_rows = B.n_rows; + const uword B_n_cols = B.n_cols; + + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == false) ) { - const eT tmp_i = A.at( i + row_offset, i + col_offset ); - const eT tmp_j = A.at( j + row_offset, j + col_offset ); + arma_extra_debug_print("trans_A = false; trans_B = false;"); - out_mem[i] = tmp_i; - out_mem[j] = tmp_j; + const uword N = (std::min)(A_n_rows, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* B_colptr = B.colptr(k); + + // condition: A_n_cols = B_n_rows + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k, i) * B_colptr[i]; + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == false) ) + { + arma_extra_debug_print("trans_A = true; trans_B = false;"); + + const uword N = (std::min)(A_n_cols, B_n_cols); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + const eT* B_colptr = B.colptr(k); + + // condition: A_n_rows = B_n_rows + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * B_colptr[i]; + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B_colptr[i]; + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = yy.imag(); + + // take into account the complex conjugate of xx + + acc_real += (a*c) + (b*d); + acc_imag += (a*d) - (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == false) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = false; trans_B = true;"); + + const uword N = (std::min)(A_n_rows, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + // condition: A_n_cols = B_n_cols + + for(uword i=0; i < A_n_cols; ++i) + { + // acc += A.at(k,i) * std::conj(B.at(k,i)); + + const std::complex& xx = A.at(k, i); + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = xx.imag(); + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } + } + else + if( (partial_unwrap::do_trans == true) && (partial_unwrap::do_trans == true) ) + { + arma_extra_debug_print("trans_A = true; trans_B = true;"); + + const uword N = (std::min)(A_n_cols, B_n_rows); + + out.set_size(N,1); + + eT* out_mem = out.memptr(); + + for(uword k=0; k < N; ++k) + { + T acc_real = T(0); + T acc_imag = T(0); + + const eT* A_colptr = A.colptr(k); + + // condition: A_n_rows = B_n_cols + + for(uword i=0; i < A_n_rows; ++i) + { + // acc += std::conj(A_colptr[i]) * std::conj(B.at(k,i)); + + const std::complex& xx = A_colptr[i]; + const std::complex& yy = B.at(k, i); + + const T a = xx.real(); + const T b = -xx.imag(); // take the conjugate + + const T c = yy.real(); + const T d = -yy.imag(); // take the conjugate + + acc_real += (a*c) - (b*d); + acc_imag += (a*d) + (b*c); + } + + const eT acc = std::complex(acc_real, acc_imag); + + out_mem[k] = (use_alpha) ? eT(alpha * acc) : eT(acc); + } } - if(i < len) + if(is_alias) { actual_out.steal_mem(tmp); } + } + + + +// +// +// + + + +template +inline +void +op_diagvec2::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const uword a = X.aux_uword_a; + const uword b = X.aux_uword_b; + + const uword row_offset = (b > 0) ? a : 0; + const uword col_offset = (b == 0) ? a : 0; + + const Proxy P(X.m); + + if(P.is_alias(out) == false) + { + op_diagvec2::apply_proxy(out, P, row_offset, col_offset); + } + else { - out_mem[i] = A.at( i + row_offset, i + col_offset ); + Mat tmp; + + op_diagvec2::apply_proxy(tmp, P, row_offset, col_offset); + + out.steal_mem(tmp); } } template -arma_hot inline void -op_diagvec::apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset, const uword len) +op_diagvec2::apply_proxy(Mat& out, const Proxy& P, const uword row_offset, const uword col_offset) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + + arma_debug_check_bounds + ( + ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), + "diagvec(): requested diagonal is out of bounds" + ); + + const uword len = (std::min)(n_rows - row_offset, n_cols - col_offset); + out.set_size(len, 1); eT* out_mem = out.memptr(); diff --git a/src/armadillo_bits/op_diff_bones.hpp b/src/armadillo_bits/op_diff_bones.hpp index 6833a623..a6844abc 100644 --- a/src/armadillo_bits/op_diff_bones.hpp +++ b/src/armadillo_bits/op_diff_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_diff_meat.hpp b/src/armadillo_bits/op_diff_meat.hpp index 2a4db1d0..a5b309b6 100644 --- a/src/armadillo_bits/op_diff_meat.hpp +++ b/src/armadillo_bits/op_diff_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_dot_bones.hpp b/src/armadillo_bits/op_dot_bones.hpp index 82c8fc41..23068de8 100644 --- a/src/armadillo_bits/op_dot_bones.hpp +++ b/src/armadillo_bits/op_dot_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,7 +28,7 @@ class op_dot public: template - arma_hot arma_inline static + arma_inline static typename arma_not_cx::result direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B); diff --git a/src/armadillo_bits/op_dot_meat.hpp b/src/armadillo_bits/op_dot_meat.hpp index 0fc7c026..e94c76d8 100644 --- a/src/armadillo_bits/op_dot_meat.hpp +++ b/src/armadillo_bits/op_dot_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,14 +23,13 @@ //! for two arrays, generic version for non-complex values template -arma_hot arma_inline typename arma_not_cx::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) { arma_extra_debug_sigprint(); - #if defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0) + #if defined(__FAST_MATH__) { eT val = eT(0); @@ -66,7 +67,6 @@ op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B //! for two arrays, generic version for complex values template -arma_hot inline typename arma_cx_only::result op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B) @@ -100,7 +100,6 @@ op_dot::direct_dot_arma(const uword n_elem, const eT* const A, const eT* const B //! for two arrays, float and double version template -arma_hot inline typename arma_real_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) @@ -138,7 +137,6 @@ op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) //! for two arrays, complex version template inline -arma_hot typename arma_cx_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) { @@ -172,7 +170,6 @@ op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) //! for two arrays, integral version template -arma_hot inline typename arma_integral_only::result op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) @@ -185,7 +182,6 @@ op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B) //! for three arrays template -arma_hot inline eT op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, const eT* C) @@ -205,7 +201,6 @@ op_dot::direct_dot(const uword n_elem, const eT* const A, const eT* const B, con template -arma_hot inline typename T1::elem_type op_dot::apply(const T1& X, const T2& Y) @@ -265,7 +260,6 @@ op_dot::apply(const T1& X, const T2& Y) template -arma_hot inline typename arma_not_cx::result op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) @@ -303,7 +297,6 @@ op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) template -arma_hot inline typename arma_cx_only::result op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) @@ -350,7 +343,6 @@ op_dot::apply_proxy(const Proxy& PA, const Proxy& PB) template -arma_hot inline typename T1::elem_type op_norm_dot::apply(const T1& X, const T2& Y) @@ -381,7 +373,6 @@ op_norm_dot::apply(const T1& X, const T2& Y) template -arma_hot inline eT op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const B) @@ -414,7 +405,6 @@ op_cdot::direct_cdot_arma(const uword n_elem, const eT* const A, const eT* const template -arma_hot inline eT op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) @@ -450,12 +440,6 @@ op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) return result[0]; } - #elif defined(ARMA_USE_ATLAS) - { - // TODO: use dedicated atlas functions cblas_cdotc_sub() and cblas_zdotc_sub() and retune threshold - - return op_cdot::direct_cdot_arma(n_elem, A, B); - } #else { return op_cdot::direct_cdot_arma(n_elem, A, B); @@ -467,14 +451,13 @@ op_cdot::direct_cdot(const uword n_elem, const eT* const A, const eT* const B) template -arma_hot inline typename T1::elem_type op_cdot::apply(const T1& X, const T2& Y) { arma_extra_debug_sigprint(); - if( (is_Mat::value == true) && (is_Mat::value == true) ) + if(is_Mat::value && is_Mat::value) { return op_cdot::apply_unwrap(X,Y); } @@ -487,7 +470,6 @@ op_cdot::apply(const T1& X, const T2& Y) template -arma_hot inline typename T1::elem_type op_cdot::apply_unwrap(const T1& X, const T2& Y) @@ -510,7 +492,6 @@ op_cdot::apply_unwrap(const T1& X, const T2& Y) template -arma_hot inline typename T1::elem_type op_cdot::apply_proxy(const T1& X, const T2& Y) @@ -566,7 +547,6 @@ op_cdot::apply_proxy(const T1& X, const T2& Y) template -arma_hot inline typename promote_type::result op_dot_mixed::apply(const T1& A, const T2& B) diff --git a/src/armadillo_bits/op_dotext_bones.hpp b/src/armadillo_bits/op_dotext_bones.hpp index 4d74c227..dc3b7b81 100644 --- a/src/armadillo_bits/op_dotext_bones.hpp +++ b/src/armadillo_bits/op_dotext_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_dotext_meat.hpp b/src/armadillo_bits/op_dotext_meat.hpp index d681bbf9..c190b2cb 100644 --- a/src/armadillo_bits/op_dotext_meat.hpp +++ b/src/armadillo_bits/op_dotext_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_expmat_bones.hpp b/src/armadillo_bits/op_expmat_bones.hpp index 11902817..5b5cb150 100644 --- a/src/armadillo_bits/op_expmat_bones.hpp +++ b/src/armadillo_bits/op_expmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_expmat_meat.hpp b/src/armadillo_bits/op_expmat_meat.hpp index 7f5db69d..d45fb36d 100644 --- a/src/armadillo_bits/op_expmat_meat.hpp +++ b/src/armadillo_bits/op_expmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,64 +60,106 @@ op_expmat::apply_direct(Mat& out, const Base A = expr.get_ref(); + + arma_debug_check( (A.is_square() == false), "expmat(): given matrix must be square sized" ); + + if(A.is_diagmat()) { - Mat A = expr.get_ref(); + arma_extra_debug_print("op_expmat: detected diagonal matrix"); - arma_debug_check( (A.is_square() == false), "expmat(): given matrix must be square sized" ); + const uword N = (std::min)(A.n_rows, A.n_cols); - const T norm_val = arma::norm(A, "inf"); + out.zeros(N,N); - const double log2_val = (norm_val > T(0)) ? double(eop_aux::log2(norm_val)) : double(0); + for(uword i=0; i::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); + } + + if(do_sym) + { + arma_extra_debug_print("op_expmat: symmetric/hermitian optimisation"); - T c = T(0.5); + Col< T> eigval; + Mat eigvec; - Mat E(A.n_rows, A.n_rows, fill::eye); E += c * A; - Mat D(A.n_rows, A.n_rows, fill::eye); D -= c * A; + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "expmat()"); - Mat X = A; + if(eig_status == false) { return false; } - bool positive = true; + eigval = exp(eigval); - const uword N = 6; + out = eigvec * diagmat(eigval) * eigvec.t(); - for(uword i = 2; i <= N; ++i) - { - c = c * T(N - i + 1) / T(i * (2*N - i + 1)); - - X = A * X; - - E += c * X; - - if(positive) { D += c * X; } else { D -= c * X; } - - positive = (positive) ? false : true; - } + return true; + } + + const T norm_val = arma::norm(A, "inf"); + + if(arma_isfinite(norm_val) == false) { return false; } + + const double log2_val = (norm_val > T(0)) ? double(eop_aux::log2(norm_val)) : double(0); + + int exponent = int(0); std::frexp(log2_val, &exponent); + + const uword s = uword( (std::max)(int(0), exponent + int(1)) ); + + A /= eT(eop_aux::pow(double(2), double(s))); + + T c = T(0.5); + + Mat E(A.n_rows, A.n_rows, fill::eye); E += c * A; + Mat D(A.n_rows, A.n_rows, fill::eye); D -= c * A; + + Mat X = A; + + bool positive = true; + + const uword N = 6; + + for(uword i = 2; i <= N; ++i) + { + c = c * T(N - i + 1) / T(i * (2*N - i + 1)); - if( (D.is_finite() == false) || (E.is_finite() == false) ) { return false; } + X = A * X; - const bool status = solve(out, D, E); + E += c * X; - if(status == false) { return false; } + if(positive) { D += c * X; } else { D -= c * X; } - for(uword i=0; i < s; ++i) { out = out * out; } + positive = (positive) ? false : true; } + if( (D.internal_has_nonfinite()) || (E.internal_has_nonfinite()) ) { return false; } + + const bool status = solve(out, D, E, solve_opts::no_approx); + + if(status == false) { return false; } + + for(uword i=0; i < s; ++i) { out = out * out; } + return true; } @@ -148,14 +192,42 @@ op_expmat_sym::apply_direct(Mat& out, const Base U(expr.get_ref()); const Mat& X = U.M; arma_debug_check( (X.is_square() == false), "expmat_sym(): given matrix must be square sized" ); + if((arma_config::debug) && (arma_config::warn_level > 0) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_expmat_sym: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; Mat eigvec; diff --git a/src/armadillo_bits/op_fft_bones.hpp b/src/armadillo_bits/op_fft_bones.hpp index 4efdfbae..b0dcbfdf 100644 --- a/src/armadillo_bits/op_fft_bones.hpp +++ b/src/armadillo_bits/op_fft_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -39,12 +41,8 @@ class op_fft_cx template inline static void apply( Mat& out, const Op& in ); - template - inline static void apply_noalias(Mat& out, const Proxy& P, const uword a, const uword b); - - template arma_hot inline static void copy_vec (typename Proxy::elem_type* dest, const Proxy& P, const uword N); - template arma_hot inline static void copy_vec_proxy (typename Proxy::elem_type* dest, const Proxy& P, const uword N); - template arma_hot inline static void copy_vec_unwrap(typename Proxy::elem_type* dest, const Proxy& P, const uword N); + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword a, const uword b); }; diff --git a/src/armadillo_bits/op_fft_meat.hpp b/src/armadillo_bits/op_fft_meat.hpp index ed69f2ea..4f5d93af 100644 --- a/src/armadillo_bits/op_fft_meat.hpp +++ b/src/armadillo_bits/op_fft_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,12 +21,56 @@ //! @{ +#if defined(ARMA_USE_FFTW3) + +template +class fft_engine_wrapper + { + public: + + static constexpr uword threshold = 512; + + fft_engine_kissfft* worker_kissfft = nullptr; + fft_engine_fftw3 * worker_fftw3 = nullptr; + + inline + ~fft_engine_wrapper() + { + arma_extra_debug_sigprint(); + + if(worker_kissfft != nullptr) { delete worker_kissfft; } + if(worker_fftw3 != nullptr) { delete worker_fftw3; } + } + + inline + fft_engine_wrapper(const uword N_samples, const uword N_exec) + { + arma_extra_debug_sigprint(); + + const bool use_fftw3 = N_samples >= (threshold / N_exec); + + worker_kissfft = (use_fftw3 == false) ? new fft_engine_kissfft(N_samples) : nullptr; + worker_fftw3 = (use_fftw3 == true ) ? new fft_engine_fftw3 (N_samples) : nullptr; + } + + inline + void + run(cx_type* Y, const cx_type* X) + { + arma_extra_debug_sigprint(); + + if(worker_kissfft != nullptr) { (*worker_kissfft).run(Y,X); } + else if(worker_fftw3 != nullptr) { (*worker_fftw3).run(Y,X); } + } + }; + +#endif + // // op_fft_real - template inline void @@ -35,62 +81,43 @@ op_fft_real::apply( Mat< std::complex >& out, const mtOp< typedef typename T1::pod_type in_eT; typedef typename std::complex out_eT; - const Proxy P(in.m); + // no need to worry about aliasing, as we're going from a real object to complex complex, which by definition cannot alias + + const quasi_unwrap U(in.m); + const Mat& X = U.M; - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_elem = P.get_n_elem(); + const uword n_rows = X.n_rows; + const uword n_cols = X.n_cols; + const uword n_elem = X.n_elem; const bool is_vec = ( (n_rows == 1) || (n_cols == 1) ); const uword N_orig = (is_vec) ? n_elem : n_rows; const uword N_user = (in.aux_uword_b == 0) ? in.aux_uword_a : N_orig; - fft_engine worker(N_user); - - // no need to worry about aliasing, as we're going from a real object to complex complex, which by definition cannot alias + #if defined(ARMA_USE_FFTW3) + const uword N_exec = (is_vec) ? uword(1) : n_cols; + fft_engine_wrapper worker(N_user, N_exec); + #else + fft_engine_kissfft worker(N_user); + #endif if(is_vec) { (n_cols == 1) ? out.set_size(N_user, 1) : out.set_size(1, N_user); - if( (out.n_elem == 0) || (N_orig == 0) ) - { - out.zeros(); - return; - } + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } - if( (N_user == 1) && (N_orig >= 1) ) - { - out[0] = out_eT( P[0] ); - return; - } - - podarray data(N_user); + if( (N_user == 1) && (N_orig >= 1) ) { out[0] = out_eT( X[0] ); return; } - out_eT* data_mem = data.memptr(); + podarray data(N_user, arma_zeros_indicator()); - if(N_user > N_orig) { arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); } + out_eT* data_mem = data.memptr(); + const in_eT* X_mem = X.memptr(); const uword N = (std::min)(N_user, N_orig); - if(Proxy::use_at == false) - { - typename Proxy::ea_type X = P.get_ea(); - - for(uword i=0; i < N; ++i) { data_mem[i] = out_eT( X[i], in_eT(0) ); } - } - else - { - if(n_cols == 1) - { - for(uword i=0; i < N; ++i) { data_mem[i] = out_eT( P.at(i,0), in_eT(0) ); } - } - else - { - for(uword i=0; i < N; ++i) { data_mem[i] = out_eT( P.at(0,i), in_eT(0) ); } - } - } + for(uword i=0; i < N; ++i) { data_mem[i].real(X_mem[i]); } worker.run( out.memptr(), data_mem ); } @@ -100,30 +127,24 @@ op_fft_real::apply( Mat< std::complex >& out, const mtOp< out.set_size(N_user, n_cols); - if( (out.n_elem == 0) || (N_orig == 0) ) - { - out.zeros(); - return; - } + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } if( (N_user == 1) && (N_orig >= 1) ) { - for(uword col=0; col < n_cols; ++col) { out.at(0,col) = out_eT( P.at(0,col) ); } + for(uword col=0; col < n_cols; ++col) { out.at(0,col).real( X.at(0,col) ); } return; } - podarray data(N_user); + podarray data(N_user, arma_zeros_indicator()); out_eT* data_mem = data.memptr(); - if(N_user > N_orig) { arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); } - const uword N = (std::min)(N_user, N_orig); for(uword col=0; col < n_cols; ++col) { - for(uword i=0; i < N; ++i) { data_mem[i] = P.at(i, col); } + for(uword i=0; i < N; ++i) { data_mem[i].real( X.at(i, col) ); } worker.run( out.colptr(col), data_mem ); } @@ -145,77 +166,70 @@ op_fft_cx::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; - const Proxy P(in.m); + const quasi_unwrap U(in.m); - if(P.is_alias(out) == false) - { - op_fft_cx::apply_noalias(out, P, in.aux_uword_a, in.aux_uword_b); - } - else + if(U.is_alias(out)) { Mat tmp; - op_fft_cx::apply_noalias(tmp, P, in.aux_uword_a, in.aux_uword_b); + op_fft_cx::apply_noalias(tmp, U.M, in.aux_uword_a, in.aux_uword_b); out.steal_mem(tmp); } + else + { + op_fft_cx::apply_noalias(out, U.M, in.aux_uword_a, in.aux_uword_b); + } } -template +template inline void -op_fft_cx::apply_noalias(Mat& out, const Proxy& P, const uword a, const uword b) +op_fft_cx::apply_noalias(Mat& out, const Mat& X, const uword a, const uword b) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_elem = P.get_n_elem(); + const uword n_rows = X.n_rows; + const uword n_cols = X.n_cols; + const uword n_elem = X.n_elem; const bool is_vec = ( (n_rows == 1) || (n_cols == 1) ); const uword N_orig = (is_vec) ? n_elem : n_rows; const uword N_user = (b == 0) ? a : N_orig; - fft_engine worker(N_user); + #if defined(ARMA_USE_FFTW3) + const uword N_exec = (is_vec) ? uword(1) : n_cols; + fft_engine_wrapper worker(N_user, N_exec); + #else + fft_engine_kissfft worker(N_user); + #endif if(is_vec) { (n_cols == 1) ? out.set_size(N_user, 1) : out.set_size(1, N_user); - if( (out.n_elem == 0) || (N_orig == 0) ) - { - out.zeros(); - return; - } + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } - if( (N_user == 1) && (N_orig >= 1) ) - { - out[0] = P[0]; - return; - } + if( (N_user == 1) && (N_orig >= 1) ) { out[0] = X[0]; return; } - if( (N_user > N_orig) || (is_Mat::stored_type>::value == false) ) + if(N_user > N_orig) { podarray data(N_user); eT* data_mem = data.memptr(); - if(N_user > N_orig) { arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); } + arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); - op_fft_cx::copy_vec( data_mem, P, (std::min)(N_user, N_orig) ); + arrayops::copy(data_mem, X.memptr(), (std::min)(N_user, N_orig)); worker.run( out.memptr(), data_mem ); } else { - const unwrap< typename Proxy::stored_type > tmp(P.Q); - - worker.run( out.memptr(), tmp.M.memptr() ); + worker.run( out.memptr(), X.memptr() ); } } else @@ -224,50 +238,44 @@ op_fft_cx::apply_noalias(Mat& out, const Proxy& P, c out.set_size(N_user, n_cols); - if( (out.n_elem == 0) || (N_orig == 0) ) - { - out.zeros(); - return; - } + if( (out.n_elem == 0) || (N_orig == 0) ) { out.zeros(); return; } if( (N_user == 1) && (N_orig >= 1) ) { - for(uword col=0; col < n_cols; ++col) { out.at(0,col) = P.at(0,col); } + for(uword col=0; col < n_cols; ++col) { out.at(0,col) = X.at(0,col); } return; } - if( (N_user > N_orig) || (is_Mat::stored_type>::value == false) ) + if(N_user > N_orig) { podarray data(N_user); eT* data_mem = data.memptr(); - if(N_user > N_orig) { arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); } + arrayops::fill_zeros( &data_mem[N_orig], (N_user - N_orig) ); const uword N = (std::min)(N_user, N_orig); for(uword col=0; col < n_cols; ++col) { - for(uword i=0; i < N; ++i) { data_mem[i] = P.at(i, col); } + arrayops::copy(data_mem, X.colptr(col), N); worker.run( out.colptr(col), data_mem ); } } else { - const unwrap< typename Proxy::stored_type > tmp(P.Q); - for(uword col=0; col < n_cols; ++col) { - worker.run( out.colptr(col), tmp.M.colptr(col) ); + worker.run( out.colptr(col), X.colptr(col) ); } } } // correct the scaling for the inverse transform - if(inverse == true) + if(inverse) { typedef typename get_pod_type::result T; @@ -283,70 +291,6 @@ op_fft_cx::apply_noalias(Mat& out, const Proxy& P, c -template -arma_hot -inline -void -op_fft_cx::copy_vec(typename Proxy::elem_type* dest, const Proxy& P, const uword N) - { - arma_extra_debug_sigprint(); - - if(is_Mat< typename Proxy::stored_type >::value == true) - { - op_fft_cx::copy_vec_unwrap(dest, P, N); - } - else - { - op_fft_cx::copy_vec_proxy(dest, P, N); - } - } - - - -template -arma_hot -inline -void -op_fft_cx::copy_vec_unwrap(typename Proxy::elem_type* dest, const Proxy& P, const uword N) - { - arma_extra_debug_sigprint(); - - const unwrap< typename Proxy::stored_type > tmp(P.Q); - - arrayops::copy(dest, tmp.M.memptr(), N); - } - - - -template -arma_hot -inline -void -op_fft_cx::copy_vec_proxy(typename Proxy::elem_type* dest, const Proxy& P, const uword N) - { - arma_extra_debug_sigprint(); - - if(Proxy::use_at == false) - { - typename Proxy::ea_type X = P.get_ea(); - - for(uword i=0; i < N; ++i) { dest[i] = X[i]; } - } - else - { - if(P.get_n_cols() == 1) - { - for(uword i=0; i < N; ++i) { dest[i] = P.at(i,0); } - } - else - { - for(uword i=0; i < N; ++i) { dest[i] = P.at(0,i); } - } - } - } - - - // // op_ifft_cx @@ -360,20 +304,20 @@ op_ifft_cx::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; - const Proxy P(in.m); + const quasi_unwrap U(in.m); - if(P.is_alias(out) == false) - { - op_fft_cx::apply_noalias(out, P, in.aux_uword_a, in.aux_uword_b); - } - else + if(U.is_alias(out)) { Mat tmp; - op_fft_cx::apply_noalias(tmp, P, in.aux_uword_a, in.aux_uword_b); + op_fft_cx::apply_noalias(tmp, U.M, in.aux_uword_a, in.aux_uword_b); out.steal_mem(tmp); } + else + { + op_fft_cx::apply_noalias(out, U.M, in.aux_uword_a, in.aux_uword_b); + } } diff --git a/src/armadillo_bits/op_find_bones.hpp b/src/armadillo_bits/op_find_bones.hpp index dbbbbaf3..6e7c9cc5 100644 --- a/src/armadillo_bits/op_find_bones.hpp +++ b/src/armadillo_bits/op_find_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -39,8 +41,8 @@ class op_find ( Mat& indices, const mtOp& X, - const typename arma_op_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0 + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr ); template @@ -49,8 +51,8 @@ class op_find ( Mat& indices, const mtOp& X, - const typename arma_op_rel_only::result junk1 = 0, - const typename arma_cx_only::result junk2 = 0 + const typename arma_op_rel_only::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr ); template @@ -59,9 +61,9 @@ class op_find ( Mat& indices, const mtGlue& X, - const typename arma_glue_rel_only::result junk1 = 0, - const typename arma_not_cx::result junk2 = 0, - const typename arma_not_cx::result junk3 = 0 + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_not_cx::result* junk2 = nullptr, + const typename arma_not_cx::result* junk3 = nullptr ); template @@ -70,9 +72,9 @@ class op_find ( Mat& indices, const mtGlue& X, - const typename arma_glue_rel_only::result junk1 = 0, - const typename arma_cx_only::result junk2 = 0, - const typename arma_cx_only::result junk3 = 0 + const typename arma_glue_rel_only::result* junk1 = nullptr, + const typename arma_cx_only::result* junk2 = nullptr, + const typename arma_cx_only::result* junk3 = nullptr ); template @@ -114,4 +116,15 @@ class op_find_nonfinite +class op_find_nan + : public traits_op_col + { + public: + + template + inline static void apply(Mat& out, const mtOp& X); + }; + + + //! @} diff --git a/src/armadillo_bits/op_find_meat.hpp b/src/armadillo_bits/op_find_meat.hpp index 33991266..b2fd6dde 100644 --- a/src/armadillo_bits/op_find_meat.hpp +++ b/src/armadillo_bits/op_find_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -79,8 +81,8 @@ op_find::helper ( Mat& indices, const mtOp& X, - const typename arma_op_rel_only::result junk1, - const typename arma_not_cx::result junk2 + const typename arma_op_rel_only::result* junk1, + const typename arma_not_cx::result* junk2 ) { arma_extra_debug_sigprint(); @@ -91,6 +93,11 @@ op_find::helper const eT val = X.aux; + if((is_same_type::yes || is_same_type::yes) && arma_config::debug && arma_isnan(val)) + { + arma_debug_warn_level(1, "find(): NaN is not equal to anything; suggest to use find_nonfinite() instead"); + } + const Proxy A(X.m); const uword n_elem = A.get_n_elem(); @@ -137,8 +144,8 @@ op_find::helper else if(is_same_type::yes) { not_zero_j = (tpj != val); } else { not_zero_j = false; } - if(not_zero_i == true) { indices_mem[n_nz] = i; ++n_nz; } - if(not_zero_j == true) { indices_mem[n_nz] = j; ++n_nz; } + if(not_zero_i) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero_j) { indices_mem[n_nz] = j; ++n_nz; } } if(i < n_elem) @@ -159,7 +166,7 @@ op_find::helper else if(is_same_type::yes) { not_zero = (tmp != val); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } } } else @@ -188,7 +195,7 @@ op_find::helper else if(is_same_type::yes) { not_zero = (tmp != val); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } ++i; } @@ -206,8 +213,8 @@ op_find::helper ( Mat& indices, const mtOp& X, - const typename arma_op_rel_only::result junk1, - const typename arma_cx_only::result junk2 + const typename arma_op_rel_only::result* junk1, + const typename arma_cx_only::result* junk2 ) { arma_extra_debug_sigprint(); @@ -219,9 +226,13 @@ op_find::helper const eT val = X.aux; + if((is_same_type::yes || is_same_type::yes) && arma_config::debug && arma_isnan(val)) + { + arma_debug_warn_level(1, "find(): NaN is not equal to anything; suggest to use find_nonfinite() instead"); + } + const Proxy A(X.m); - ea_type PA = A.get_ea(); const uword n_elem = A.get_n_elem(); indices.set_size(n_elem, 1); @@ -232,6 +243,8 @@ op_find::helper if(Proxy::use_at == false) { + ea_type PA = A.get_ea(); + for(uword i=0; i::yes) { not_zero = (tmp != val); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } } } else @@ -263,7 +276,7 @@ op_find::helper else if(is_same_type::yes) { not_zero = (tmp != val); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } i++; } @@ -281,9 +294,9 @@ op_find::helper ( Mat& indices, const mtGlue& X, - const typename arma_glue_rel_only::result junk1, - const typename arma_not_cx::result junk2, - const typename arma_not_cx::result junk3 + const typename arma_glue_rel_only::result* junk1, + const typename arma_not_cx::result* junk2, + const typename arma_not_cx::result* junk3 ) { arma_extra_debug_sigprint(); @@ -302,34 +315,67 @@ op_find::helper arma_debug_assert_same_size(A, B, "relational operator"); - ea_type1 PA = A.get_ea(); - ea_type2 PB = B.get_ea(); - - const uword n_elem = B.get_n_elem(); + const uword n_elem = A.get_n_elem(); indices.set_size(n_elem, 1); uword* indices_mem = indices.memptr(); uword n_nz = 0; - for(uword i=0; i::use_at == false) && (Proxy::use_at == false)) { - const eT1 tmp1 = PA[i]; - const eT2 tmp2 = PB[i]; + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); - bool not_zero; + for(uword i=0; i::yes) { not_zero = (tmp1 < tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 > tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 <= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 >= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 == tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 != tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 && tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 || tmp2); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + } + } + else + { + const uword n_rows = A.get_n_rows(); + const uword n_cols = A.get_n_cols(); - if(is_same_type::yes) { not_zero = (tmp1 < tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 > tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 <= tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 >= tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 == tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 != tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 && tmp2); } - else if(is_same_type::yes) { not_zero = (tmp1 || tmp2); } - else { not_zero = false; } + uword i = 0; - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + const eT1 tmp1 = A.at(row,col); + const eT2 tmp2 = B.at(row,col); + + bool not_zero; + + if(is_same_type::yes) { not_zero = (tmp1 < tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 > tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 <= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 >= tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 == tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 != tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 && tmp2); } + else if(is_same_type::yes) { not_zero = (tmp1 || tmp2); } + else { not_zero = false; } + + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } + + i++; + } } return n_nz; @@ -344,9 +390,9 @@ op_find::helper ( Mat& indices, const mtGlue& X, - const typename arma_glue_rel_only::result junk1, - const typename arma_cx_only::result junk2, - const typename arma_cx_only::result junk3 + const typename arma_glue_rel_only::result* junk1, + const typename arma_cx_only::result* junk2, + const typename arma_cx_only::result* junk3 ) { arma_extra_debug_sigprint(); @@ -362,19 +408,18 @@ op_find::helper arma_debug_assert_same_size(A, B, "relational operator"); - ea_type1 PA = A.get_ea(); - ea_type2 PB = B.get_ea(); - - const uword n_elem = B.get_n_elem(); + const uword n_elem = A.get_n_elem(); indices.set_size(n_elem, 1); uword* indices_mem = indices.memptr(); uword n_nz = 0; - - if(Proxy::use_at == false) + if((Proxy::use_at == false) && (Proxy::use_at == false)) { + ea_type1 PA = A.get_ea(); + ea_type2 PB = B.get_ea(); + for(uword i=0; i::yes) { not_zero = (PA[i] != PB[i]); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } } } else @@ -402,11 +447,11 @@ op_find::helper else if(is_same_type::yes) { not_zero = (A.at(row,col) != B.at(row,col)); } else { not_zero = false; } - if(not_zero == true) { indices_mem[n_nz] = i; ++n_nz; } + if(not_zero) { indices_mem[n_nz] = i; ++n_nz; } i++; } - } + } return n_nz; } @@ -475,11 +520,13 @@ op_find_finite::apply(Mat& out, const mtOp& X) { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "find_finite(): detection of non-finite values is not reliable in fast math mode"); } + const Proxy P(X.m); const uword n_elem = P.get_n_elem(); - Mat indices(n_elem,1); + Mat indices(n_elem, 1, arma_nozeros_indicator()); uword* indices_mem = indices.memptr(); uword count = 0; @@ -521,11 +568,13 @@ op_find_nonfinite::apply(Mat& out, const mtOp P(X.m); const uword n_elem = P.get_n_elem(); - Mat indices(n_elem,1); + Mat indices(n_elem, 1, arma_nozeros_indicator()); uword* indices_mem = indices.memptr(); uword count = 0; @@ -560,4 +609,52 @@ op_find_nonfinite::apply(Mat& out, const mtOp +inline +void +op_find_nan::apply(Mat& out, const mtOp& X) + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "find_nan(): detection of non-finite values is not reliable in fast math mode"); } + + const Proxy P(X.m); + + const uword n_elem = P.get_n_elem(); + + Mat indices(n_elem, 1, arma_nozeros_indicator()); + + uword* indices_mem = indices.memptr(); + uword count = 0; + + if(Proxy::use_at == false) + { + const typename Proxy::ea_type Pea = P.get_ea(); + + for(uword i=0; i& out, const Proxy& P, const bool asc if(n_elem == 0) { out.set_size(0,1); return true; } if(n_elem == 1) { out.set_size(1,1); out[0] = 0; return true; } - uvec indices(n_elem); + uvec indices(n_elem, arma_nozeros_indicator()); std::vector< arma_find_unique_packet > packet_vec(n_elem); diff --git a/src/armadillo_bits/op_flip_bones.hpp b/src/armadillo_bits/op_flip_bones.hpp index 8997ccd2..c81eca15 100644 --- a/src/armadillo_bits/op_flip_bones.hpp +++ b/src/armadillo_bits/op_flip_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_flip_meat.hpp b/src/armadillo_bits/op_flip_meat.hpp index 1a8cce44..470ae923 100644 --- a/src/armadillo_bits/op_flip_meat.hpp +++ b/src/armadillo_bits/op_flip_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,17 +28,32 @@ op_flipud::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); - const Proxy P(in.m); + typedef typename T1::elem_type eT; - if(is_Mat::stored_type>::value || P.is_alias(out)) + if(is_Mat::value) { - const unwrap::stored_type> U(P.Q); + // allow detection of in-place operation + + const unwrap U(in.m); op_flipud::apply_direct(out, U.M); } else { - op_flipud::apply_proxy_noalias(out, P); + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_flipud::apply_proxy_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_flipud::apply_proxy_noalias(out, P); + } } } @@ -121,6 +138,17 @@ op_flipud::apply_proxy_noalias(Mat& out, const Proxy typedef typename T1::elem_type eT; + typedef typename Proxy::stored_type P_stored_type; + + if(is_Mat::value) + { + const unwrap U(P.Q); + + op_flipud::apply_direct(out, U.M); + + return; + } + const uword P_n_rows = P.get_n_rows(); const uword P_n_cols = P.get_n_cols(); @@ -166,17 +194,32 @@ op_fliplr::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); - const Proxy P(in.m); + typedef typename T1::elem_type eT; - if(is_Mat::stored_type>::value || P.is_alias(out)) + if(is_Mat::value) { - const unwrap::stored_type> U(P.Q); + // allow detection of in-place operation + + const unwrap U(in.m); op_fliplr::apply_direct(out, U.M); } else { - op_fliplr::apply_proxy_noalias(out, P); + const Proxy P(in.m); + + if(P.is_alias(out)) + { + Mat tmp; + + op_fliplr::apply_proxy_noalias(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_fliplr::apply_proxy_noalias(out, P); + } } } @@ -250,6 +293,17 @@ op_fliplr::apply_proxy_noalias(Mat& out, const Proxy typedef typename T1::elem_type eT; + typedef typename Proxy::stored_type P_stored_type; + + if(is_Mat::value) + { + const unwrap U(P.Q); + + op_fliplr::apply_direct(out, U.M); + + return; + } + const uword P_n_rows = P.get_n_rows(); const uword P_n_cols = P.get_n_cols(); diff --git a/src/armadillo_bits/op_hist_bones.hpp b/src/armadillo_bits/op_hist_bones.hpp index 478dd1c8..c014ba26 100644 --- a/src/armadillo_bits/op_hist_bones.hpp +++ b/src/armadillo_bits/op_hist_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_hist_meat.hpp b/src/armadillo_bits/op_hist_meat.hpp index c363afc8..04c5ed8f 100644 --- a/src/armadillo_bits/op_hist_meat.hpp +++ b/src/armadillo_bits/op_hist_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,15 +60,21 @@ op_hist::apply_noalias(Mat& out, const Mat& A, const uword n_bins, co if(max_val < val_i) { max_val = val_i; } } + if(min_val == max_val) + { + min_val -= (n_bins/2); + max_val += (n_bins/2); + } + if(arma_isfinite(min_val) == false) { min_val = priv::most_neg(); } if(arma_isfinite(max_val) == false) { max_val = priv::most_pos(); } - Col c(n_bins); + Col c(n_bins, arma_nozeros_indicator()); eT* c_mem = c.memptr(); for(uword ii=0; ii < n_bins; ++ii) { - c_mem[ii] = (0.5 + ii) / double(n_bins); // TODO: may need to be modified for integer matrices + c_mem[ii] = (0.5 + ii) / double(n_bins); } c = ((max_val - min_val) * c) + min_val; @@ -89,17 +97,26 @@ op_hist::apply(Mat& out, const mtOp& X) const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); - if(U.is_alias(out)) + if(is_non_integral::value) { - Mat tmp; - - op_hist::apply_noalias(tmp, U.M, n_bins, dim); - - out.steal_mem(tmp); + if(U.is_alias(out)) + { + Mat tmp; + + op_hist::apply_noalias(tmp, U.M, n_bins, dim); + + out.steal_mem(tmp); + } + else + { + op_hist::apply_noalias(out, U.M, n_bins, dim); + } } else { - op_hist::apply_noalias(out, U.M, n_bins, dim); + Mat converted = conv_to< Mat >::from(U.M); + + op_hist::apply_noalias(out, converted, n_bins, dim); } } diff --git a/src/armadillo_bits/op_htrans_bones.hpp b/src/armadillo_bits/op_htrans_bones.hpp index f084f498..c10f624f 100644 --- a/src/armadillo_bits/op_htrans_bones.hpp +++ b/src/armadillo_bits/op_htrans_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,16 +29,16 @@ class op_htrans template struct traits { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; }; template - arma_hot arma_inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = 0); + arma_hot inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = nullptr); template - arma_hot inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = 0); + arma_hot inline static void apply_mat_noalias(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = nullptr); // @@ -49,36 +51,34 @@ class op_htrans // template - arma_hot arma_inline static void apply_mat_inplace(Mat& out, const typename arma_not_cx::result* junk = 0); + arma_hot inline static void apply_mat_inplace(Mat& out, const typename arma_not_cx::result* junk = nullptr); template - arma_hot inline static void apply_mat_inplace(Mat& out, const typename arma_cx_only::result* junk = 0); + arma_hot inline static void apply_mat_inplace(Mat& out, const typename arma_cx_only::result* junk = nullptr); // template - arma_hot arma_inline static void apply_mat(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = 0); + inline static void apply_mat(Mat& out, const Mat& A, const typename arma_not_cx::result* junk = nullptr); template - arma_hot inline static void apply_mat(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = 0); + inline static void apply_mat(Mat& out, const Mat& A, const typename arma_cx_only::result* junk = nullptr); // template - arma_hot inline static void apply_proxy(Mat& out, const T1& X); + inline static void apply_proxy(Mat& out, const Proxy& P); // template - arma_hot inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = 0); + inline static void apply_direct(Mat& out, const T1& X); template - arma_hot inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = 0); - - // + inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = nullptr); template - arma_hot inline static void apply(Mat& out, const Op< Op, op_htrans>& in); + inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = nullptr); }; @@ -90,29 +90,16 @@ class op_htrans2 template struct traits { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; }; - template - arma_hot inline static void apply_noalias(Mat& out, const Mat& A, const eT val); - - template - arma_hot inline static void apply(Mat& out, const Mat& A, const eT val); - - // - - template - arma_hot inline static void apply_proxy(Mat& out, const T1& X, const typename T1::elem_type val); - - // - template - arma_hot inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = 0); + inline static void apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk = nullptr); template - arma_hot inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = 0); + inline static void apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk = nullptr); }; diff --git a/src/armadillo_bits/op_htrans_meat.hpp b/src/armadillo_bits/op_htrans_meat.hpp index 1ccd2cf1..e03893c8 100644 --- a/src/armadillo_bits/op_htrans_meat.hpp +++ b/src/armadillo_bits/op_htrans_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,8 +22,7 @@ template -arma_hot -arma_inline +inline void op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma_not_cx::result* junk) { @@ -34,7 +35,6 @@ op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma template -arma_hot inline void op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma_cx_only::result* junk) @@ -86,7 +86,6 @@ op_htrans::apply_mat_noalias(Mat& out, const Mat& A, const typename arma template -arma_hot inline void op_htrans::block_worker(std::complex* Y, const std::complex* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols) @@ -107,7 +106,6 @@ op_htrans::block_worker(std::complex* Y, const std::complex* X, const uwor template -arma_hot inline void op_htrans::apply_mat_noalias_large(Mat< std::complex >& out, const Mat< std::complex >& A) @@ -163,8 +161,7 @@ op_htrans::apply_mat_noalias_large(Mat< std::complex >& out, const Mat< std:: template -arma_hot -arma_inline +inline void op_htrans::apply_mat_inplace(Mat& out, const typename arma_not_cx::result* junk) { @@ -177,7 +174,6 @@ op_htrans::apply_mat_inplace(Mat& out, const typename arma_not_cx::resul template -arma_hot inline void op_htrans::apply_mat_inplace(Mat& out, const typename arma_cx_only::result* junk) @@ -221,8 +217,7 @@ op_htrans::apply_mat_inplace(Mat& out, const typename arma_cx_only::resu template -arma_hot -arma_inline +inline void op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_not_cx::result* junk) { @@ -235,7 +230,6 @@ op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_not_cx< template -arma_hot inline void op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_cx_only::result* junk) @@ -256,205 +250,44 @@ op_htrans::apply_mat(Mat& out, const Mat& A, const typename arma_cx_only template -arma_hot inline void -op_htrans::apply_proxy(Mat& out, const T1& X) +op_htrans::apply_proxy(Mat& out, const Proxy& P) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(X); + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); - // allow detection of in-place transpose - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) - { - const unwrap::stored_type> tmp(P.Q); - - op_htrans::apply_mat(out, tmp.M); - } - else + if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + out.set_size(n_cols, n_rows); - const bool is_alias = P.is_alias(out); + eT* out_mem = out.memptr(); - if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) - { - if(is_alias == false) - { - out.set_size(n_cols, n_rows); - - eT* out_mem = out.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - out_mem[i] = std::conj(Pea[i]); - } - } - else // aliasing - { - Mat out2(n_cols, n_rows); - - eT* out_mem = out2.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - out_mem[i] = std::conj(Pea[i]); - } - - out.steal_mem(out2); - } - } - else - { - if(is_alias == false) - { - out.set_size(n_cols, n_rows); - - eT* outptr = out.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - for(uword j=0; j < n_cols; ++j) - { - (*outptr) = std::conj(P.at(k,j)); - - outptr++; - } - } - } - else // aliasing - { - Mat out2(n_cols, n_rows); - - eT* out2ptr = out2.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - for(uword j=0; j < n_cols; ++j) - { - (*out2ptr) = std::conj(P.at(k,j)); - - out2ptr++; - } - } - - out.steal_mem(out2); - } - } - } - } - - - -template -arma_hot -inline -void -op_htrans::apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk) - { - arma_extra_debug_sigprint(); - arma_ignore(junk); - - op_strans::apply_proxy(out, in.m); - } - - - -template -arma_hot -inline -void -op_htrans::apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk) - { - arma_extra_debug_sigprint(); - arma_ignore(junk); - - op_htrans::apply_proxy(out, in.m); - } - - - -template -arma_hot -inline -void -op_htrans::apply(Mat& out, const Op< Op, op_htrans>& in) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const unwrap tmp(in.m.m); - const Mat& A = tmp.M; - - const bool upper = in.m.aux_uword_a; - - op_trimat::apply_htrans(out, A, upper); - } - - - -// -// op_htrans2 - - - -template -arma_hot -arma_inline -void -op_htrans2::apply_noalias(Mat& out, const Mat& A, const eT val) - { - arma_extra_debug_sigprint(); - - const uword A_n_rows = A.n_rows; - const uword A_n_cols = A.n_cols; - - out.set_size(A_n_cols, A_n_rows); - - if( (A_n_cols == 1) || (A_n_rows == 1) ) - { - const uword n_elem = A.n_elem; + const uword n_elem = P.get_n_elem(); - const eT* A_mem = A.memptr(); - eT* out_mem = out.memptr(); + typename Proxy::ea_type Pea = P.get_ea(); for(uword i=0; i < n_elem; ++i) { - out_mem[i] = val * std::conj(A_mem[i]); + out_mem[i] = std::conj(Pea[i]); } } - else - if( (A_n_rows >= 512) && (A_n_cols >= 512) ) - { - op_htrans::apply_mat_noalias_large(out, A); - arrayops::inplace_mul( out.memptr(), val, out.n_elem ); - } else { + out.set_size(n_cols, n_rows); + eT* outptr = out.memptr(); - for(uword k=0; k < A_n_rows; ++k) + for(uword k=0; k < n_rows; ++k) { - const eT* Aptr = &(A.at(k,0)); - - for(uword j=0; j < A_n_cols; ++j) + for(uword j=0; j < n_cols; ++j) { - (*outptr) = val * std::conj(*Aptr); + (*outptr) = std::conj(P.at(k,j)); - Aptr += A_n_rows; outptr++; } } @@ -463,155 +296,58 @@ op_htrans2::apply_noalias(Mat& out, const Mat& A, const eT val) -template -arma_hot -inline -void -op_htrans2::apply(Mat& out, const Mat& A, const eT val) - { - arma_extra_debug_sigprint(); - - if(&out != &A) - { - op_htrans2::apply_noalias(out, A, val); - } - else - { - const uword n_rows = out.n_rows; - const uword n_cols = out.n_cols; - - if(n_rows == n_cols) - { - arma_extra_debug_print("doing in-place hermitian transpose of a square matrix"); - - // TODO: do multiplication while swapping - - for(uword col=0; col < n_cols; ++col) - { - eT* coldata = out.colptr(col); - - out.at(col,col) = std::conj( out.at(col,col) ); - - for(uword row=(col+1); row < n_rows; ++row) - { - const eT val1 = std::conj(coldata[row]); - const eT val2 = std::conj(out.at(col,row)); - - out.at(col,row) = val1; - coldata[row] = val2; - } - } - - arrayops::inplace_mul( out.memptr(), val, out.n_elem ); - } - else - { - Mat tmp; - op_htrans2::apply_noalias(tmp, A, val); - - out.steal_mem(tmp); - } - } - } - - - template -arma_hot inline void -op_htrans2::apply_proxy(Mat& out, const T1& X, const typename T1::elem_type val) +op_htrans::apply_direct(Mat& out, const T1& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(X); - // allow detection of in-place transpose - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) { - const unwrap::stored_type> tmp(P.Q); + const unwrap U(X); - op_htrans2::apply(out, tmp.M, val); + op_htrans::apply_mat(out, U.M); } else { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + const Proxy P(X); const bool is_alias = P.is_alias(out); - if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) + if(is_Mat::stored_type>::value) { - if(is_alias == false) + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) { - out.set_size(n_cols, n_rows); - - eT* out_mem = out.memptr(); + Mat tmp; - const uword n_elem = P.get_n_elem(); + op_htrans::apply_mat_noalias(tmp, U.M); - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - out_mem[i] = val * std::conj(Pea[i]); - } + out.steal_mem(tmp); } - else // aliasing + else { - Mat out2(n_cols, n_rows); - - eT* out_mem = out2.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i < n_elem; ++i) - { - out_mem[i] = val * std::conj(Pea[i]); - } - - out.steal_mem(out2); + op_htrans::apply_mat_noalias(out, U.M); } } else { - if(is_alias == false) + if(is_alias) { - out.set_size(n_cols, n_rows); + Mat tmp; - eT* outptr = out.memptr(); + op_htrans::apply_proxy(tmp, P); - for(uword k=0; k < n_rows; ++k) - { - for(uword j=0; j < n_cols; ++j) - { - (*outptr) = val * std::conj(P.at(k,j)); - - outptr++; - } - } + out.steal_mem(tmp); } - else // aliasing + else { - Mat out2(n_cols, n_rows); - - eT* out2ptr = out2.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - for(uword j=0; j < n_cols; ++j) - { - (*out2ptr) = val * std::conj(P.at(k,j)); - - out2ptr++; - } - } - - out.steal_mem(out2); + op_htrans::apply_proxy(out, P); } } } @@ -620,7 +356,37 @@ op_htrans2::apply_proxy(Mat& out, const T1& X, const typ template -arma_hot +inline +void +op_htrans::apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_strans::apply_direct(out, in.m); + } + + + +template +inline +void +op_htrans::apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + op_htrans::apply_direct(out, in.m); + } + + + +// +// op_htrans2 + + + +template inline void op_htrans2::apply(Mat& out, const Op& in, const typename arma_not_cx::result* junk) @@ -628,13 +394,14 @@ op_htrans2::apply(Mat& out, const Op& in, arma_extra_debug_sigprint(); arma_ignore(junk); - op_strans2::apply_proxy(out, in.m, in.aux); + op_strans::apply_direct(out, in.m); + + arrayops::inplace_mul(out.memptr(), in.aux, out.n_elem); } template -arma_hot inline void op_htrans2::apply(Mat& out, const Op& in, const typename arma_cx_only::result* junk) @@ -642,7 +409,9 @@ op_htrans2::apply(Mat& out, const Op& in, arma_extra_debug_sigprint(); arma_ignore(junk); - op_htrans2::apply_proxy(out, in.m, in.aux); + op_htrans::apply_direct(out, in.m); + + arrayops::inplace_mul(out.memptr(), in.aux, out.n_elem); } diff --git a/src/armadillo_bits/op_index_max_bones.hpp b/src/armadillo_bits/op_index_max_bones.hpp index 6584b569..d226f22d 100644 --- a/src/armadillo_bits/op_index_max_bones.hpp +++ b/src/armadillo_bits/op_index_max_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -38,10 +40,10 @@ class op_index_max inline static void apply(Cube& out, const mtOpCube& in); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // sparse matrices diff --git a/src/armadillo_bits/op_index_max_meat.hpp b/src/armadillo_bits/op_index_max_meat.hpp index 3b7d1010..5034921f 100644 --- a/src/armadillo_bits/op_index_max_meat.hpp +++ b/src/armadillo_bits/op_index_max_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ op_index_max::apply(Mat& out, const mtOp& in) typedef typename T1::elem_type eT; const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1"); + arma_debug_check( (dim > 1), "index_max(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); const Mat& X = U.M; @@ -88,7 +90,7 @@ op_index_max::apply_noalias(Mat& out, const Mat& X, const uword dim) uword* out_mem = out.memptr(); - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); @@ -98,7 +100,7 @@ op_index_max::apply_noalias(Mat& out, const Mat& X, const uword dim) for(uword row=0; row < X_n_rows; ++row) { - tmp_mem[row] = std::abs(col_mem[row]); + tmp_mem[row] = eop_aux::arma_abs(col_mem[row]); } } else @@ -113,7 +115,7 @@ op_index_max::apply_noalias(Mat& out, const Mat& X, const uword dim) for(uword row=0; row < X_n_rows; ++row) { T& max_val = tmp_mem[row]; - T col_val = (is_cx::yes) ? T(std::abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); + T col_val = (is_cx::yes) ? T(eop_aux::arma_abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); if(max_val < col_val) { @@ -195,7 +197,7 @@ op_index_max::apply_noalias(Cube& out, const Cube& X, const uword dim if(out.is_empty() || X.is_empty()) { return; } - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); eT* tmp_mem = tmp.memptr(); @@ -299,7 +301,7 @@ op_index_max::apply_noalias(Cube& out, const Cube& X, const uword dim if(out.is_empty() || X.is_empty()) { return; } - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); @@ -342,7 +344,7 @@ op_index_max::apply_noalias(Cube& out, const Cube& X, const uword dim uword* out_mem = out.memptr(); - Mat tmp(X_n_rows, X_n_cols); + Mat tmp(X_n_rows, X_n_cols, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); const eT* X_slice0_mem = X.slice_memptr(0); diff --git a/src/armadillo_bits/op_index_min_bones.hpp b/src/armadillo_bits/op_index_min_bones.hpp index 236ba990..050b8c0f 100644 --- a/src/armadillo_bits/op_index_min_bones.hpp +++ b/src/armadillo_bits/op_index_min_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -38,10 +40,10 @@ class op_index_min inline static void apply(Cube& out, const mtOpCube& in); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // sparse matrices diff --git a/src/armadillo_bits/op_index_min_meat.hpp b/src/armadillo_bits/op_index_min_meat.hpp index f33476de..13162ab5 100644 --- a/src/armadillo_bits/op_index_min_meat.hpp +++ b/src/armadillo_bits/op_index_min_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ op_index_min::apply(Mat& out, const mtOp& in) typedef typename T1::elem_type eT; const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "index_min(): parameter 'dim' must be 0 or 1"); + arma_debug_check( (dim > 1), "index_min(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); const Mat& X = U.M; @@ -88,7 +90,7 @@ op_index_min::apply_noalias(Mat& out, const Mat& X, const uword dim) uword* out_mem = out.memptr(); - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); @@ -98,7 +100,7 @@ op_index_min::apply_noalias(Mat& out, const Mat& X, const uword dim) for(uword row=0; row < X_n_rows; ++row) { - tmp_mem[row] = std::abs(col_mem[row]); + tmp_mem[row] = eop_aux::arma_abs(col_mem[row]); } } else @@ -113,7 +115,7 @@ op_index_min::apply_noalias(Mat& out, const Mat& X, const uword dim) for(uword row=0; row < X_n_rows; ++row) { T& min_val = tmp_mem[row]; - T col_val = (is_cx::yes) ? T(std::abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); + T col_val = (is_cx::yes) ? T(eop_aux::arma_abs(col_mem[row])) : T(access::tmp_real(col_mem[row])); if(min_val > col_val) { @@ -195,7 +197,7 @@ op_index_min::apply_noalias(Cube& out, const Cube& X, const uword dim if(out.is_empty() || X.is_empty()) { return; } - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); eT* tmp_mem = tmp.memptr(); @@ -299,7 +301,7 @@ op_index_min::apply_noalias(Cube& out, const Cube& X, const uword dim if(out.is_empty() || X.is_empty()) { return; } - Col tmp(X_n_rows); + Col tmp(X_n_rows, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); @@ -342,7 +344,7 @@ op_index_min::apply_noalias(Cube& out, const Cube& X, const uword dim uword* out_mem = out.memptr(); - Mat tmp(X_n_rows, X_n_cols); + Mat tmp(X_n_rows, X_n_cols, arma_nozeros_indicator()); T* tmp_mem = tmp.memptr(); const eT* X_slice0_mem = X.slice_memptr(0); diff --git a/src/armadillo_bits/op_inv_gen_bones.hpp b/src/armadillo_bits/op_inv_gen_bones.hpp new file mode 100644 index 00000000..fe952ed1 --- /dev/null +++ b/src/armadillo_bits/op_inv_gen_bones.hpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_gen +//! @{ + + + +class op_inv_gen_default + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const char* caller_sig); + }; + + + +class op_inv_gen_full + : public traits_op_default + { + public: + + template + struct pos + { + static constexpr uword n2 = row + col*2; + static constexpr uword n3 = row + col*3; + }; + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const char* caller_sig, const uword flags); + + template + arma_cold inline static bool apply_tiny_2x2(Mat& X); + + template + arma_cold inline static bool apply_tiny_3x3(Mat& X); + }; + + + +template +struct op_inv_gen_state + { + uword size = uword(0); + T rcond = T(0); + bool is_diag = false; + bool is_sym = false; + }; + + + +class op_inv_gen_rcond + : public traits_op_default + { + public: + + template + inline static bool apply_direct(Mat& out_inv, op_inv_gen_state& out_state, const Base& expr); + }; + + + +namespace inv_opts + { + struct opts + { + const uword flags; + + inline constexpr explicit opts(const uword in_flags); + + inline const opts operator+(const opts& rhs) const; + }; + + inline + constexpr + opts::opts(const uword in_flags) + : flags(in_flags) + {} + + inline + const opts + opts::operator+(const opts& rhs) const + { + const opts result( flags | rhs.flags ); + + return result; + } + + // The values below (eg. 1u << 1) are for internal Armadillo use only. + // The values can change without notice. + + static constexpr uword flag_none = uword(0 ); + static constexpr uword flag_fast = uword(1u << 0); + static constexpr uword flag_tiny = uword(1u << 0); // deprecated + static constexpr uword flag_allow_approx = uword(1u << 1); + static constexpr uword flag_likely_sympd = uword(1u << 2); // deprecated + static constexpr uword flag_no_sympd = uword(1u << 3); // deprecated + static constexpr uword flag_no_ugly = uword(1u << 4); + + struct opts_none : public opts { inline constexpr opts_none() : opts(flag_none ) {} }; + struct opts_fast : public opts { inline constexpr opts_fast() : opts(flag_fast ) {} }; + struct opts_tiny : public opts { inline constexpr opts_tiny() : opts(flag_tiny ) {} }; + struct opts_allow_approx : public opts { inline constexpr opts_allow_approx() : opts(flag_allow_approx) {} }; + struct opts_likely_sympd : public opts { inline constexpr opts_likely_sympd() : opts(flag_likely_sympd) {} }; + struct opts_no_sympd : public opts { inline constexpr opts_no_sympd() : opts(flag_no_sympd ) {} }; + struct opts_no_ugly : public opts { inline constexpr opts_no_ugly() : opts(flag_no_ugly ) {} }; + + static constexpr opts_none none; + static constexpr opts_fast fast; + static constexpr opts_tiny tiny; + static constexpr opts_allow_approx allow_approx; + static constexpr opts_likely_sympd likely_sympd; + static constexpr opts_no_sympd no_sympd; + static constexpr opts_no_ugly no_ugly; + } + + + +//! @} diff --git a/src/armadillo_bits/op_inv_gen_meat.hpp b/src/armadillo_bits/op_inv_gen_meat.hpp new file mode 100644 index 00000000..a7585d71 --- /dev/null +++ b/src/armadillo_bits/op_inv_gen_meat.hpp @@ -0,0 +1,428 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_gen +//! @{ + + + +template +inline +void +op_inv_gen_default::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_gen_default::apply_direct(out, X.m, "inv()"); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv(): matrix is singular"); + } + } + + + +template +inline +bool +op_inv_gen_default::apply_direct(Mat& out, const Base& expr, const char* caller_sig) + { + arma_extra_debug_sigprint(); + + return op_inv_gen_full::apply_direct(out, expr, caller_sig, uword(0)); + } + + + +// + + + +template +inline +void +op_inv_gen_full::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const uword flags = X.aux_uword_a; + + const bool status = op_inv_gen_full::apply_direct(out, X.m, "inv()", flags); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv(): matrix is singular"); + } + } + + + +template +inline +bool +op_inv_gen_full::apply_direct(Mat& out, const Base& expr, const char* caller_sig, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(has_user_flags == true ) { arma_extra_debug_print("op_inv_gen_full: has_user_flags = true"); } + if(has_user_flags == false) { arma_extra_debug_print("op_inv_gen_full: has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & inv_opts::flag_fast ); + const bool allow_approx = has_user_flags && bool(flags & inv_opts::flag_allow_approx); + const bool no_ugly = has_user_flags && bool(flags & inv_opts::flag_no_ugly ); + + if(has_user_flags) + { + arma_extra_debug_print("op_inv_gen_full: enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(allow_approx) { arma_extra_debug_print("allow_approx"); } + if(no_ugly ) { arma_extra_debug_print("no_ugly"); } + + arma_debug_check( (fast && allow_approx), "inv(): options 'fast' and 'allow_approx' are mutually exclusive" ); + arma_debug_check( (fast && no_ugly ), "inv(): options 'fast' and 'no_ugly' are mutually exclusive" ); + arma_debug_check( (no_ugly && allow_approx), "inv(): options 'no_ugly' and 'allow_approx' are mutually exclusive" ); + } + + if(no_ugly) + { + op_inv_gen_state inv_state; + + const bool status = op_inv_gen_rcond::apply_direct(out, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { return false; } + + return true; + } + + if(allow_approx) + { + op_inv_gen_state inv_state; + + Mat tmp; + + const bool status = op_inv_gen_rcond::apply_direct(tmp, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) + { + Mat A = expr.get_ref(); + + if(inv_state.is_diag) { return op_pinv::apply_diag(out, A, T(0) ); } + if(inv_state.is_sym ) { return op_pinv::apply_sym (out, A, T(0), uword(0)); } + + return op_pinv::apply_gen(out, A, T(0), uword(0)); + } + + out.steal_mem(tmp); + + return true; + } + + out = expr.get_ref(); + + arma_debug_check( (out.is_square() == false), caller_sig, ": given matrix must be square sized", [&](){ out.soft_reset(); } ); + + const uword N = out.n_rows; + + if(N == 0) { return true; } + + if(is_cx::no) + { + if(N == 1) + { + const eT a = out[0]; + + out[0] = eT(1) / a; + + return (a != eT(0)); + } + else + if(N == 2) + { + const bool status = op_inv_gen_full::apply_tiny_2x2(out); + + if(status) { return true; } + } + else + if(N == 3) + { + const bool status = op_inv_gen_full::apply_tiny_3x3(out); + + if(status) { return true; } + } + + // fallthrough if optimisation failed + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_gen_full: detected diagonal matrix"); + + eT* colmem = out.memptr(); + + for(uword i=0; i strip(expr.get_ref()); + + const bool is_triu_expr = strip.do_triu; + const bool is_tril_expr = strip.do_tril; + + const bool is_triu_mat = (is_triu_expr || is_tril_expr) ? false : ( trimat_helper::is_triu(out)); + const bool is_tril_mat = (is_triu_expr || is_tril_expr) ? false : ((is_triu_mat) ? false : trimat_helper::is_tril(out)); + + if(is_triu_expr || is_tril_expr || is_triu_mat || is_tril_mat) + { + return auxlib::inv_tr(out, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(out); + + if(try_sympd) + { + arma_extra_debug_print("op_inv_gen_full: attempting sympd optimisation"); + + Mat tmp = out; + + bool sympd_state = false; + + const bool status = auxlib::inv_sympd(tmp, sympd_state); + + if(status) { out.steal_mem(tmp); return true; } + + if((status == false) && (sympd_state == true)) { return false; } + + arma_extra_debug_print("op_inv_gen_full: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + return auxlib::inv(out); + } + + + +template +inline +bool +op_inv_gen_full::apply_tiny_2x2(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT* Xm = X.memptr(); + + const eT a = Xm[pos<0,0>::n2]; + const eT b = Xm[pos<0,1>::n2]; + const eT c = Xm[pos<1,0>::n2]; + const eT d = Xm[pos<1,1>::n2]; + + const eT det_val = (a*d - b*c); + const T abs_det_val = std::abs(det_val); + + if((abs_det_val < det_min) || (abs_det_val > det_max) || arma_isnan(det_val)) { return false; } + + Xm[pos<0,0>::n2] = d / det_val; + Xm[pos<0,1>::n2] = -b / det_val; + Xm[pos<1,0>::n2] = -c / det_val; + Xm[pos<1,1>::n2] = a / det_val; + + return true; + } + + + +template +inline +bool +op_inv_gen_full::apply_tiny_3x3(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + Mat Y(3, 3, arma_nozeros_indicator()); + + eT* Xm = X.memptr(); + eT* Ym = Y.memptr(); + + const eT det_val = op_det::apply_tiny_3x3(X); + const T abs_det_val = std::abs(det_val); + + if((abs_det_val < det_min) || (abs_det_val > det_max) || arma_isnan(det_val)) { return false; } + + Ym[pos<0,0>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<1,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<1,2>::n3]) / det_val; + Ym[pos<1,0>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,2>::n3]) / det_val; + Ym[pos<2,0>::n3] = (Xm[pos<2,1>::n3]*Xm[pos<1,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<1,1>::n3]) / det_val; + + Ym[pos<0,1>::n3] = -(Xm[pos<2,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<2,1>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<1,1>::n3] = (Xm[pos<2,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<2,1>::n3] = -(Xm[pos<2,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<2,0>::n3]*Xm[pos<0,1>::n3]) / det_val; + + Ym[pos<0,2>::n3] = (Xm[pos<1,2>::n3]*Xm[pos<0,1>::n3] - Xm[pos<1,1>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<1,2>::n3] = -(Xm[pos<1,2>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,2>::n3]) / det_val; + Ym[pos<2,2>::n3] = (Xm[pos<1,1>::n3]*Xm[pos<0,0>::n3] - Xm[pos<1,0>::n3]*Xm[pos<0,1>::n3]) / det_val; + + const eT check_val = Xm[pos<0,0>::n3]*Ym[pos<0,0>::n3] + Xm[pos<0,1>::n3]*Ym[pos<1,0>::n3] + Xm[pos<0,2>::n3]*Ym[pos<2,0>::n3]; + + const T max_diff = (is_float::value) ? T(1e-4) : T(1e-10); // empirically determined; may need tuning + + if(std::abs(T(1) - check_val) >= max_diff) { return false; } + + arrayops::copy(Xm, Ym, uword(3*3)); + + return true; + } + + + +template +inline +bool +op_inv_gen_rcond::apply_direct(Mat& out, op_inv_gen_state& out_state, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out = expr.get_ref(); + out_state.size = out.n_rows; + out_state.rcond = T(0); + + arma_debug_check( (out.is_square() == false), "inv(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_gen_rcond: detected diagonal matrix"); + + out_state.is_diag = true; + + eT* colmem = out.memptr(); + + T max_abs_src_val = T(0); + T max_abs_inv_val = T(0); + + const uword N = out.n_rows; + + for(uword i=0; i max_abs_src_val) ? abs_src_val : max_abs_src_val; + max_abs_inv_val = (abs_inv_val > max_abs_inv_val) ? abs_inv_val : max_abs_inv_val; + + colmem += N; + } + + out_state.rcond = T(1) / (max_abs_src_val * max_abs_inv_val); + + return true; + } + + const strip_trimat strip(expr.get_ref()); + + const bool is_triu_expr = strip.do_triu; + const bool is_tril_expr = strip.do_tril; + + const bool is_triu_mat = (is_triu_expr || is_tril_expr) ? false : ( trimat_helper::is_triu(out)); + const bool is_tril_mat = (is_triu_expr || is_tril_expr) ? false : ((is_triu_mat) ? false : trimat_helper::is_tril(out)); + + if(is_triu_expr || is_tril_expr || is_triu_mat || is_tril_mat) + { + return auxlib::inv_tr_rcond(out, out_state.rcond, ((is_triu_expr || is_triu_mat) ? uword(0) : uword(1))); + } + + const bool try_sympd = arma_config::optimise_sym && ((auxlib::crippled_lapack(out)) ? false : sym_helper::guess_sympd(out)); + + if(try_sympd) + { + arma_extra_debug_print("op_inv_gen_rcond: attempting sympd optimisation"); + + out_state.is_sym = true; + + Mat tmp = out; + + bool sympd_state = false; + + const bool status = auxlib::inv_sympd_rcond(tmp, sympd_state, out_state.rcond); + + if(status) { out.steal_mem(tmp); return true; } + + if((status == false) && (sympd_state == true)) { return false; } + + arma_extra_debug_print("op_inv_gen_rcond: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + return auxlib::inv_rcond(out, out_state.rcond); + } + + + +//! @} diff --git a/src/armadillo_bits/op_inv_meat.hpp b/src/armadillo_bits/op_inv_meat.hpp deleted file mode 100644 index 983c8f3d..00000000 --- a/src/armadillo_bits/op_inv_meat.hpp +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -//! \addtogroup op_inv -//! @{ - - - -template -inline -void -op_inv::apply(Mat& out, const Op& X) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const strip_diagmat strip(X.m); - - if(strip.do_diagmat) - { - op_inv::apply_diagmat(out, strip.M); - } - else - { - const quasi_unwrap U(X.m); - - if(U.is_alias(out)) - { - Mat tmp; - - op_inv::apply_noalias(tmp, U.M); - - out.steal_mem(tmp); - } - else - { - op_inv::apply_noalias(out, U.M); - } - } - } - - - -template -inline -void -op_inv::apply_noalias(Mat& out, const Mat& A) - { - arma_extra_debug_sigprint(); - - arma_debug_check( (A.n_rows != A.n_cols), "inv(): given matrix must be square sized" ); - - bool status = false; - - if(A.n_rows <= 4) - { - status = auxlib::inv_tiny(out, A); - } - else - if(sympd_helper::guess_sympd(A)) - { - status = auxlib::inv_sympd(out, A); - } - - if(status == false) - { - status = auxlib::inv(out, A); - } - - if(status == false) - { - out.soft_reset(); - arma_stop_runtime_error("inv(): matrix seems singular"); - } - } - - - -template -inline -void -op_inv::apply_diagmat(Mat& out, const T1& X) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const diagmat_proxy A(X); - - arma_debug_check( (A.n_rows != A.n_cols), "inv(): given matrix must be square sized" ); - - const uword N = (std::min)(A.n_rows, A.n_cols); - - bool status = true; - - if(A.is_alias(out) == false) - { - out.zeros(N,N); - - for(uword i=0; i tmp(N, N, fill::zeros); - - for(uword i=0; i -inline -void -op_inv_tr::apply(Mat& out, const Op& X) - { - arma_extra_debug_sigprint(); - - const bool status = auxlib::inv_tr(out, X.m, X.aux_uword_a); - - if(status == false) - { - out.soft_reset(); - arma_stop_runtime_error("inv(): matrix seems singular"); - } - } - - - -template -inline -void -op_inv_sympd::apply(Mat& out, const Op& X) - { - arma_extra_debug_sigprint(); - - const bool status = auxlib::inv_sympd(out, X.m); - - if(status == false) - { - out.soft_reset(); - arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); - } - } - - - -//! @} diff --git a/src/armadillo_bits/op_inv_spd_bones.hpp b/src/armadillo_bits/op_inv_spd_bones.hpp new file mode 100644 index 00000000..85a50132 --- /dev/null +++ b/src/armadillo_bits/op_inv_spd_bones.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_spd +//! @{ + + + +class op_inv_spd_default + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + +class op_inv_spd_full + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + + template + inline static bool apply_direct(Mat& out, const Base& expr, const uword flags); + + template + arma_cold inline static bool apply_tiny_2x2(Mat& X); + }; + + + +template +struct op_inv_spd_state + { + uword size = uword(0); + T rcond = T(0); + bool is_diag = false; + }; + + + +class op_inv_spd_rcond + : public traits_op_default + { + public: + + template + inline static bool apply_direct(Mat& out_inv, op_inv_spd_state& out_state, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_inv_spd_meat.hpp b/src/armadillo_bits/op_inv_spd_meat.hpp new file mode 100644 index 00000000..0c609745 --- /dev/null +++ b/src/armadillo_bits/op_inv_spd_meat.hpp @@ -0,0 +1,365 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_inv_spd +//! @{ + + + +template +inline +void +op_inv_spd_default::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const bool status = op_inv_spd_default::apply_direct(out, X.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); + } + } + + + +template +inline +bool +op_inv_spd_default::apply_direct(Mat& out, const Base& expr) + { + arma_extra_debug_sigprint(); + + return op_inv_spd_full::apply_direct(out, expr, uword(0)); + } + + + +// + + + +template +inline +void +op_inv_spd_full::apply(Mat& out, const Op& X) + { + arma_extra_debug_sigprint(); + + const uword flags = X.aux_uword_a; + + const bool status = op_inv_spd_full::apply_direct(out, X.m, flags); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("inv_sympd(): matrix is singular or not positive definite"); + } + } + + + +template +inline +bool +op_inv_spd_full::apply_direct(Mat& out, const Base& expr, const uword flags) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(has_user_flags == true ) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = true"); } + if(has_user_flags == false) { arma_extra_debug_print("op_inv_spd_full: has_user_flags = false"); } + + const bool fast = has_user_flags && bool(flags & inv_opts::flag_fast ); + const bool allow_approx = has_user_flags && bool(flags & inv_opts::flag_allow_approx); + const bool no_ugly = has_user_flags && bool(flags & inv_opts::flag_no_ugly ); + + if(has_user_flags) + { + arma_extra_debug_print("op_inv_spd_full: enabled flags:"); + + if(fast ) { arma_extra_debug_print("fast"); } + if(allow_approx) { arma_extra_debug_print("allow_approx"); } + if(no_ugly ) { arma_extra_debug_print("no_ugly"); } + + arma_debug_check( (fast && allow_approx), "inv_sympd(): options 'fast' and 'allow_approx' are mutually exclusive" ); + arma_debug_check( (fast && no_ugly ), "inv_sympd(): options 'fast' and 'no_ugly' are mutually exclusive" ); + arma_debug_check( (no_ugly && allow_approx), "inv_sympd(): options 'no_ugly' and 'allow_approx' are mutually exclusive" ); + } + + if(no_ugly) + { + op_inv_spd_state inv_state; + + const bool status = op_inv_spd_rcond::apply_direct(out, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) { return false; } + + return true; + } + + if(allow_approx) + { + op_inv_spd_state inv_state; + + Mat tmp; + + const bool status = op_inv_spd_rcond::apply_direct(tmp, inv_state, expr); + + // workaround for bug in gcc 4.8 + const uword local_size = inv_state.size; + const T local_rcond = inv_state.rcond; + + if((status == false) || (local_rcond < ((std::max)(local_size, uword(1)) * std::numeric_limits::epsilon())) || arma_isnan(local_rcond)) + { + const Mat A = expr.get_ref(); + + if(inv_state.is_diag) { return op_pinv::apply_diag(out, A, T(0)); } + + return op_pinv::apply_sym(out, A, T(0), uword(0)); + } + + out.steal_mem(tmp); + + return true; + } + + out = expr.get_ref(); + + arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if((arma_config::debug) && (arma_config::warn_level > 0)) + { + if(auxlib::rudimentary_sym_check(out) == false) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + else + if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + } + + const uword N = out.n_rows; + + if(N == 0) { return true; } + + if(is_cx::no) + { + if(N == 1) + { + const T a = access::tmp_real(out[0]); + + out[0] = eT(T(1) / a); + + return (a > T(0)); + } + else + if(N == 2) + { + const bool status = op_inv_spd_full::apply_tiny_2x2(out); + + if(status) { return true; } + } + + // fallthrough if optimisation failed + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_spd_full: detected diagonal matrix"); + + eT* colmem = out.memptr(); + + for(uword i=0; i +inline +bool +op_inv_spd_full::apply_tiny_2x2(Mat& X) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + // NOTE: assuming matrix X is square sized + // NOTE: assuming matrix X is symmetric + // NOTE: assuming matrix X is real + + constexpr T det_min = std::numeric_limits::epsilon(); + constexpr T det_max = T(1) / std::numeric_limits::epsilon(); + + eT* Xm = X.memptr(); + + T a = access::tmp_real(Xm[0]); + T c = access::tmp_real(Xm[1]); + T d = access::tmp_real(Xm[3]); + + const T det_val = (a*d - c*c); + + // positive definite iff all leading principal minors are positive + // a = first leading principal minor (top-left 1x1 submatrix) + // det_val = second leading principal minor (top-left 2x2 submatrix) + + if(a <= T(0)) { return false; } + + // NOTE: since det_min is positive, this also checks whether det_val is positive + if((det_val < det_min) || (det_val > det_max) || arma_isnan(det_val)) { return false; } + + d /= det_val; + c /= det_val; + a /= det_val; + + Xm[0] = d; + Xm[1] = -c; + Xm[2] = -c; + Xm[3] = a; + + return true; + } + + + +// + + + +template +inline +bool +op_inv_spd_rcond::apply_direct(Mat& out, op_inv_spd_state& out_state, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + out = expr.get_ref(); + out_state.size = out.n_rows; + out_state.rcond = T(0); + + arma_debug_check( (out.is_square() == false), "inv_sympd(): given matrix must be square sized", [&](){ out.soft_reset(); } ); + + if((arma_config::debug) && (arma_config::warn_level > 0)) + { + if(auxlib::rudimentary_sym_check(out) == false) + { + if(is_cx::no ) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "inv_sympd(): given matrix is not hermitian"); } + } + else + if((is_cx::yes) && (sym_helper::check_diag_imag(out) == false)) + { + arma_debug_warn_level(1, "inv_sympd(): imaginary components on diagonal are non-zero"); + } + } + + if(is_op_diagmat::value || out.is_diagmat()) + { + arma_extra_debug_print("op_inv_spd_rcond: detected diagonal matrix"); + + out_state.is_diag = true; + + eT* colmem = out.memptr(); + + T max_abs_src_val = T(0); + T max_abs_inv_val = T(0); + + const uword N = out.n_rows; + + for(uword i=0; i max_abs_src_val) ? abs_src_val : max_abs_src_val; + max_abs_inv_val = (abs_inv_val > max_abs_inv_val) ? abs_inv_val : max_abs_inv_val; + + colmem += N; + } + + out_state.rcond = T(1) / (max_abs_src_val * max_abs_inv_val); + + return true; + } + + if(auxlib::crippled_lapack(out)) + { + arma_extra_debug_print("op_inv_spd_rcond: workaround for crippled lapack"); + + Mat tmp = out; + + bool sympd_state = false; + + auxlib::inv_sympd(out, sympd_state); + + if(sympd_state == false) { out.soft_reset(); out_state.rcond = T(0); return false; } + + out_state.rcond = auxlib::rcond(tmp); + + if(out_state.rcond == T(0)) { out.soft_reset(); return false; } + + return true; + } + + bool is_sympd_junk = false; + + return auxlib::inv_sympd_rcond(out, is_sympd_junk, out_state.rcond); + } + + + +//! @} diff --git a/src/armadillo_bits/op_log_det_bones.hpp b/src/armadillo_bits/op_log_det_bones.hpp new file mode 100644 index 00000000..e2f3daf0 --- /dev/null +++ b/src/armadillo_bits/op_log_det_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_log_det +//! @{ + + + +class op_log_det + : public traits_op_default + { + public: + + template + inline static bool apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + + template + inline static bool apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + + template + inline static bool apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr); + }; + + + +class op_log_det_sympd + : public traits_op_default + { + public: + + template + inline static bool apply_direct(typename T1::pod_type& out_val, const Base& expr); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_log_det_meat.hpp b/src/armadillo_bits/op_log_det_meat.hpp new file mode 100644 index 00000000..7b888592 --- /dev/null +++ b/src/armadillo_bits/op_log_det_meat.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_log_det +//! @{ + + + +template +inline +bool +op_log_det::apply_direct(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + // typedef typename T1::pod_type T; + + if(strip_diagmat::do_diagmat) + { + const strip_diagmat strip(expr.get_ref()); + + return op_log_det::apply_diagmat(out_val, out_sign, strip.M); + } + + if(strip_trimat::do_trimat) + { + const strip_trimat strip(expr.get_ref()); + + return op_log_det::apply_trimat(out_val, out_sign, strip.M); + } + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "log_det(): given matrix must be square sized" ); + + if(A.is_diagmat()) { return op_log_det::apply_diagmat(out_val, out_sign, A); } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = is_triu ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) { return op_log_det::apply_trimat(out_val, out_sign, A); } + + // const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + // + // if(try_sympd) + // { + // arma_extra_debug_print("op_log_det: attempting sympd optimisation"); + // + // T out_val_real = T(0); + // + // const bool status = auxlib::log_det_sympd(out_val_real, A); + // + // if(status) + // { + // out_val = eT(out_val_real); + // out_sign = T(1); + // + // return true; + // } + // + // arma_extra_debug_print("op_log_det: sympd optimisation failed"); + // + // // restore A as it's destroyed by auxlib::log_det_sympd() + // A = expr.get_ref(); + // + // // fallthrough to the next return statement + // } + + return auxlib::log_det(out_val, out_sign, A); + } + + + +template +inline +bool +op_log_det::apply_diagmat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const diagmat_proxy A(expr.get_ref()); + + arma_debug_check( (A.n_rows != A.n_cols), "log_det(): given matrix must be square sized" ); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + if(N == 0) + { + out_val = eT(0); + out_sign = T(1); + + return true; + } + + eT x = A[0]; + + T sign = (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + eT val = (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + + for(uword i=1; i::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + } + + out_val = val; + out_sign = sign; + + return (arma_isnan(out_val) == false); + } + + + +template +inline +bool +op_log_det::apply_trimat(typename T1::elem_type& out_val, typename T1::pod_type& out_sign, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const Proxy P(expr.get_ref()); + + const uword N = P.get_n_rows(); + + arma_debug_check( (N != P.get_n_cols()), "log_det(): given matrix must be square sized" ); + + if(N == 0) + { + out_val = eT(0); + out_sign = T(1); + + return true; + } + + eT x = P.at(0,0); + + T sign = (is_cx::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + eT val = (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + + for(uword i=1; i::no) ? ( (access::tmp_real(x) < T(0)) ? T(-1) : T(1) ) : T(1); + val += (is_cx::no) ? std::log( (access::tmp_real(x) < T(0)) ? x*T(-1) : x ) : std::log(x); + } + + out_val = val; + out_sign = sign; + + return (arma_isnan(out_val) == false); + } + + + +// + + + +template +inline +bool +op_log_det_sympd::apply_direct(typename T1::pod_type& out_val, const Base& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + Mat A(expr.get_ref()); + + arma_debug_check( (A.is_square() == false), "log_det_sympd(): given matrix must be square sized" ); + + if((arma_config::debug) && (arma_config::warn_level > 0) && (is_cx::yes) && (sym_helper::check_diag_imag(A) == false)) + { + arma_debug_warn_level(1, "log_det_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_log_det_sympd: detected diagonal matrix"); + + eT* colmem = A.memptr(); + + out_val = T(0); + + const uword N = A.n_rows; + + for(uword i=0; i::no ) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "log_det_sympd(): given matrix is not hermitian"); } + } + + return auxlib::log_det_sympd(out_val, A); + } + + + +//! @} diff --git a/src/armadillo_bits/op_logmat_bones.hpp b/src/armadillo_bits/op_logmat_bones.hpp index 1fc37a0c..77e967d7 100644 --- a/src/armadillo_bits/op_logmat_bones.hpp +++ b/src/armadillo_bits/op_logmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_logmat_meat.hpp b/src/armadillo_bits/op_logmat_meat.hpp index 6be7a92f..494f747a 100644 --- a/src/armadillo_bits/op_logmat_meat.hpp +++ b/src/armadillo_bits/op_logmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,7 +60,7 @@ op_logmat::apply_direct(Mat< std::complex >& out, const const uword N = P.n_rows; - out.zeros(N,N); + out.zeros(N,N); // aliasing can't happen as op_logmat is defined as cx_mat = op(mat) for(uword i=0; i >& out, const typedef typename T1::elem_type in_T; typedef typename std::complex out_T; - const Proxy P(expr.get_ref()); + const quasi_unwrap expr_unwrap(expr.get_ref()); + const Mat& A = expr_unwrap.M; - arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "logmat(): given matrix must be square sized" ); + arma_debug_check( (A.is_square() == false), "logmat(): given matrix must be square sized" ); - if(P.get_n_elem() == 0) + if(A.n_elem == 0) { out.reset(); return true; } else - if(P.get_n_elem() == 1) + if(A.n_elem == 1) { out.set_size(1,1); - out[0] = std::log( std::complex( P[0] ) ); + out[0] = std::log( std::complex( A[0] ) ); return true; } - typename Proxy::ea_type Pea = P.get_ea(); + if(A.is_diagmat()) + { + arma_extra_debug_print("op_logmat: detected diagonal matrix"); + + const uword N = A.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_logmat is defined as cx_mat = op(mat) + + for(uword i=0; i= in_T(0)) + { + out.at(i,i) = std::log(val); + } + else + { + out.at(i,i) = std::log( out_T(val) ); + } + } + + return true; + } - Mat U; - Mat S(P.get_n_rows(), P.get_n_rows()); + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); - out_T* Smem = S.memptr(); + if(try_sympd) + { + arma_extra_debug_print("op_logmat: attempting sympd optimisation"); + + // if matrix A is sympd, all its eigenvalues are positive + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "logmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const in_T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >::from( eigvec * diagmat(eigval) * eigvec.t() ); + + return true; + } + } + + arma_extra_debug_print("op_logmat: sympd optimisation failed"); + + // fallthrough if eigen decomposition failed or an eigenvalue is <= 0 + } + + + Mat S(A.n_rows, A.n_cols, arma_nozeros_indicator()); + + const in_T* Amem = A.memptr(); + out_T* Smem = S.memptr(); - const uword n_elem = P.get_n_elem(); + const uword n_elem = A.n_elem; for(uword i=0; i( Pea[i] ); + Smem[i] = std::complex( Amem[i] ); } return op_logmat_cx::apply_common(out, S, n_iters); @@ -204,6 +270,7 @@ op_logmat_cx::apply_direct(Mat& out, const Base S = expr.get_ref(); @@ -223,6 +290,58 @@ op_logmat_cx::apply_direct(Mat& out, const Base eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, S, 'd', "logmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >& out, Mat< std::complex >& if(schur_ok == false) { arma_extra_debug_print("logmat(): schur decomposition failed"); return false; } -//double theta[] = { 1.10e-5, 1.82e-3, 1.62e-2, 5.39e-2, 1.14e-1, 1.87e-1, 2.64e-1 }; - double theta[] = { 0.0, 0.0, 1.6206284795015624e-2, 5.3873532631381171e-2, 1.1352802267628681e-1, 1.8662860613541288e-1, 2.642960831111435e-1 }; - // theta[0] and theta[1] not really used + // NOTE: theta[0] and theta[1] not really used + double theta[] = { 1.10e-5, 1.82e-3, 1.6206284795015624e-2, 5.3873532631381171e-2, 1.1352802267628681e-1, 1.8662860613541288e-1, 2.642960831111435e-1 }; const uword N = S.n_rows; @@ -281,7 +399,7 @@ op_logmat_cx::apply_common(Mat< std::complex >& out, Mat< std::complex >& iter++; } - if(iter >= n_iters) { arma_debug_warn("logmat(): reached max iterations without full convergence"); } + if(iter >= n_iters) { arma_debug_warn_level(2, "logmat(): reached max iterations without full convergence"); } S.diag() -= eT(1); @@ -308,11 +426,11 @@ op_logmat_cx::helper(Mat& A, const uword m) { arma_extra_debug_sigprint(); - if(A.is_finite() == false) { return false; } + if(A.internal_has_nonfinite()) { return false; } const vec indices = regspace(1,m-1); - mat tmp(m,m,fill::zeros); + mat tmp(m, m, arma_zeros_indicator()); tmp.diag(-1) = indices / sqrt(square(2.0*indices) - 1.0); tmp.diag(+1) = indices / sqrt(square(2.0*indices) - 1.0); @@ -329,7 +447,7 @@ op_logmat_cx::helper(Mat& A, const uword m) const uword N = A.n_rows; - Mat B(N,N,fill::zeros); + Mat B(N, N, arma_zeros_indicator()); Mat X; @@ -338,7 +456,7 @@ op_logmat_cx::helper(Mat& A, const uword m) // B += weights(i) * solve( (nodes(i)*A + eye< Mat >(N,N)), A ); //const bool solve_ok = solve( X, (nodes(i)*A + eye< Mat >(N,N)), A, solve_opts::fast ); - const bool solve_ok = solve( X, trimatu(nodes(i)*A + eye< Mat >(N,N)), A ); + const bool solve_ok = solve( X, trimatu(nodes(i)*A + eye< Mat >(N,N)), A, solve_opts::no_approx ); if(solve_ok == false) { arma_extra_debug_print("logmat(): solve() failed"); return false; } @@ -387,6 +505,36 @@ op_logmat_sympd::apply_direct(Mat& out, const Base 0) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "logmat_sympd(): imaginary components on diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_logmat_sympd: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; Mat eigvec; @@ -399,7 +547,7 @@ op_logmat_sympd::apply_direct(Mat& out, const Base& out, const Op& in); template - inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // @@ -43,10 +45,10 @@ class op_max inline static void apply(Cube& out, const OpCube& in); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // diff --git a/src/armadillo_bits/op_max_meat.hpp b/src/armadillo_bits/op_max_meat.hpp index c1b49971..34de86b1 100644 --- a/src/armadillo_bits/op_max_meat.hpp +++ b/src/armadillo_bits/op_max_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ op_max::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "max(): parameter 'dim' must be 0 or 1"); + arma_debug_check( (dim > 1), "max(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); const Mat& X = U.M; @@ -359,7 +361,8 @@ op_max::direct_max(const eT* const X, const uword n_elem) { arma_extra_debug_sigprint(); - eT max_val = priv::most_neg(); + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); uword i,j; for(i=0, j=1; j max_val) { max_val = X_i; } - if(X_j > max_val) { max_val = X_j; } + if(X_i > max_val_i) { max_val_i = X_i; } + if(X_j > max_val_j) { max_val_j = X_j; } } if(i < n_elem) { const eT X_i = X[i]; - if(X_i > max_val) { max_val = X_i; } + if(X_i > max_val_i) { max_val_i = X_i; } } - return max_val; + return (max_val_i > max_val_j) ? max_val_i : max_val_j; } @@ -390,9 +393,11 @@ op_max::direct_max(const eT* const X, const uword n_elem, uword& index_of_max_va { arma_extra_debug_sigprint(); - eT max_val = priv::most_neg(); + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); - uword best_index = 0; + uword best_index_i = 0; + uword best_index_j = 0; uword i,j; for(i=0, j=1; j max_val) - { - max_val = X_i; - best_index = i; - } - - if(X_j > max_val) - { - max_val = X_j; - best_index = j; - } + if(X_i > max_val_i) { max_val_i = X_i; best_index_i = i; } + if(X_j > max_val_j) { max_val_j = X_j; best_index_j = j; } } if(i < n_elem) { const eT X_i = X[i]; - if(X_i > max_val) - { - max_val = X_i; - best_index = i; - } + if(X_i > max_val_i) { max_val_i = X_i; best_index_i = i; } } - index_of_max_val = best_index; + index_of_max_val = (max_val_i > max_val_j) ? best_index_i : best_index_j; - return max_val; + return (max_val_i > max_val_j) ? max_val_i : max_val_j; } @@ -440,7 +432,8 @@ op_max::direct_max(const Mat& X, const uword row) const uword X_n_cols = X.n_cols; - eT max_val = priv::most_neg(); + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); uword i,j; for(i=0, j=1; j < X_n_cols; i+=2, j+=2) @@ -448,18 +441,18 @@ op_max::direct_max(const Mat& X, const uword row) const eT tmp_i = X.at(row,i); const eT tmp_j = X.at(row,j); - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < X_n_cols) { const eT tmp_i = X.at(row,i); - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } - return max_val; + return (max_val_i > max_val_j) ? max_val_i : max_val_j; } @@ -481,10 +474,11 @@ op_max::max(const subview& X) const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; - eT max_val = priv::most_neg(); - if(X_n_rows == 1) { + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + const Mat& A = X.m; const uword start_row = X.aux_row1; @@ -498,23 +492,25 @@ op_max::max(const subview& X) const eT tmp_i = A.at(start_row, i); const eT tmp_j = A.at(start_row, j); - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < end_col_p1) { const eT tmp_i = A.at(start_row, i); - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } + + return (max_val_i > max_val_j) ? max_val_i : max_val_j; } - else + + eT max_val = priv::most_neg(); + + for(uword col=0; col < X_n_cols; ++col) { - for(uword col=0; col < X_n_cols; ++col) - { - max_val = (std::max)(max_val, op_max::direct_max(X.colptr(col), X_n_rows)); - } + max_val = (std::max)(max_val, op_max::direct_max(X.colptr(col), X_n_rows)); } return max_val; @@ -542,7 +538,8 @@ op_max::max(const Base& X) return Datum::nan; } - eT max_val = priv::most_neg(); + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); if(Proxy::use_at == false) { @@ -557,15 +554,15 @@ op_max::max(const Base& X) const eT tmp_i = A[i]; const eT tmp_j = A[j]; - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < n_elem) { const eT tmp_i = A[i]; - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } } else @@ -581,15 +578,15 @@ op_max::max(const Base& X) const eT tmp_i = P.at(0,i); const eT tmp_j = P.at(0,j); - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < n_cols) { const eT tmp_i = P.at(0,i); - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } } else @@ -602,21 +599,21 @@ op_max::max(const Base& X) const eT tmp_i = P.at(i,col); const eT tmp_j = P.at(j,col); - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < n_rows) { const eT tmp_i = P.at(i,col); - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } } } } - return max_val; + return (max_val_i > max_val_j) ? max_val_i : max_val_j; } @@ -645,6 +642,9 @@ op_max::max(const BaseCube& X) if(ProxyCube::use_at == false) { + eT max_val_i = priv::most_neg(); + eT max_val_j = priv::most_neg(); + typedef typename ProxyCube::ea_type ea_type; ea_type A = P.get_ea(); @@ -656,16 +656,18 @@ op_max::max(const BaseCube& X) const eT tmp_i = A[i]; const eT tmp_j = A[j]; - if(tmp_i > max_val) { max_val = tmp_i; } - if(tmp_j > max_val) { max_val = tmp_j; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } + if(tmp_j > max_val_j) { max_val_j = tmp_j; } } if(i < n_elem) { const eT tmp_i = A[i]; - if(tmp_i > max_val) { max_val = tmp_i; } + if(tmp_i > max_val_i) { max_val_i = tmp_i; } } + + max_val = (max_val_i > max_val_j) ? max_val_i : max_val_j; } else { diff --git a/src/armadillo_bits/op_mean_bones.hpp b/src/armadillo_bits/op_mean_bones.hpp index c0194b51..20a86ae2 100644 --- a/src/armadillo_bits/op_mean_bones.hpp +++ b/src/armadillo_bits/op_mean_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_mean_meat.hpp b/src/armadillo_bits/op_mean_meat.hpp index eb7df05d..7e7a49d3 100644 --- a/src/armadillo_bits/op_mean_meat.hpp +++ b/src/armadillo_bits/op_mean_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -134,7 +136,6 @@ op_mean::apply_noalias_unwrap(Mat& out, const Proxy& template -arma_hot inline void op_mean::apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim) @@ -193,7 +194,7 @@ op_mean::apply_noalias_proxy(Mat& out, const Proxy& out /= T(P_n_cols); } - if(out.is_finite() == false) + if(out.internal_has_nonfinite()) { // TODO: replace with dedicated handling to avoid unwrapping op_mean::apply_noalias_unwrap(out, P, dim); @@ -363,7 +364,6 @@ op_mean::apply_noalias_unwrap(Cube& out, const ProxyCube template -arma_hot inline void op_mean::apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim) diff --git a/src/armadillo_bits/op_median_bones.hpp b/src/armadillo_bits/op_median_bones.hpp index 151874b8..8212d148 100644 --- a/src/armadillo_bits/op_median_bones.hpp +++ b/src/armadillo_bits/op_median_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -32,31 +34,33 @@ arma_inline bool operator< (const arma_cx_median_packet& A, const arma_cx_median_packet& B) { - return A.val < B.val; + return (A.val < B.val); } -//! Class for finding median values of a matrix class op_median : public traits_op_xvec { public: template - inline static void apply(Mat& out, const Op& in); + inline static void apply(Mat& out, const Op& expr); + + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); - template - inline static void apply(Mat< std::complex >& out, const Op& in); + template + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // // template - inline static typename T1::elem_type median_vec(const T1& X, const typename arma_not_cx::result* junk = 0); + inline static typename T1::elem_type median_vec(const T1& X, const typename arma_not_cx::result* junk = nullptr); template - inline static typename T1::elem_type median_vec(const T1& X, const typename arma_cx_only::result* junk = 0); + inline static typename T1::elem_type median_vec(const T1& X, const typename arma_cx_only::result* junk = nullptr); // // diff --git a/src/armadillo_bits/op_median_meat.hpp b/src/armadillo_bits/op_median_meat.hpp index c51a2a3d..ae805153 100644 --- a/src/armadillo_bits/op_median_meat.hpp +++ b/src/armadillo_bits/op_median_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,113 +21,83 @@ -//! \brief -//! For each row or for each column, find the median value. -//! The result is stored in a dense matrix that has either one column or one row. -//! The dimension, for which the medians are found, is set via the median() function. template inline void -op_median::apply(Mat& out, const Op& in) +op_median::apply(Mat& out, const Op& expr) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "median(): parameter 'dim' must be 0 or 1" ); - - const Proxy P(in.m); + const quasi_unwrap U(expr.m); - typedef typename Proxy::stored_type P_stored_type; + const uword dim = expr.aux_uword_a; - const bool is_alias = P.is_alias(out); + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + arma_debug_check( (dim > 1), "median(): parameter 'dim' must be 0 or 1" ); - if( (is_Mat::value == true) || is_alias ) + if(U.is_alias(out)) { - const unwrap_check tmp(P.Q, is_alias); + Mat tmp; - const typename unwrap_check::stored_type& X = tmp.M; + op_median::apply_noalias(out, U.M, dim); + + out.steal_mem(tmp); + } + else + { + op_median::apply_noalias(out, U.M, dim); + } + } + + + +template +inline +void +op_median::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) // in each column + { + arma_extra_debug_print("op_median::apply(): dim = 0"); - const uword X_n_rows = X.n_rows; - const uword X_n_cols = X.n_cols; + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); - if(dim == 0) // in each column + if(X_n_rows > 0) { - arma_extra_debug_print("op_median::apply(): dim = 0"); - - out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + std::vector tmp_vec(X_n_rows); - if(X_n_rows > 0) + for(uword col=0; col < X_n_cols; ++col) { - std::vector tmp_vec(X_n_rows); + arrayops::copy( &(tmp_vec[0]), X.colptr(col), X_n_rows ); - for(uword col=0; col < X_n_cols; ++col) - { - arrayops::copy( &(tmp_vec[0]), X.colptr(col), X_n_rows ); - - out[col] = op_median::direct_median(tmp_vec); - } - } - } - else // in each row - { - arma_extra_debug_print("op_median::apply(): dim = 1"); - - out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); - - if(X_n_cols > 0) - { - std::vector tmp_vec(X_n_cols); - - for(uword row=0; row < X_n_rows; ++row) - { - for(uword col=0; col < X_n_cols; ++col) { tmp_vec[col] = X.at(row,col); } - - out[row] = op_median::direct_median(tmp_vec); - } + out[col] = op_median::direct_median(tmp_vec); } } } else + if(dim == 1) // in each row { - const uword P_n_rows = P.get_n_rows(); - const uword P_n_cols = P.get_n_cols(); + arma_extra_debug_print("op_median::apply(): dim = 1"); - if(dim == 0) // in each column + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) { - arma_extra_debug_print("op_median::apply(): dim = 0"); - - out.set_size((P_n_rows > 0) ? 1 : 0, P_n_cols); - - if(P_n_rows > 0) - { - std::vector tmp_vec(P_n_rows); + std::vector tmp_vec(X_n_cols); - for(uword col=0; col < P_n_cols; ++col) - { - for(uword row=0; row < P_n_rows; ++row) { tmp_vec[row] = P.at(row,col); } - - out[col] = op_median::direct_median(tmp_vec); - } - } - } - else // in each row - { - arma_extra_debug_print("op_median::apply(): dim = 1"); - - out.set_size(P_n_rows, (P_n_cols > 0) ? 1 : 0); - - if(P_n_cols > 0) + for(uword row=0; row < X_n_rows; ++row) { - std::vector tmp_vec(P_n_cols); - - for(uword row=0; row < P_n_rows; ++row) - { - for(uword col=0; col < P_n_cols; ++col) { tmp_vec[col] = P.at(row,col); } - - out[row] = op_median::direct_median(tmp_vec); - } + for(uword col=0; col < X_n_cols; ++col) { tmp_vec[col] = X.at(row,col); } + + out[row] = op_median::direct_median(tmp_vec); } } } @@ -133,27 +105,19 @@ op_median::apply(Mat& out, const Op& in) -//! Implementation for complex numbers -template +template inline void -op_median::apply(Mat< std::complex >& out, const Op& in) +op_median::apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk) { arma_extra_debug_sigprint(); + arma_ignore(junk); - typedef typename std::complex eT; - - arma_type_check(( is_same_type::no )); - - const unwrap_check tmp(in.m, out); - const Mat& X = tmp.M; + typedef typename get_pod_type::result T; const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; - const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "median(): parameter 'dim' must be 0 or 1" ); - if(dim == 0) // in each column { arma_extra_debug_print("op_median::apply(): dim = 0"); @@ -174,8 +138,8 @@ op_median::apply(Mat< std::complex >& out, const Op& in) tmp_vec[row].index = row; } - uword index1; - uword index2; + uword index1 = 0; + uword index2 = 0; op_median::direct_cx_median_index(index1, index2, tmp_vec); out[col] = op_mean::robust_mean(colmem[index1], colmem[index2]); @@ -201,8 +165,8 @@ op_median::apply(Mat< std::complex >& out, const Op& in) tmp_vec[col].index = col; } - uword index1; - uword index2; + uword index1 = 0; + uword index2 = 0; op_median::direct_cx_median_index(index1, index2, tmp_vec); out[row] = op_mean::robust_mean( X.at(row,index1), X.at(row,index2) ); @@ -227,11 +191,9 @@ op_median::median_vec typedef typename T1::elem_type eT; - typedef typename Proxy::stored_type P_stored_type; - - const Proxy P(X); + const quasi_unwrap U(X); - const uword n_elem = P.get_n_elem(); + const uword n_elem = U.M.n_elem; if(n_elem == 0) { @@ -240,46 +202,11 @@ op_median::median_vec return Datum::nan; } + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + std::vector tmp_vec(n_elem); - if(is_Mat::value == true) - { - const unwrap tmp(P.Q); - - const typename unwrap::stored_type& Y = tmp.M; - - arrayops::copy( &(tmp_vec[0]), Y.memptr(), n_elem ); - } - else - { - if(Proxy::use_at == false) - { - typedef typename Proxy::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i P(X); + const quasi_unwrap U(X); - const uword n_elem = P.get_n_elem(); + const uword n_elem = U.M.n_elem; if(n_elem == 0) { @@ -312,72 +239,27 @@ op_median::median_vec return Datum::nan; } + arma_debug_check( U.M.internal_has_nan(), "median(): detected NaN" ); + std::vector< arma_cx_median_packet > tmp_vec(n_elem); - if(Proxy::use_at == false) + const eT* A = U.M.memptr(); + + for(uword i=0; i::ea_type ea_type; - - ea_type A = P.get_ea(); - - for(uword i=0; i inline eT diff --git a/src/armadillo_bits/op_min_bones.hpp b/src/armadillo_bits/op_min_bones.hpp index c02674c5..e9f5a621 100644 --- a/src/armadillo_bits/op_min_bones.hpp +++ b/src/armadillo_bits/op_min_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -30,10 +32,10 @@ class op_min inline static void apply(Mat& out, const Op& in); template - inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Mat& out, const Mat& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // @@ -43,10 +45,10 @@ class op_min inline static void apply(Cube& out, const OpCube& in); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_noalias(Cube& out, const Cube& X, const uword dim, const typename arma_cx_only::result* junk = nullptr); // diff --git a/src/armadillo_bits/op_min_meat.hpp b/src/armadillo_bits/op_min_meat.hpp index 844425c9..9879185d 100644 --- a/src/armadillo_bits/op_min_meat.hpp +++ b/src/armadillo_bits/op_min_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ op_min::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "min(): parameter 'dim' must be 0 or 1"); + arma_debug_check( (dim > 1), "min(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); const Mat& X = U.M; @@ -359,7 +361,8 @@ op_min::direct_min(const eT* const X, const uword n_elem) { arma_extra_debug_sigprint(); - eT min_val = priv::most_pos(); + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); uword i,j; for(i=0, j=1; j -inline +inline eT op_min::direct_min(const eT* const X, const uword n_elem, uword& index_of_min_val) { arma_extra_debug_sigprint(); - eT min_val = priv::most_pos(); + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); - uword best_index = 0; + uword best_index_i = 0; + uword best_index_j = 0; uword i,j; for(i=0, j=1; j& X, const uword row) const uword X_n_cols = X.n_cols; - eT min_val = priv::most_pos(); + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); uword i,j; for(i=0, j=1; j < X_n_cols; i+=2, j+=2) @@ -448,18 +441,18 @@ op_min::direct_min(const Mat& X, const uword row) const eT tmp_i = X.at(row,i); const eT tmp_j = X.at(row,j); - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < X_n_cols) { const eT tmp_i = X.at(row,i); - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } - return min_val; + return (min_val_i < min_val_j) ? min_val_i : min_val_j; } @@ -477,14 +470,15 @@ op_min::min(const subview& X) return Datum::nan; } - + const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; - eT min_val = priv::most_pos(); - if(X_n_rows == 1) { + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + const Mat& A = X.m; const uword start_row = X.aux_row1; @@ -498,23 +492,25 @@ op_min::min(const subview& X) const eT tmp_i = A.at(start_row, i); const eT tmp_j = A.at(start_row, j); - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < end_col_p1) { const eT tmp_i = A.at(start_row, i); - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } + + return (min_val_i < min_val_j) ? min_val_i : min_val_j; } - else + + eT min_val = priv::most_pos(); + + for(uword col=0; col < X_n_cols; ++col) { - for(uword col=0; col < X_n_cols; ++col) - { - min_val = (std::min)(min_val, op_min::direct_min(X.colptr(col), X_n_rows)); - } + min_val = (std::min)(min_val, op_min::direct_min(X.colptr(col), X_n_rows)); } return min_val; @@ -542,7 +538,8 @@ op_min::min(const Base& X) return Datum::nan; } - eT min_val = priv::most_pos(); + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); if(Proxy::use_at == false) { @@ -557,15 +554,15 @@ op_min::min(const Base& X) const eT tmp_i = A[i]; const eT tmp_j = A[j]; - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < n_elem) { const eT tmp_i = A[i]; - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } } else @@ -581,15 +578,15 @@ op_min::min(const Base& X) const eT tmp_i = P.at(0,i); const eT tmp_j = P.at(0,j); - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < n_cols) { const eT tmp_i = P.at(0,i); - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } } else @@ -602,21 +599,21 @@ op_min::min(const Base& X) const eT tmp_i = P.at(i,col); const eT tmp_j = P.at(j,col); - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < n_rows) { const eT tmp_i = P.at(i,col); - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } } } } - return min_val; + return (min_val_i < min_val_j) ? min_val_i : min_val_j; } @@ -645,6 +642,9 @@ op_min::min(const BaseCube& X) if(ProxyCube::use_at == false) { + eT min_val_i = priv::most_pos(); + eT min_val_j = priv::most_pos(); + typedef typename ProxyCube::ea_type ea_type; ea_type A = P.get_ea(); @@ -656,16 +656,18 @@ op_min::min(const BaseCube& X) const eT tmp_i = A[i]; const eT tmp_j = A[j]; - if(tmp_i < min_val) { min_val = tmp_i; } - if(tmp_j < min_val) { min_val = tmp_j; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } + if(tmp_j < min_val_j) { min_val_j = tmp_j; } } if(i < n_elem) { const eT tmp_i = A[i]; - if(tmp_i < min_val) { min_val = tmp_i; } + if(tmp_i < min_val_i) { min_val_i = tmp_i; } } + + min_val = (min_val_i < min_val_j) ? min_val_i : min_val_j; } else { diff --git a/src/armadillo_bits/op_misc_bones.hpp b/src/armadillo_bits/op_misc_bones.hpp index fff03be8..5fd15714 100644 --- a/src/armadillo_bits/op_misc_bones.hpp +++ b/src/armadillo_bits/op_misc_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_misc_meat.hpp b/src/armadillo_bits/op_misc_meat.hpp index b21cb168..d1c2f364 100644 --- a/src/armadillo_bits/op_misc_meat.hpp +++ b/src/armadillo_bits/op_misc_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_nonzeros_bones.hpp b/src/armadillo_bits/op_nonzeros_bones.hpp index d7260e9a..8c3fd65e 100644 --- a/src/armadillo_bits/op_nonzeros_bones.hpp +++ b/src/armadillo_bits/op_nonzeros_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_nonzeros_meat.hpp b/src/armadillo_bits/op_nonzeros_meat.hpp index 06c9ed6d..8cf32fa1 100644 --- a/src/armadillo_bits/op_nonzeros_meat.hpp +++ b/src/armadillo_bits/op_nonzeros_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -31,7 +33,7 @@ op_nonzeros::apply_noalias(Mat& out, const Proxy& P) const uword N_max = P.get_n_elem(); - Mat tmp(N_max, 1); + Mat tmp(N_max, 1, arma_nozeros_indicator()); eT* tmp_mem = tmp.memptr(); @@ -111,23 +113,37 @@ op_nonzeros_spmat::apply(Mat& out, const SpToDOp 0) + if(N == 0) { return; } + + if(is_SpMat::stored_type>::value) { - if(is_SpMat::stored_type>::value) - { - const unwrap_spmat::stored_type> U(P.Q); - - arrayops::copy(out.memptr(), U.M.values, N); - } - else + const unwrap_spmat::stored_type> U(P.Q); + + arrayops::copy(out.memptr(), U.M.values, N); + + return; + } + + if(is_SpSubview::stored_type>::value) + { + const SpSubview& sv = reinterpret_cast< const SpSubview& >(P.Q); + + if(sv.n_rows == sv.m.n_rows) { - eT* out_mem = out.memptr(); + const SpMat& m = sv.m; + const uword col = sv.aux_col1; - typename SpProxy::const_iterator_type it = P.begin(); + arrayops::copy(out.memptr(), &(m.values[ m.col_ptrs[col] ]), N); - for(uword i=0; i::const_iterator_type it = P.begin(); + + for(uword i=0; i +struct norm2est_randu_filler + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline norm2est_randu_filler(); + + inline void fill(eT* mem, const uword N); + }; + + +template +struct norm2est_randu_filler< std::complex > + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline norm2est_randu_filler(); + + inline void fill(std::complex* mem, const uword N); + }; + + + +class op_norm2est + : public traits_op_default + { + public: + + template inline static typename T1::pod_type norm2est(const Base& X, const typename T1::pod_type tolerance, const uword max_iter); + template inline static typename T1::pod_type norm2est(const SpBase& X, const typename T1::pod_type tolerance, const uword max_iter); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_norm2est_meat.hpp b/src/armadillo_bits/op_norm2est_meat.hpp new file mode 100644 index 00000000..b809d134 --- /dev/null +++ b/src/armadillo_bits/op_norm2est_meat.hpp @@ -0,0 +1,248 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm2est +//! @{ + + + +template +inline +norm2est_randu_filler::norm2est_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +norm2est_randu_filler::fill(eT* mem, const uword N) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i +inline +norm2est_randu_filler< std::complex >::norm2est_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +norm2est_randu_filler< std::complex >::fill(std::complex* mem, const uword N) + { + arma_extra_debug_sigprint(); + + for(uword i=0; i& mem_i = mem[i]; + + mem_i.real( T(local_u_distr(local_engine)) ); + mem_i.imag( T(local_u_distr(local_engine)) ); + } + } + + + +// +// +// + + + +template +inline +typename T1::pod_type +op_norm2est::norm2est + ( + const Base& X, + const typename T1::pod_type tolerance, + const uword max_iter + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + arma_debug_check( (tolerance < T(0)), "norm2est(): parameter 'tolerance' must be > 0" ); + arma_debug_check( (max_iter == uword(0)), "norm2est(): parameter 'max_iter' must be > 0" ); + + const T tol = (tolerance == T(0)) ? T(1e-6) : T(tolerance); + + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; + + if(A.n_elem == 0) { return T(0); } + + if(A.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm2est(): given matrix has non-finite elements"); } + + if((A.n_rows == 1) || (A.n_cols == 1)) { return op_norm::vec_norm_2( Proxy< Mat >(A) ); } + + norm2est_randu_filler randu_filler; + + Col x(A.n_rows, fill::none); + Col y(A.n_cols, fill::none); + + randu_filler.fill(y.memptr(), y.n_elem); + + T est_old = 0; + T est_cur = 0; + + for(uword i=0; i >(x) ); + + if(x_norm == T(0) || (arma_isfinite(x_norm) == false) || (x.internal_has_nonfinite())) + { + randu_filler.fill(x.memptr(), x.n_elem); + + x_norm = op_norm::vec_norm_2( Proxy< Col >(x) ); + } + + if(x_norm != T(0)) { x /= x_norm; } + + y = A.t() * x; + + est_old = est_cur; + est_cur = op_norm::vec_norm_2( Proxy< Col >(y) ); + + arma_extra_debug_print(arma_str::format("norm2est(): est_old: %e") % est_old); + arma_extra_debug_print(arma_str::format("norm2est(): est_cur: %e") % est_cur); + + if(arma_isfinite(est_cur) == false) { return est_old; } + + if( ((std::abs)(est_cur - est_old)) <= (tol * (std::max)(est_cur,est_old)) ) { break; } + } + + return est_cur; + } + + + +// +// +// + + + +template +inline +typename T1::pod_type +op_norm2est::norm2est + ( + const SpBase& X, + const typename T1::pod_type tolerance, + const uword max_iter + ) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type T; + typedef typename T1::elem_type eT; + + arma_debug_check( (tolerance < T(0)), "norm2est(): parameter 'tolerance' must be > 0" ); + arma_debug_check( (max_iter == uword(0)), "norm2est(): parameter 'max_iter' must be > 0" ); + + const T tol = (tolerance == T(0)) ? T(1e-6) : T(tolerance); + + const unwrap_spmat U(X.get_ref()); + const SpMat& A = U.M; + + if(A.n_nonzero == 0) { return T(0); } + + if(A.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm2est(): given matrix has non-finite elements"); } + + if((A.n_rows == 1) || (A.n_cols == 1)) { return spop_norm::vec_norm_k(A.values, A.n_nonzero, 2); } + + norm2est_randu_filler randu_filler; + + Mat x(A.n_rows, 1, fill::none); + Mat y(A.n_cols, 1, fill::none); + + randu_filler.fill(y.memptr(), y.n_elem); + + T est_old = 0; + T est_cur = 0; + + for(uword i=0; i >(x) ); + + if(x_norm == T(0) || (arma_isfinite(x_norm) == false) || (x.internal_has_nonfinite())) + { + randu_filler.fill(x.memptr(), x.n_elem); + + x_norm = op_norm::vec_norm_2( Proxy< Mat >(x) ); + } + + if(x_norm != T(0)) { x /= x_norm; } + + y = A.t() * x; + + est_old = est_cur; + est_cur = op_norm::vec_norm_2( Proxy< Mat >(y) ); + + arma_extra_debug_print(arma_str::format("norm2est(): est_old: %e") % est_old); + arma_extra_debug_print(arma_str::format("norm2est(): est_cur: %e") % est_cur); + + if(arma_isfinite(est_cur) == false) { return est_old; } + + if( ((std::abs)(est_cur - est_old)) <= (tol * (std::max)(est_cur,est_old)) ) { break; } + } + + return est_cur; + } + + + +//! @} diff --git a/src/armadillo_bits/op_norm_bones.hpp b/src/armadillo_bits/op_norm_bones.hpp index 5ec0b48a..f4023383 100644 --- a/src/armadillo_bits/op_norm_bones.hpp +++ b/src/armadillo_bits/op_norm_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,15 +25,13 @@ class op_norm { public: - // norms for dense vectors and matrices - - template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_not_cx::result* junk = 0); - template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_cx_only::result* junk = 0); + template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_not_cx::result* junk = nullptr); + template arma_hot inline static typename T1::pod_type vec_norm_1(const Proxy& P, const typename arma_cx_only::result* junk = nullptr); template arma_hot inline static eT vec_norm_1_direct_std(const Mat& X); template arma_hot inline static eT vec_norm_1_direct_mem(const uword N, const eT* A); - template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_not_cx::result* junk = 0); - template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_cx_only::result* junk = 0); + template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_not_cx::result* junk = nullptr); + template arma_hot inline static typename T1::pod_type vec_norm_2(const Proxy& P, const typename arma_cx_only::result* junk = nullptr); template arma_hot inline static eT vec_norm_2_direct_std(const Mat& X); template arma_hot inline static eT vec_norm_2_direct_mem(const uword N, const eT* A); template arma_hot inline static eT vec_norm_2_direct_robust(const Mat& X); @@ -41,20 +41,10 @@ class op_norm template arma_hot inline static typename T1::pod_type vec_norm_max(const Proxy& P); template arma_hot inline static typename T1::pod_type vec_norm_min(const Proxy& P); - template inline static typename T1::pod_type mat_norm_1(const Proxy& P); - template inline static typename T1::pod_type mat_norm_2(const Proxy& P); - - template inline static typename T1::pod_type mat_norm_inf(const Proxy& P); - + template inline static typename get_pod_type::result mat_norm_1(const Mat& X); + template inline static typename get_pod_type::result mat_norm_2(const Mat& X); - // norms for sparse matrices - - template inline static typename T1::pod_type mat_norm_1(const SpProxy& P); - - template inline static typename T1::pod_type mat_norm_2(const SpProxy& P, const typename arma_real_only::result* junk = 0); - template inline static typename T1::pod_type mat_norm_2(const SpProxy& P, const typename arma_cx_only::result* junk = 0); - - template inline static typename T1::pod_type mat_norm_inf(const SpProxy& P); + template inline static typename get_pod_type::result mat_norm_inf(const Mat& X); }; diff --git a/src/armadillo_bits/op_norm_meat.hpp b/src/armadillo_bits/op_norm_meat.hpp index 9a4499bb..2f50f14c 100644 --- a/src/armadillo_bits/op_norm_meat.hpp +++ b/src/armadillo_bits/op_norm_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline typename T1::pod_type op_norm::vec_norm_1(const Proxy& P, const typename arma_not_cx::result* junk) @@ -28,9 +29,9 @@ op_norm::vec_norm_1(const Proxy& P, const typename arma_not_cx::stored_type>::value) || (is_subview_col::stored_type>::value); + const bool use_direct_mem = (is_Mat::stored_type>::value) || (is_subview_col::stored_type>::value) || (arma_config::openmp && Proxy::use_mp); - if(have_direct_mem) + if(use_direct_mem) { const quasi_unwrap::stored_type> tmp(P.Q); @@ -107,7 +108,6 @@ op_norm::vec_norm_1(const Proxy& P, const typename arma_not_cx -arma_hot inline typename T1::pod_type op_norm::vec_norm_1(const Proxy& P, const typename arma_cx_only::result* junk) @@ -215,7 +215,6 @@ op_norm::vec_norm_1(const Proxy& P, const typename arma_cx_only -arma_hot inline eT op_norm::vec_norm_1_direct_std(const Mat& X) @@ -250,14 +249,13 @@ op_norm::vec_norm_1_direct_std(const Mat& X) template -arma_hot inline eT op_norm::vec_norm_1_direct_mem(const uword N, const eT* A) { arma_extra_debug_sigprint(); - #if defined(ARMA_SIMPLE_LOOPS) || (defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)) + #if (defined(ARMA_SIMPLE_LOOPS) || defined(__FAST_MATH__)) { eT acc1 = eT(0); @@ -303,7 +301,6 @@ op_norm::vec_norm_1_direct_mem(const uword N, const eT* A) template -arma_hot inline typename T1::pod_type op_norm::vec_norm_2(const Proxy& P, const typename arma_not_cx::result* junk) @@ -311,9 +308,9 @@ op_norm::vec_norm_2(const Proxy& P, const typename arma_not_cx::stored_type>::value) || (is_subview_col::stored_type>::value); + const bool use_direct_mem = (is_Mat::stored_type>::value) || (is_subview_col::stored_type>::value) || (arma_config::openmp && Proxy::use_mp); - if(have_direct_mem) + if(use_direct_mem) { const quasi_unwrap::stored_type> tmp(P.Q); @@ -411,7 +408,6 @@ op_norm::vec_norm_2(const Proxy& P, const typename arma_not_cx -arma_hot inline typename T1::pod_type op_norm::vec_norm_2(const Proxy& P, const typename arma_cx_only::result* junk) @@ -514,7 +510,6 @@ op_norm::vec_norm_2(const Proxy& P, const typename arma_cx_only -arma_hot inline eT op_norm::vec_norm_2_direct_std(const Mat& X) @@ -562,7 +557,6 @@ op_norm::vec_norm_2_direct_std(const Mat& X) template -arma_hot inline eT op_norm::vec_norm_2_direct_mem(const uword N, const eT* A) @@ -571,7 +565,7 @@ op_norm::vec_norm_2_direct_mem(const uword N, const eT* A) eT acc; - #if defined(ARMA_SIMPLE_LOOPS) || (defined(__FINITE_MATH_ONLY__) && (__FINITE_MATH_ONLY__ > 0)) + #if (defined(ARMA_SIMPLE_LOOPS) || defined(__FAST_MATH__)) { eT acc1 = eT(0); @@ -621,7 +615,6 @@ op_norm::vec_norm_2_direct_mem(const uword N, const eT* A) template -arma_hot inline eT op_norm::vec_norm_2_direct_robust(const Mat& X) @@ -686,7 +679,6 @@ op_norm::vec_norm_2_direct_robust(const Mat& X) template -arma_hot inline typename T1::pod_type op_norm::vec_norm_k(const Proxy& P, const int k) @@ -703,15 +695,7 @@ op_norm::vec_norm_k(const Proxy& P, const int k) const uword N = P.get_n_elem(); - uword i,j; - - for(i=0, j=1; j& P, const int k) template -arma_hot inline typename T1::pod_type op_norm::vec_norm_max(const Proxy& P) @@ -810,7 +793,6 @@ op_norm::vec_norm_max(const Proxy& P) template -arma_hot inline typename T1::pod_type op_norm::vec_norm_min(const Proxy& P) @@ -875,139 +857,47 @@ op_norm::vec_norm_min(const Proxy& P) -template -inline -typename T1::pod_type -op_norm::mat_norm_1(const Proxy& P) - { - arma_extra_debug_sigprint(); - - // TODO: this can be sped up with a dedicated implementation - return as_scalar( max( sum(abs(P.Q), 0), 1) ); - } - - - -template -inline -typename T1::pod_type -op_norm::mat_norm_2(const Proxy& P) - { - arma_extra_debug_sigprint(); - - typedef typename T1::pod_type T; - - Col S; - svd(S, P.Q); - - return (S.n_elem > 0) ? max(S) : T(0); - } - - - -template -inline -typename T1::pod_type -op_norm::mat_norm_inf(const Proxy& P) - { - arma_extra_debug_sigprint(); - - // TODO: this can be sped up with a dedicated implementation - return as_scalar( max( sum(abs(P.Q), 1), 0) ); - } - - - -// -// norms for sparse matrices - - - -template +template inline -typename T1::pod_type -op_norm::mat_norm_1(const SpProxy& P) +typename get_pod_type::result +op_norm::mat_norm_1(const Mat& X) { arma_extra_debug_sigprint(); // TODO: this can be sped up with a dedicated implementation - return as_scalar( max( sum(abs(P.Q), 0), 1) ); + return as_scalar( max( sum(abs(X), 0), 1) ); } -template -inline -typename T1::pod_type -op_norm::mat_norm_2(const SpProxy& P, const typename arma_real_only::result* junk) - { - arma_extra_debug_sigprint(); - arma_ignore(junk); - - // norm = sqrt( largest eigenvalue of (A^H)*A ), where ^H is the conjugate transpose - // http://math.stackexchange.com/questions/4368/computing-the-largest-eigenvalue-of-a-very-large-sparse-matrix - - typedef typename T1::elem_type eT; - typedef typename T1::pod_type T; - - const unwrap_spmat::stored_type> tmp(P.Q); - - const SpMat& A = tmp.M; - const SpMat B = trans(A); - - const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); - - Col eigval; - eigs_sym(eigval, C, 1); - - return (eigval.n_elem > 0) ? std::sqrt(eigval[0]) : T(0); - } - - - -template +template inline -typename T1::pod_type -op_norm::mat_norm_2(const SpProxy& P, const typename arma_cx_only::result* junk) +typename get_pod_type::result +op_norm::mat_norm_2(const Mat& X) { arma_extra_debug_sigprint(); - arma_ignore(junk); - typedef typename T1::elem_type eT; - typedef typename T1::pod_type T; + typedef typename get_pod_type::result T; - // we're calling eigs_gen(), which currently requires ARPACK - #if !defined(ARMA_USE_ARPACK) - { - arma_stop_logic_error("norm(): use of ARPACK must be enabled for norm of complex matrices"); - return T(0); - } - #endif - - const unwrap_spmat::stored_type> tmp(P.Q); - - const SpMat& A = tmp.M; - const SpMat B = trans(A); + if(X.internal_has_nonfinite()) { arma_debug_warn_level(1, "norm(): given matrix has non-finite elements"); } - const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); - - Col eigval; - eigs_gen(eigval, C, 1); + Col S; + svd(S, X); - return (eigval.n_elem > 0) ? std::sqrt(std::real(eigval[0])) : T(0); + return (S.n_elem > 0) ? S[0] : T(0); } -template +template inline -typename T1::pod_type -op_norm::mat_norm_inf(const SpProxy& P) +typename get_pod_type::result +op_norm::mat_norm_inf(const Mat& X) { arma_extra_debug_sigprint(); // TODO: this can be sped up with a dedicated implementation - return as_scalar( max( sum(abs(P.Q), 1), 0) ); + return as_scalar( max( sum(abs(X), 1), 0) ); } diff --git a/src/armadillo_bits/op_normalise_bones.hpp b/src/armadillo_bits/op_normalise_bones.hpp index 11d45bdc..4b1932c0 100644 --- a/src/armadillo_bits/op_normalise_bones.hpp +++ b/src/armadillo_bits/op_normalise_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_normalise_meat.hpp b/src/armadillo_bits/op_normalise_meat.hpp index 96088f37..e56c390d 100644 --- a/src/armadillo_bits/op_normalise_meat.hpp +++ b/src/armadillo_bits/op_normalise_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -32,7 +34,7 @@ op_normalise_vec::apply(Mat& out, const Op U(in.m); @@ -65,8 +67,8 @@ op_normalise_mat::apply(Mat& out, const Op 1), "normalise(): parameter 'dim' must be 0 or 1" ); + arma_debug_check( (p == 0), "normalise(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "normalise(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); diff --git a/src/armadillo_bits/op_orth_null_bones.hpp b/src/armadillo_bits/op_orth_null_bones.hpp index 8748f4a3..2cb182ba 100644 --- a/src/armadillo_bits/op_orth_null_bones.hpp +++ b/src/armadillo_bits/op_orth_null_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_orth_null_meat.hpp b/src/armadillo_bits/op_orth_null_meat.hpp index 3ee96ba3..4a776ba4 100644 --- a/src/armadillo_bits/op_orth_null_meat.hpp +++ b/src/armadillo_bits/op_orth_null_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,6 +37,7 @@ op_orth::apply(Mat& out, const Op& expr) if(status == false) { + out.soft_reset(); arma_stop_runtime_error("orth(): svd failed"); } } @@ -53,18 +56,17 @@ op_orth::apply_direct(Mat& out, const Base= 0"); - const unwrap tmp(expr.get_ref()); - const Mat& X = tmp.M; + Mat A(expr.get_ref()); Mat U; Col< T> s; Mat V; - const bool status = auxlib::svd_dc(U, s, V, X); + const bool status = auxlib::svd_dc(U, s, V, A); V.reset(); - if(status == false) { out.soft_reset(); return false; } + if(status == false) { return false; } if(s.is_empty()) { out.reset(); return true; } @@ -72,7 +74,7 @@ op_orth::apply_direct(Mat& out, const Base::epsilon(); } + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } uword count = 0; @@ -84,7 +86,7 @@ op_orth::apply_direct(Mat& out, const Base& out, const Op& expr) if(status == false) { + out.soft_reset(); arma_stop_runtime_error("null(): svd failed"); } } @@ -129,18 +132,17 @@ op_null::apply_direct(Mat& out, const Base= 0"); - const unwrap tmp(expr.get_ref()); - const Mat& X = tmp.M; + Mat A(expr.get_ref()); Mat U; Col< T> s; Mat V; - const bool status = auxlib::svd_dc(U, s, V, X); + const bool status = auxlib::svd_dc(U, s, V, A); U.reset(); - if(status == false) { out.soft_reset(); return false; } + if(status == false) { return false; } if(s.is_empty()) { out.reset(); return true; } @@ -148,15 +150,15 @@ op_null::apply_direct(Mat& out, const Base::epsilon(); } + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } uword count = 0; for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } - if(count < X.n_cols) + if(count < A.n_cols) { - out = V.tail_cols(X.n_cols - count); + out = V.tail_cols(A.n_cols - count); const uword out_n_elem = out.n_elem; eT* out_mem = out.memptr(); @@ -168,7 +170,7 @@ op_null::apply_direct(Mat& out, const Base inline static void apply(Mat& out, const Op& in); + + template inline static bool apply_direct(Mat& out, const Base& expr); + }; + + + class op_pinv : public traits_op_default { public: template inline static void apply(Mat& out, const Op& in); - - template inline static bool apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const bool use_divide_and_conquer); + + template inline static bool apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const uword method_id); + + template inline static bool apply_diag(Mat& out, const Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_sym (Mat& out, const Mat& A, typename get_pod_type::result tol, const uword method_id); + + template inline static bool apply_gen (Mat& out, Mat& A, typename get_pod_type::result tol, const uword method_id); }; diff --git a/src/armadillo_bits/op_pinv_meat.hpp b/src/armadillo_bits/op_pinv_meat.hpp index d2c79b01..326a0bef 100644 --- a/src/armadillo_bits/op_pinv_meat.hpp +++ b/src/armadillo_bits/op_pinv_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,20 +25,59 @@ template inline void -op_pinv::apply(Mat& out, const Op& in) +op_pinv_default::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + const bool status = op_pinv_default::apply_direct(out, in.m); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("pinv(): svd failed"); + } + } + + + +template +inline +bool +op_pinv_default::apply_direct(Mat& out, const Base& expr) { arma_extra_debug_sigprint(); typedef typename T1::pod_type T; - const T tol = access::tmp_real(in.aux); + constexpr T tol = T(0); + constexpr uword method_id = uword(0); - const bool use_divide_and_conquer = (in.aux_uword_a == 1); + return op_pinv::apply_direct(out, expr, tol, method_id); + } + + + +// + + + +template +inline +void +op_pinv::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); - const bool status = op_pinv::apply_direct(out, in.m, tol, use_divide_and_conquer); + typedef typename T1::pod_type T; + + const T tol = access::tmp_real(in.aux); + const uword method_id = in.aux_uword_a; + + const bool status = op_pinv::apply_direct(out, in.m, tol, method_id); if(status == false) { + out.soft_reset(); arma_stop_runtime_error("pinv(): svd failed"); } } @@ -46,7 +87,7 @@ op_pinv::apply(Mat& out, const Op& in) template inline bool -op_pinv::apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const bool use_divide_and_conquer) +op_pinv::apply_direct(Mat& out, const Base& expr, typename T1::pod_type tol, const uword method_id) { arma_extra_debug_sigprint(); @@ -55,126 +96,213 @@ op_pinv::apply_direct(Mat& out, const Base= 0"); - const Proxy P(expr.get_ref()); + // method_id = 0 -> default setting + // method_id = 1 -> use standard algorithm + // method_id = 2 -> use divide and conquer algorithm + + Mat A(expr.get_ref()); - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + if(A.is_empty()) { out.set_size(A.n_cols,A.n_rows); return true; } - if( (n_rows*n_cols) == 0 ) + if(is_op_diagmat::value || A.is_diagmat()) { - out.set_size(n_cols,n_rows); - return true; + arma_extra_debug_print("op_pinv: detected diagonal matrix"); + + return op_pinv::apply_diag(out, A, tol); } + bool do_sym = false; - // economical SVD decomposition - Mat U; - Col< T> s; - Mat V; + const bool is_sym_size_ok = (A.n_rows == A.n_cols) && (A.n_rows > (is_cx::yes ? uword(20) : uword(40))); - bool status = false; + if( (is_sym_size_ok) && (arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) ) + { + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + do_sym = ((is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd)); + } - if(use_divide_and_conquer) + if(do_sym) { - status = (n_cols > n_rows) ? auxlib::svd_dc_econ(U, s, V, trans(P.Q)) : auxlib::svd_dc_econ(U, s, V, P.Q); + arma_extra_debug_print("op_pinv: symmetric/hermitian optimisation"); + + return op_pinv::apply_sym(out, A, tol, method_id); } - else + + return op_pinv::apply_gen(out, A, tol, method_id); + } + + + +template +inline +bool +op_pinv::apply_diag(Mat& out, const Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(A.n_cols, A.n_rows); + + const uword N = (std::min)(A.n_rows, A.n_cols); + + podarray diag_abs_vals(N); + + T max_abs_Aii = T(0); + + for(uword i=0; i n_rows) ? auxlib::svd_econ(U, s, V, trans(P.Q), 'b') : auxlib::svd_econ(U, s, V, P.Q, 'b'); + const eT Aii = A.at(i,i); + const T abs_Aii = std::abs(Aii); + + if(arma_isnan(Aii)) { return false; } + + diag_abs_vals[i] = abs_Aii; + + max_abs_Aii = (abs_Aii > max_abs_Aii) ? abs_Aii : max_abs_Aii; } - if(status == false) + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_Aii * std::numeric_limits::epsilon(); } + + for(uword i=0; i= tol) + { + const eT Aii = A.at(i,i); + + if(Aii != eT(0)) { out.at(i,i) = eT(eT(1) / Aii); } + } } - const uword s_n_elem = s.n_elem; - const T* s_mem = s.memptr(); + return true; + } + + + +template +inline +bool +op_pinv::apply_sym(Mat& out, const Mat& A, typename get_pod_type::result tol, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col< T> eigval; + Mat eigvec; + + const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::eig_sym_dc(eigval, eigvec, A) : auxlib::eig_sym(eigval, eigvec, A); + + if(status == false) { return false; } + + if(eigval.n_elem == 0) { out.zeros(A.n_cols, A.n_rows); return true; } + + Col abs_eigval = arma::abs(eigval); + + const uvec indices = sort_index(abs_eigval, "descend"); + + abs_eigval = abs_eigval.elem(indices); + eigval = eigval.elem(indices); + eigvec = eigvec.cols(indices); // set tolerance to default if it hasn't been specified - if( (tol == T(0)) && (s_n_elem > 0) ) + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * abs_eigval[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < abs_eigval.n_elem; ++i) { count += (abs_eigval[i] >= tol) ? uword(1) : uword(0); } + + if(count == 0) { out.zeros(A.n_cols, A.n_rows); return true; } + + Col eigval2(count, arma_nozeros_indicator()); + + uword count2 = 0; + + for(uword i=0; i < eigval.n_elem; ++i) { - tol = (std::max)(n_rows, n_cols) * s_mem[0] * std::numeric_limits::epsilon(); + const T abs_val = abs_eigval[i]; + const T val = eigval[i]; + + if(abs_val >= tol) { eigval2[count2] = (val != T(0)) ? T(T(1) / val) : T(0); ++count2; } } + const Mat eigvec_use(eigvec.memptr(), eigvec.n_rows, count, false); + + out = (eigvec_use * diagmat(eigval2)).eval() * eigvec_use.t(); + + return true; + } + + + + +template +inline +bool +op_pinv::apply_gen(Mat& out, Mat& A, typename get_pod_type::result tol, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + // economical SVD decomposition + Mat U; + Col< T> s; + Mat V; + + if(n_cols > n_rows) { A = trans(A); } + + const bool status = ((method_id == uword(0)) || (method_id == uword(2))) ? auxlib::svd_dc_econ(U, s, V, A) : auxlib::svd_econ(U, s, V, A, 'b'); + + if(status == false) { return false; } + + // set tolerance to default if it hasn't been specified + if( (tol == T(0)) && (s.n_elem > 0) ) { tol = (std::max)(n_rows, n_cols) * s[0] * std::numeric_limits::epsilon(); } uword count = 0; - for(uword i = 0; i < s_n_elem; ++i) + for(uword i=0; i < s.n_elem; ++i) { count += (s[i] >= tol) ? uword(1) : uword(0); } + + if(count == 0) { out.zeros(n_cols, n_rows); return true; } + + Col s2(count, arma_nozeros_indicator()); + + uword count2 = 0; + + for(uword i=0; i < s.n_elem; ++i) { - count += (s_mem[i] >= tol) ? uword(1) : uword(0); + const T val = s[i]; + + if(val >= tol) { s2[count2] = (val > T(0)) ? T(T(1) / val) : T(0); ++count2; } } + const Mat U_use(U.memptr(), U.n_rows, count, false); + const Mat V_use(V.memptr(), V.n_rows, count, false); + + Mat tmp; - if(count > 0) + if(n_rows >= n_cols) { - Col s2(count); + // out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( (U.n_cols > count) ? U.cols(0,count-1) : U ); - T* s2_mem = s2.memptr(); + tmp = V_use * diagmat(s2); - uword count2 = 0; - - for(uword i=0; i < s_n_elem; ++i) - { - const T val = s_mem[i]; - - if(val >= tol) { s2_mem[count2] = T(1) / val; ++count2; } - } - - - if(n_rows >= n_cols) - { - // out = ( (V.n_cols > count) ? V.cols(0,count-1) : V ) * diagmat(s2) * trans( (U.n_cols > count) ? U.cols(0,count-1) : U ); - - Mat tmp; - - if(count < V.n_cols) - { - tmp = V.cols(0,count-1) * diagmat(s2); - } - else - { - tmp = V * diagmat(s2); - } - - if(count < U.n_cols) - { - out = tmp * trans(U.cols(0,count-1)); - } - else - { - out = tmp * trans(U); - } - } - else - { - // out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( (V.n_cols > count) ? V.cols(0,count-1) : V ); - - Mat tmp; - - if(count < U.n_cols) - { - tmp = U.cols(0,count-1) * diagmat(s2); - } - else - { - tmp = U * diagmat(s2); - } - - if(count < V.n_cols) - { - out = tmp * trans(V.cols(0,count-1)); - } - else - { - out = tmp * trans(V); - } - } + out = tmp * trans(U_use); } else { - out.zeros(n_cols, n_rows); + // out = ( (U.n_cols > count) ? U.cols(0,count-1) : U ) * diagmat(s2) * trans( (V.n_cols > count) ? V.cols(0,count-1) : V ); + + tmp = U_use * diagmat(s2); + + out = tmp * trans(V_use); } return true; diff --git a/src/armadillo_bits/op_inv_bones.hpp b/src/armadillo_bits/op_powmat_bones.hpp similarity index 61% rename from src/armadillo_bits/op_inv_bones.hpp rename to src/armadillo_bits/op_powmat_bones.hpp index 9ef4e2bd..021522b1 100644 --- a/src/armadillo_bits/op_inv_bones.hpp +++ b/src/armadillo_bits/op_powmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,49 +16,39 @@ // ------------------------------------------------------------------------ -//! \addtogroup op_inv + +//! \addtogroup op_powmat //! @{ -//! 'invert matrix' operation (general matrices) -class op_inv +class op_powmat : public traits_op_default { public: template - inline static void apply(Mat& out, const Op& in); - - template - inline static void apply_noalias(Mat& out, const Mat& A); + inline static void apply(Mat& out, const Op& expr); template - inline static void apply_diagmat(Mat& out, const T1& X); + inline static bool apply_direct(Mat& out, const Base& X, const uword y, const bool y_neg); + + template + inline static void apply_direct_positive(Mat& out, const Mat& X, const uword y); }; -//! 'invert matrix' operation (triangular matrices) -class op_inv_tr +class op_powmat_cx : public traits_op_default { public: template - inline static void apply(Mat& out, const Op& in); - }; - - - -//! 'invert matrix' operation (symmetric positive definite matrices) -class op_inv_sympd - : public traits_op_default - { - public: + inline static void apply(Mat< std::complex >& out, const mtOp,T1,op_powmat_cx>& expr); template - inline static void apply(Mat& out, const Op& in); + inline static bool apply_direct(Mat< std::complex >& out, const Base& X, const typename T1::pod_type y); }; diff --git a/src/armadillo_bits/op_powmat_meat.hpp b/src/armadillo_bits/op_powmat_meat.hpp new file mode 100644 index 00000000..323db320 --- /dev/null +++ b/src/armadillo_bits/op_powmat_meat.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_powmat +//! @{ + + +template +inline +void +op_powmat::apply(Mat& out, const Op& expr) + { + arma_extra_debug_sigprint(); + + const uword y = expr.aux_uword_a; + const bool y_neg = (expr.aux_uword_b == uword(1)); + + const bool status = op_powmat::apply_direct(out, expr.m, y, y_neg); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("powmat(): transformation failed"); + } + } + + + +template +inline +bool +op_powmat::apply_direct(Mat& out, const Base& X, const uword y, const bool y_neg) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + if(y_neg) + { + if(y == uword(1)) + { + return op_inv_gen_default::apply_direct(out, X.get_ref(), "powmat()"); + } + else + { + Mat X_inv; + + const bool inv_status = op_inv_gen_default::apply_direct(X_inv, X.get_ref(), "powmat()"); + + if(inv_status == false) { return false; } + + op_powmat::apply_direct_positive(out, X_inv, y); + } + } + else + { + const quasi_unwrap U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "powmat(): given matrix must be square sized" ); + + op_powmat::apply_direct_positive(out, U.M, y); + } + + return true; + } + + + +template +inline +void +op_powmat::apply_direct_positive(Mat& out, const Mat& X, const uword y) + { + arma_extra_debug_sigprint(); + + const uword N = X.n_rows; + + if(y == uword(0)) { out.eye(N,N); return; } + if(y == uword(1)) { out = X; return; } + + if(X.is_diagmat()) + { + arma_extra_debug_print("op_powmat: detected diagonal matrix"); + + podarray tmp(N); // use temporary array in case we have aliasing + + for(uword i=0; i tmp = X*X; out = X*tmp; } + else if(y == uword(4)) { const Mat tmp = X*X; out = tmp*tmp; } + else if(y == uword(5)) { const Mat tmp = X*X; out = X*tmp*tmp; } + else + { + Mat tmp = X; + + out = X; + + uword z = y-1; + + while(z > 0) + { + if(z & 1) { out = tmp * out; } + + z /= uword(2); + + if(z > 0) { tmp = tmp * tmp; } + } + } + } + } + + + +template +inline +void +op_powmat_cx::apply(Mat< std::complex >& out, const mtOp,T1,op_powmat_cx>& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::pod_type in_T; + + const in_T y = std::real(expr.aux_out_eT); + + const bool status = op_powmat_cx::apply_direct(out, expr.m, y); + + if(status == false) + { + out.soft_reset(); + arma_stop_runtime_error("powmat(): transformation failed"); + } + } + + + +template +inline +bool +op_powmat_cx::apply_direct(Mat< std::complex >& out, const Base& X, const typename T1::pod_type y) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type in_T; + typedef std::complex out_eT; + + if( y == in_T(int(y)) ) + { + arma_extra_debug_print("op_powmat_cx::apply_direct(): integer exponent detected; redirecting to op_powmat"); + + const uword y_val = (y < int(0)) ? uword(-y) : uword(y); + const bool y_neg = (y < int(0)); + + Mat tmp; + + const bool status = op_powmat::apply_direct(tmp, X.get_ref(), y_val, y_neg); + + if(status == false) { return false; } + + out = conv_to< Mat >::from(tmp); + + return true; + } + + const quasi_unwrap U(X.get_ref()); + const Mat& A = U.M; + + arma_debug_check( (A.is_square() == false), "powmat(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_powmat_cx: detected diagonal matrix"); + + podarray tmp(N); // use temporary array in case we have aliasing + + for(uword i=0; i(A.at(i,i)), y) ; } + + out.zeros(N,N); + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym(eigval, eigvec, A); + + if(eig_status) + { + eigval = pow(eigval, y); + + const Mat tmp = diagmat(eigval) * eigvec.t(); + + out = conv_to< Mat >::from(eigvec * tmp); + + return true; + } + + arma_extra_debug_print("op_powmat_cx: sympd optimisation failed"); + + // fallthrough if optimisation failed + } + + bool powmat_status = false; + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_gen(eigval, eigvec, A); + + if(eig_status) + { + eigval = pow(eigval, y); + + Mat eigvec_t = trans(eigvec); + Mat tmp = diagmat(conj(eigval)) * eigvec_t; + + const bool solve_status = auxlib::solve_square_fast(out, eigvec_t, tmp); + + if(solve_status) { out = trans(out); powmat_status = true; } + } + + return powmat_status; + } + + + +//! @} diff --git a/src/armadillo_bits/op_princomp_bones.hpp b/src/armadillo_bits/op_princomp_bones.hpp index d75a780c..4d6abaab 100644 --- a/src/armadillo_bits/op_princomp_bones.hpp +++ b/src/armadillo_bits/op_princomp_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_princomp_meat.hpp b/src/armadillo_bits/op_princomp_meat.hpp index f506ba21..db6f83fe 100644 --- a/src/armadillo_bits/op_princomp_meat.hpp +++ b/src/armadillo_bits/op_princomp_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -72,7 +74,7 @@ op_princomp::direct_princomp { score_out.cols(n_rows-1,n_cols-1).zeros(); - Col s_tmp(n_cols, fill::zeros); + Col s_tmp(n_cols, arma_zeros_indicator()); s_tmp.rows(0,n_rows-2) = s.rows(0,n_rows-2); s = s_tmp; @@ -164,7 +166,7 @@ op_princomp::direct_princomp { score_out.cols(n_rows-1,n_cols-1).zeros(); - Col s_tmp(n_cols, fill::zeros); + Col s_tmp(n_cols, arma_zeros_indicator()); s_tmp.rows(0,n_rows-2) = s.rows(0,n_rows-2); s = s_tmp; @@ -264,6 +266,7 @@ op_princomp::direct_princomp arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; const unwrap Y( X.get_ref() ); const Mat& in = Y.M; @@ -274,7 +277,7 @@ op_princomp::direct_princomp // singular value decomposition Mat U; - Col s; + Col< T> s; const bool svd_ok = (in.n_rows >= in.n_cols) ? svd_econ(U, s, coeff_out, tmp) : svd(U, s, coeff_out, tmp); @@ -301,12 +304,7 @@ op_princomp::apply { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - - const unwrap_check tmp(in.m, out); - const Mat& A = tmp.M; - - const bool status = op_princomp::direct_princomp(out, A); + const bool status = op_princomp::direct_princomp(out, in.m); if(status == false) { diff --git a/src/armadillo_bits/op_prod_bones.hpp b/src/armadillo_bits/op_prod_bones.hpp index f57d30fa..790401cb 100644 --- a/src/armadillo_bits/op_prod_bones.hpp +++ b/src/armadillo_bits/op_prod_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_prod_meat.hpp b/src/armadillo_bits/op_prod_meat.hpp index 9555d0b9..c1b9577a 100644 --- a/src/armadillo_bits/op_prod_meat.hpp +++ b/src/armadillo_bits/op_prod_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,7 +30,7 @@ op_prod::apply_noalias(Mat& out, const Mat& X, const uword dim) const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; - if(dim == 0) // traverse across rows (i.e. find the product in each column) + if(dim == 0) // traverse across rows (ie. find the product in each column) { out.set_size(1, X_n_cols); @@ -39,7 +41,7 @@ op_prod::apply_noalias(Mat& out, const Mat& X, const uword dim) out_mem[col] = arrayops::product(X.colptr(col), X_n_rows); } } - else // traverse across columns (i.e. find the product in each row) + else // traverse across columns (ie. find the product in each row) { out.ones(X_n_rows, 1); diff --git a/src/armadillo_bits/op_range_bones.hpp b/src/armadillo_bits/op_range_bones.hpp index 99496ac9..5745624b 100644 --- a/src/armadillo_bits/op_range_bones.hpp +++ b/src/armadillo_bits/op_range_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_range_meat.hpp b/src/armadillo_bits/op_range_meat.hpp index 8f7e4514..a3e66c21 100644 --- a/src/armadillo_bits/op_range_meat.hpp +++ b/src/armadillo_bits/op_range_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,7 +31,7 @@ op_range::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; const uword dim = in.aux_uword_a; - arma_debug_check( (dim > 1), "range(): parameter 'dim' must be 0 or 1"); + arma_debug_check( (dim > 1), "range(): parameter 'dim' must be 0 or 1" ); const quasi_unwrap U(in.m); const Mat& X = U.M; diff --git a/src/armadillo_bits/op_rank_bones.hpp b/src/armadillo_bits/op_rank_bones.hpp new file mode 100644 index 00000000..f0c4a072 --- /dev/null +++ b/src/armadillo_bits/op_rank_bones.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_rank +//! @{ + + + +class op_rank + : public traits_op_default + { + public: + + template inline static bool apply(uword& out, const Base& expr, const typename T1::pod_type tol); + + template inline static bool apply_gen(uword& out, Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_sym(uword& out, Mat& A, typename get_pod_type::result tol); + + template inline static bool apply_diag(uword& out, Mat& A, typename get_pod_type::result tol); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_rank_meat.hpp b/src/armadillo_bits/op_rank_meat.hpp new file mode 100644 index 00000000..ef00dd39 --- /dev/null +++ b/src/armadillo_bits/op_rank_meat.hpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + + +//! \addtogroup op_rank +//! @{ + + + +template +inline +bool +op_rank::apply(uword& out, const Base& expr, const typename T1::pod_type tol) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + Mat A(expr.get_ref()); + + if(A.is_empty()) { out = uword(0); return true; } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_rank::apply(): detected diagonal matrix"); + + return op_rank::apply_diag(out, A, tol); + } + + bool do_sym = false; + + if((arma_config::optimise_sym) && (auxlib::crippled_lapack(A) == false) && (A.n_rows >= (is_cx::yes ? uword(64) : uword(128)))) + { + bool is_approx_sym = false; + bool is_approx_sympd = false; + + sym_helper::analyse_matrix(is_approx_sym, is_approx_sympd, A); + + do_sym = (is_cx::no) ? (is_approx_sym) : (is_approx_sym && is_approx_sympd); + } + + if(do_sym) + { + arma_extra_debug_print("op_rank::apply(): symmetric/hermitian optimisation"); + + return op_rank::apply_sym(out, A, tol); + } + + return op_rank::apply_gen(out, A, tol); + } + + + +template +inline +bool +op_rank::apply_diag(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + const uword N = (std::min)(A.n_rows, A.n_cols); + + podarray diag_abs_vals(N); + + T max_abs_Aii = T(0); + + for(uword i=0; i max_abs_Aii) ? abs_Aii : max_abs_Aii; + } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_Aii * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +template +inline +bool +op_rank::apply_sym(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + if(A.is_square() == false) { out = uword(0); return false; } + + Col v; + + const bool status = auxlib::eig_sym(v, A); + + if(status == false) { out = uword(0); return false; } + + const uword v_n_elem = v.n_elem; + T* v_mem = v.memptr(); + + if(v_n_elem == 0) { out = uword(0); return true; } + + T max_abs_v = T(0); + + for(uword i=0; i < v_n_elem; ++i) { const T val = std::abs(v_mem[i]); v_mem[i] = val; if(val > max_abs_v) { max_abs_v = val; } } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * max_abs_v * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < v_n_elem; ++i) { count += (v_mem[i] > tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +template +inline +bool +op_rank::apply_gen(uword& out, Mat& A, typename get_pod_type::result tol) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + Col s; + + const bool status = auxlib::svd_dc(s, A); + + if(status == false) { out = uword(0); return false; } + + const uword s_n_elem = s.n_elem; + const T* s_mem = s.memptr(); + + if(s_n_elem == 0) { out = uword(0); return true; } + + // set tolerance to default if it hasn't been specified + if(tol == T(0)) { tol = (std::max)(A.n_rows, A.n_cols) * s_mem[0] * std::numeric_limits::epsilon(); } + + uword count = 0; + + for(uword i=0; i < s_n_elem; ++i) { count += (s_mem[i] > tol) ? uword(1) : uword(0); } + + out = count; + + return true; + } + + + +//! @} diff --git a/src/armadillo_bits/include_atlas.hpp b/src/armadillo_bits/op_rcond_bones.hpp similarity index 54% rename from src/armadillo_bits/include_atlas.hpp rename to src/armadillo_bits/op_rcond_bones.hpp index fa5aa505..88697e8d 100644 --- a/src/armadillo_bits/include_atlas.hpp +++ b/src/armadillo_bits/op_rcond_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,29 +16,17 @@ // ------------------------------------------------------------------------ -#if defined(ARMA_USE_ATLAS) - #if !defined(ARMA_ATLAS_INCLUDE_DIR) - extern "C" - { - #include - #include - } - #else - #define ARMA_STR1(x) x - #define ARMA_STR2(x) ARMA_STR1(x) - - #define ARMA_CBLAS ARMA_STR2(ARMA_ATLAS_INCLUDE_DIR)ARMA_STR2(cblas.h) - #define ARMA_CLAPACK ARMA_STR2(ARMA_ATLAS_INCLUDE_DIR)ARMA_STR2(clapack.h) - - extern "C" - { - #include ARMA_INCFILE_WRAP(ARMA_CBLAS) - #include ARMA_INCFILE_WRAP(ARMA_CLAPACK) - } - - #undef ARMA_STR1 - #undef ARMA_STR2 - #undef ARMA_CBLAS - #undef ARMA_CLAPACK - #endif -#endif +//! \addtogroup op_rcond +//! @{ + + +class op_rcond + : public traits_op_default + { + public: + + template static inline typename T1::pod_type apply(const Base& X); + }; + + +//! @} diff --git a/src/armadillo_bits/op_rcond_meat.hpp b/src/armadillo_bits/op_rcond_meat.hpp new file mode 100644 index 00000000..48123bd0 --- /dev/null +++ b/src/armadillo_bits/op_rcond_meat.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_rcond +//! @{ + + + +template +inline +typename T1::pod_type +op_rcond::apply(const Base& X) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + if(strip_trimat::do_trimat) + { + const strip_trimat S(X.get_ref()); + + const quasi_unwrap::stored_type> U(S.M); + + arma_debug_check( (U.M.is_square() == false), "rcond(): matrix must be square sized" ); + + const uword layout = (S.do_triu) ? uword(0) : uword(1); + + return auxlib::rcond_trimat(U.M, layout); + } + + Mat A = X.get_ref(); + + arma_debug_check( (A.is_square() == false), "rcond(): matrix must be square sized" ); + + if(A.is_empty()) { return Datum::inf; } + + if(is_op_diagmat::value || A.is_diagmat()) + { + arma_extra_debug_print("op_rcond::apply(): detected diagonal matrix"); + + const eT* colmem = A.memptr(); + const uword N = A.n_rows; + + T abs_min = Datum::inf; + T abs_max = T(0); + + for(uword i=0; i abs_max) ? abs_val : abs_max; + + colmem += N; + } + + if((abs_min == T(0)) || (abs_max == T(0))) { return T(0); } + + return T(abs_min / abs_max); + } + + const bool is_triu = trimat_helper::is_triu(A); + const bool is_tril = (is_triu) ? false : trimat_helper::is_tril(A); + + if(is_triu || is_tril) + { + const uword layout = (is_triu) ? uword(0) : uword(1); + + return auxlib::rcond_trimat(A, layout); + } + + const bool try_sympd = arma_config::optimise_sym && (auxlib::crippled_lapack(A) ? false : sym_helper::guess_sympd(A)); + + if(try_sympd) + { + arma_extra_debug_print("op_rcond::apply(): attempting sympd optimisation"); + + bool calc_ok = false; + + const T out_val = auxlib::rcond_sympd(A, calc_ok); + + if(calc_ok) { return out_val; } + + arma_extra_debug_print("op_rcond::apply(): sympd optimisation failed"); + + // auxlib::rcond_sympd() may have failed because A isn't really sympd + // restore A, as auxlib::rcond_sympd() may have destroyed it + A = X.get_ref(); + // fallthrough to the next return statement + } + + return auxlib::rcond(A); + } + + + +//! @} diff --git a/src/armadillo_bits/op_relational_bones.hpp b/src/armadillo_bits/op_relational_bones.hpp index 2b1f6dcc..0e8ecabc 100644 --- a/src/armadillo_bits/op_relational_bones.hpp +++ b/src/armadillo_bits/op_relational_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_relational_meat.hpp b/src/armadillo_bits/op_relational_meat.hpp index 3a651718..6c7344bf 100644 --- a/src/armadillo_bits/op_relational_meat.hpp +++ b/src/armadillo_bits/op_relational_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_repelem_bones.hpp b/src/armadillo_bits/op_repelem_bones.hpp index 149fc57d..52d20a41 100644 --- a/src/armadillo_bits/op_repelem_bones.hpp +++ b/src/armadillo_bits/op_repelem_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_repelem_meat.hpp b/src/armadillo_bits/op_repelem_meat.hpp index 4d9d08db..38bfe36b 100644 --- a/src/armadillo_bits/op_repelem_meat.hpp +++ b/src/armadillo_bits/op_repelem_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_repmat_bones.hpp b/src/armadillo_bits/op_repmat_bones.hpp index 6b5cc76c..100179ab 100644 --- a/src/armadillo_bits/op_repmat_bones.hpp +++ b/src/armadillo_bits/op_repmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_repmat_meat.hpp b/src/armadillo_bits/op_repmat_meat.hpp index 6733210b..1460350d 100644 --- a/src/armadillo_bits/op_repmat_meat.hpp +++ b/src/armadillo_bits/op_repmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_reshape_bones.hpp b/src/armadillo_bits/op_reshape_bones.hpp index d53b31f0..b27f22b8 100644 --- a/src/armadillo_bits/op_reshape_bones.hpp +++ b/src/armadillo_bits/op_reshape_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,22 +27,21 @@ class op_reshape { public: - template inline static void apply_unwrap(Mat& out, const Mat& A, const uword in_n_rows, const uword in_n_cols, const uword in_dim); + template inline static void apply(Mat& out, const Op& in); - template inline static void apply_proxy (Mat& out, const Proxy& P, const uword in_n_rows, const uword in_n_cols); + template inline static void apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols); - template inline static void apply (Mat& out, const Op& in); - }; - - - -class op_reshape_ext - : public traits_op_default - { - public: + template inline static void apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols); + + template inline static void apply_proxy_noalias(Mat& out, const Proxy& P, const uword new_n_rows, const uword new_n_cols); + + // + + template inline static void apply(Cube& out, const OpCube& in); + + template inline static void apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); - template inline static void apply( Mat& out, const Op& in); - template inline static void apply(Cube& out, const OpCube& in); + template inline static void apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); }; diff --git a/src/armadillo_bits/op_reshape_meat.hpp b/src/armadillo_bits/op_reshape_meat.hpp index 5daaf342..1846dc8b 100644 --- a/src/armadillo_bits/op_reshape_meat.hpp +++ b/src/armadillo_bits/op_reshape_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,104 +22,98 @@ -template +template inline void -op_reshape::apply_unwrap(Mat& out, const Mat& A, const uword in_n_rows, const uword in_n_cols, const uword in_dim) +op_reshape::apply(Mat& actual_out, const Op& in) { arma_extra_debug_sigprint(); - const bool is_alias = (&out == &A); + typedef typename T1::elem_type eT; - const uword in_n_elem = in_n_rows * in_n_cols; + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; - if(A.n_elem == in_n_elem) + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) { - if(in_dim == 0) + const unwrap U(in.m); + const Mat& A = U.M; + + if(&actual_out == &A) { - if(is_alias == false) - { - out.set_size(in_n_rows, in_n_cols); - arrayops::copy( out.memptr(), A.memptr(), out.n_elem ); - } - else // &out == &A, i.e. inplace resize - { - out.set_size(in_n_rows, in_n_cols); - // set_size() doesn't destroy data as long as the number of elements in the matrix remains the same - } + op_reshape::apply_mat_inplace(actual_out, new_n_rows, new_n_cols); } else { - unwrap_check< Mat > B_tmp(A, is_alias); - const Mat& B = B_tmp.M; - - out.set_size(in_n_rows, in_n_cols); - - eT* out_mem = out.memptr(); - - const uword B_n_rows = B.n_rows; - const uword B_n_cols = B.n_cols; - - for(uword row=0; row > B_tmp(A, is_alias); - const Mat& B = B_tmp.M; - - const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem); + const Proxy P(in.m); - out.set_size(in_n_rows, in_n_cols); + const bool is_alias = P.is_alias(actual_out); - eT* out_mem = out.memptr(); + Mat tmp; + Mat& out = (is_alias) ? tmp : actual_out; - if(in_dim == 0) + if(is_Mat::stored_type>::value) { - arrayops::copy( out_mem, B.memptr(), n_elem_to_copy ); + const quasi_unwrap::stored_type> U(P.Q); + + op_reshape::apply_mat_noalias(out, U.M, new_n_rows, new_n_cols); } else { - uword row = 0; - uword col = 0; - - const uword B_n_cols = B.n_cols; - - for(uword i=0; i= B_n_cols) - { - col = 0; - ++row; - } - } + op_reshape::apply_proxy_noalias(out, P, new_n_rows, new_n_cols); } - for(uword i=n_elem_to_copy; i +inline +void +op_reshape::apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + const uword new_n_elem = new_n_rows * new_n_cols; + + if(A.n_elem == new_n_elem) { A.set_size(new_n_rows, new_n_cols); return; } + + Mat B; + + op_reshape::apply_mat_noalias(B, A, new_n_rows, new_n_cols); + + A.steal_mem(B); + } + + + +template +inline +void +op_reshape::apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols); + + const uword n_elem_to_copy = (std::min)(A.n_elem, out.n_elem); + + eT* out_mem = out.memptr(); + + arrayops::copy( out_mem, A.memptr(), n_elem_to_copy ); + + if(n_elem_to_copy < out.n_elem) + { + const uword n_elem_leftover = out.n_elem - n_elem_to_copy; + arrayops::fill_zeros(&(out_mem[n_elem_to_copy]), n_elem_leftover); } } @@ -126,91 +122,49 @@ op_reshape::apply_unwrap(Mat& out, const Mat& A, const uword in_n_rows, template inline void -op_reshape::apply_proxy(Mat& out, const Proxy& P, const uword in_n_rows, const uword in_n_cols) +op_reshape::apply_proxy_noalias(Mat& out, const Proxy& P, const uword new_n_rows, const uword new_n_cols) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - out.set_size(in_n_rows, in_n_cols); + out.set_size(new_n_rows, new_n_cols); - eT* out_mem = out.memptr(); + const uword n_elem_to_copy = (std::min)(P.get_n_elem(), out.n_elem); - const uword in_n_elem = in_n_rows * in_n_cols; + eT* out_mem = out.memptr(); - if(P.get_n_elem() == in_n_elem) + if(Proxy::use_at == false) { - if(Proxy::use_at == false) - { - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i::ea_type Pea = P.get_ea(); + + for(uword i=0; i < n_elem_to_copy; ++i) { out_mem[i] = Pea[i]; } } else { - const uword n_elem_to_copy = (std::min)(P.get_n_elem(), in_n_elem); + uword i = 0; - if(Proxy::use_at == false) - { - typename Proxy::ea_type Pea = P.get_ea(); - - for(uword i=0; i= n_elem_to_copy) { goto nested_loop_end; } - const uword P_n_rows = P.get_n_rows(); - const uword P_n_cols = P.get_n_cols(); + out_mem[i] = P.at(row,col); - for(uword col=0; col < P_n_cols; ++col) - for(uword row=0; row < P_n_rows; ++row) - { - if(i >= n_elem_to_copy) { goto nested_loop_end; } - - out_mem[i] = P.at(row,col); - - ++i; - } - - nested_loop_end: ; + ++i; } - for(uword i=n_elem_to_copy; i& out, const Proxy& P, co template inline void -op_reshape::apply(Mat& out, const Op& in) +op_reshape::apply(Cube& out, const OpCube& in) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(in.m); + const unwrap_cube U(in.m); + const Cube& A = U.M; - const uword in_n_rows = in.aux_uword_a; - const uword in_n_cols = in.aux_uword_b; + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + const uword new_n_slices = in.aux_uword_c; - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) + if(&out == &A) { - // not checking for aliasing here, as this might be an inplace reshape - - const unwrap::stored_type> tmp(P.Q); - - op_reshape::apply_unwrap(out, tmp.M, in_n_rows, in_n_cols, uword(0)); + op_reshape::apply_cube_inplace(out, new_n_rows, new_n_cols, new_n_slices); } else { - if(P.is_alias(out)) - { - Mat tmp; - - op_reshape::apply_proxy(tmp, P, in_n_rows, in_n_cols); - - out.steal_mem(tmp); - } - else - { - op_reshape::apply_proxy(out, P, in_n_rows, in_n_cols); - } + op_reshape::apply_cube_noalias(out, A, new_n_rows, new_n_cols, new_n_slices); } } + - -template +template inline void -op_reshape_ext::apply(Mat& out, const Op& in) +op_reshape::apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - const unwrap tmp(in.m); + const uword new_n_elem = new_n_rows * new_n_cols * new_n_slices; + + if(A.n_elem == new_n_elem) { A.set_size(new_n_rows, new_n_cols, new_n_slices); return; } - const uword in_n_rows = in.aux_uword_a; - const uword in_n_cols = in.aux_uword_b; - const uword in_dim = in.aux_uword_c; + Cube B; - op_reshape::apply_unwrap(out, tmp.M, in_n_rows, in_n_cols, in_dim); + op_reshape::apply_cube_noalias(B, A, new_n_rows, new_n_cols, new_n_slices); + + A.steal_mem(B); } -template +template inline void -op_reshape_ext::apply(Cube& out, const OpCube& in) +op_reshape::apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; + out.set_size(new_n_rows, new_n_cols, new_n_slices); - const unwrap_cube A_tmp(in.m); - const Cube& A = A_tmp.M; + const uword n_elem_to_copy = (std::min)(A.n_elem, out.n_elem); - const uword in_n_rows = in.aux_uword_a; - const uword in_n_cols = in.aux_uword_b; - const uword in_n_slices = in.aux_uword_c; - const uword in_dim = in.aux_uword_d; + eT* out_mem = out.memptr(); - const uword in_n_elem = in_n_rows * in_n_cols * in_n_slices; + arrayops::copy( out_mem, A.memptr(), n_elem_to_copy ); - if(A.n_elem == in_n_elem) - { - if(in_dim == 0) - { - if(&out != &A) - { - out.set_size(in_n_rows, in_n_cols, in_n_slices); - arrayops::copy( out.memptr(), A.memptr(), out.n_elem ); - } - else // &out == &A, i.e. inplace resize - { - out.set_size(in_n_rows, in_n_cols, in_n_slices); - // set_size() doesn't destroy data as long as the number of elements in the cube remains the same - } - } - else - { - unwrap_cube_check< Cube > B_tmp(A, out); - const Cube& B = B_tmp.M; - - out.set_size(in_n_rows, in_n_cols, in_n_slices); - - eT* out_mem = out.memptr(); - - const uword B_n_rows = B.n_rows; - const uword B_n_cols = B.n_cols; - const uword B_n_slices = B.n_slices; - - for(uword slice = 0; slice < B_n_slices; ++slice) - for(uword row = 0; row < B_n_rows; ++row ) - for(uword col = 0; col < B_n_cols; ++col ) - { - *out_mem = B.at(row,col,slice); - out_mem++; - } - } - } - else + if(n_elem_to_copy < out.n_elem) { - const unwrap_cube_check< Cube > B_tmp(A, out); - const Cube& B = B_tmp.M; - - const uword n_elem_to_copy = (std::min)(B.n_elem, in_n_elem); - - out.set_size(in_n_rows, in_n_cols, in_n_slices); - - eT* out_mem = out.memptr(); - - if(in_dim == 0) - { - arrayops::copy( out_mem, B.memptr(), n_elem_to_copy ); - } - else - { - uword row = 0; - uword col = 0; - uword slice = 0; - - const uword B_n_rows = B.n_rows; - const uword B_n_cols = B.n_cols; - - for(uword i=0; i= B_n_cols) - { - col = 0; - ++row; - - if(row >= B_n_rows) - { - row = 0; - ++slice; - } - } - } - } - - for(uword i=n_elem_to_copy; i inline static void apply( Mat& out, const Op& in); + template inline static void apply(Mat& out, const Op& in); + + template inline static void apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols); + + template inline static void apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols); + + // + template inline static void apply(Cube& out, const OpCube& in); + + template inline static void apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); + + template inline static void apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices); }; diff --git a/src/armadillo_bits/op_resize_meat.hpp b/src/armadillo_bits/op_resize_meat.hpp index bb7834b6..b18a163a 100644 --- a/src/armadillo_bits/op_resize_meat.hpp +++ b/src/armadillo_bits/op_resize_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,122 +25,143 @@ template inline void -op_resize::apply(Mat& actual_out, const Op& in) +op_resize::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const uword out_n_rows = in.aux_uword_a; - const uword out_n_cols = in.aux_uword_b; + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; const unwrap tmp(in.m); const Mat& A = tmp.M; - const uword A_n_rows = A.n_rows; - const uword A_n_cols = A.n_cols; - - const bool alias = (&actual_out == &A); - - if(alias) + if(&out == &A) { - if( (A_n_rows == out_n_rows) && (A_n_cols == out_n_cols) ) - { - return; - } - - if(actual_out.is_empty()) - { - actual_out.zeros(out_n_rows, out_n_cols); - return; - } + op_resize::apply_mat_inplace(out, new_n_rows, new_n_cols); } + else + { + op_resize::apply_mat_noalias(out, A, new_n_rows, new_n_cols); + } + } + + + +template +inline +void +op_resize::apply_mat_inplace(Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); - Mat B; - Mat& out = alias ? B : actual_out; + if( (A.n_rows == new_n_rows) && (A.n_cols == new_n_cols) ) { return; } - out.set_size(out_n_rows, out_n_cols); + if(A.is_empty()) { A.zeros(new_n_rows, new_n_cols); return; } - if( (out_n_rows > A_n_rows) || (out_n_cols > A_n_cols) ) - { - out.zeros(); - } + Mat B; + + op_resize::apply_mat_noalias(B, A, new_n_rows, new_n_cols); + + A.steal_mem(B); + } + + + +template +inline +void +op_resize::apply_mat_noalias(Mat& out, const Mat& A, const uword new_n_rows, const uword new_n_cols) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols); + + if( (new_n_rows > A.n_rows) || (new_n_cols > A.n_cols) ) { out.zeros(); } if( (out.n_elem > 0) && (A.n_elem > 0) ) { - const uword end_row = (std::min)(out_n_rows, A_n_rows) - 1; - const uword end_col = (std::min)(out_n_cols, A_n_cols) - 1; + const uword end_row = (std::min)(new_n_rows, A.n_rows) - 1; + const uword end_col = (std::min)(new_n_cols, A.n_cols) - 1; out.submat(0, 0, end_row, end_col) = A.submat(0, 0, end_row, end_col); } - - if(alias) - { - actual_out.steal_mem(B); - } } +// + + + template inline void -op_resize::apply(Cube& actual_out, const OpCube& in) +op_resize::apply(Cube& out, const OpCube& in) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const uword out_n_rows = in.aux_uword_a; - const uword out_n_cols = in.aux_uword_b; - const uword out_n_slices = in.aux_uword_c; + const uword new_n_rows = in.aux_uword_a; + const uword new_n_cols = in.aux_uword_b; + const uword new_n_slices = in.aux_uword_c; const unwrap_cube tmp(in.m); const Cube& A = tmp.M; - const uword A_n_rows = A.n_rows; - const uword A_n_cols = A.n_cols; - const uword A_n_slices = A.n_slices; - - const bool alias = (&actual_out == &A); - - if(alias) + if(&out == &A) { - if( (A_n_rows == out_n_rows) && (A_n_cols == out_n_cols) && (A_n_slices == out_n_slices) ) - { - return; - } - - if(actual_out.is_empty()) - { - actual_out.zeros(out_n_rows, out_n_cols, out_n_slices); - return; - } + op_resize::apply_cube_inplace(out, new_n_rows, new_n_cols, new_n_slices); + } + else + { + op_resize::apply_cube_noalias(out, A, new_n_rows, new_n_cols, new_n_slices); } + } + + + +template +inline +void +op_resize::apply_cube_inplace(Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); - Cube B; - Cube& out = alias ? B : actual_out; + if( (A.n_rows == new_n_rows) && (A.n_cols == new_n_cols) && (A.n_slices == new_n_slices) ) { return; } - out.set_size(out_n_rows, out_n_cols, out_n_slices); + if(A.is_empty()) { A.zeros(new_n_rows, new_n_cols, new_n_slices); return; } - if( (out_n_rows > A_n_rows) || (out_n_cols > A_n_cols) || (out_n_slices > A_n_slices) ) - { - out.zeros(); - } + Cube B; + + op_resize::apply_cube_noalias(B, A, new_n_rows, new_n_cols, new_n_slices); + + A.steal_mem(B); + } + + + +template +inline +void +op_resize::apply_cube_noalias(Cube& out, const Cube& A, const uword new_n_rows, const uword new_n_cols, const uword new_n_slices) + { + arma_extra_debug_sigprint(); + + out.set_size(new_n_rows, new_n_cols, new_n_slices); + + if( (new_n_rows > A.n_rows) || (new_n_cols > A.n_cols) || (new_n_slices > A.n_slices) ) { out.zeros(); } if( (out.n_elem > 0) && (A.n_elem > 0) ) { - const uword end_row = (std::min)(out_n_rows, A_n_rows) - 1; - const uword end_col = (std::min)(out_n_cols, A_n_cols) - 1; - const uword end_slice = (std::min)(out_n_slices, A_n_slices) - 1; + const uword end_row = (std::min)(new_n_rows, A.n_rows) - 1; + const uword end_col = (std::min)(new_n_cols, A.n_cols) - 1; + const uword end_slice = (std::min)(new_n_slices, A.n_slices) - 1; out.subcube(0, 0, 0, end_row, end_col, end_slice) = A.subcube(0, 0, 0, end_row, end_col, end_slice); } - - if(alias) - { - actual_out.steal_mem(B); - } } diff --git a/src/armadillo_bits/op_reverse_bones.hpp b/src/armadillo_bits/op_reverse_bones.hpp index c782c0f0..8ec62176 100644 --- a/src/armadillo_bits/op_reverse_bones.hpp +++ b/src/armadillo_bits/op_reverse_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_reverse_meat.hpp b/src/armadillo_bits/op_reverse_meat.hpp index 12f7b0ed..4edafa86 100644 --- a/src/armadillo_bits/op_reverse_meat.hpp +++ b/src/armadillo_bits/op_reverse_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,36 +28,38 @@ op_reverse::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; + const uword dim = in.aux_uword_a; arma_debug_check( (dim > 1), "reverse(): parameter 'dim' must be 0 or 1" ); - const Proxy P(in.m); - - if(is_Mat::stored_type>::value || P.is_alias(out)) + if(is_Mat::value) { - const unwrap::stored_type> U(P.Q); + // allow detection of in-place operation - if(dim == 0) - { - op_flipud::apply_direct(out, U.M); - } - else - if(dim == 1) - { - op_fliplr::apply_direct(out, U.M); - } + const unwrap U(in.m); + + if(dim == 0) { op_flipud::apply_direct(out, U.M); } + if(dim == 1) { op_fliplr::apply_direct(out, U.M); } } else { - if(dim == 0) + const Proxy P(in.m); + + if(P.is_alias(out)) { - op_flipud::apply_proxy_noalias(out, P); + Mat tmp; + + if(dim == 0) { op_flipud::apply_proxy_noalias(tmp, P); } + if(dim == 1) { op_fliplr::apply_proxy_noalias(tmp, P); } + + out.steal_mem(tmp); } else - if(dim == 1) { - op_fliplr::apply_proxy_noalias(out, P); + if(dim == 0) { op_flipud::apply_proxy_noalias(out, P); } + if(dim == 1) { op_fliplr::apply_proxy_noalias(out, P); } } } } @@ -69,11 +73,13 @@ op_reverse_vec::apply(Mat& out, const Op P(in.m); + typedef typename T1::elem_type eT; - if(is_Mat::stored_type>::value || P.is_alias(out)) + if(is_Mat::value) { - const unwrap::stored_type> U(P.Q); + // allow detection of in-place operation + + const unwrap U(in.m); if((T1::is_xvec) ? bool(U.M.is_rowvec()) : bool(T1::is_row)) { @@ -86,13 +92,33 @@ op_reverse_vec::apply(Mat& out, const Op P(in.m); + + if(P.is_alias(out)) { - op_fliplr::apply_proxy_noalias(out, P); + Mat tmp; + + if((T1::is_xvec) ? bool(P.get_n_rows() == 1) : bool(T1::is_row)) + { + op_fliplr::apply_proxy_noalias(tmp, P); + } + else + { + op_flipud::apply_proxy_noalias(tmp, P); + } + + out.steal_mem(tmp); } else { - op_flipud::apply_proxy_noalias(out, P); + if((T1::is_xvec) ? bool(P.get_n_rows() == 1) : bool(T1::is_row)) + { + op_fliplr::apply_proxy_noalias(out, P); + } + else + { + op_flipud::apply_proxy_noalias(out, P); + } } } } diff --git a/src/armadillo_bits/op_roots_bones.hpp b/src/armadillo_bits/op_roots_bones.hpp index 8cce3069..6007d198 100644 --- a/src/armadillo_bits/op_roots_bones.hpp +++ b/src/armadillo_bits/op_roots_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_roots_meat.hpp b/src/armadillo_bits/op_roots_meat.hpp index b2d8ca27..1e091205 100644 --- a/src/armadillo_bits/op_roots_meat.hpp +++ b/src/armadillo_bits/op_roots_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,7 +30,11 @@ op_roots::apply(Mat< std::complex >& out, const mtOp >& out, const Ba status = op_roots::apply_noalias(out, U.M); } - if(status == false) { out.soft_reset(); } - return status; } @@ -79,7 +83,7 @@ op_roots::apply_noalias(Mat< std::complex::result> >& arma_debug_check( (X.is_vec() == false), "roots(): given object must be a vector" ); - if(X.is_finite() == false) { return false; } + if(X.internal_has_nonfinite()) { return false; } // treat X as a column vector diff --git a/src/armadillo_bits/op_row_as_mat_bones.hpp b/src/armadillo_bits/op_row_as_mat_bones.hpp new file mode 100644 index 00000000..a8430927 --- /dev/null +++ b/src/armadillo_bits/op_row_as_mat_bones.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_row_as_mat +//! @{ + + +class op_row_as_mat + : public traits_op_default + { + public: + + template inline static void apply(Mat& out, const CubeToMatOp& expr); + }; + + + +//! @} diff --git a/src/armadillo_bits/op_row_as_mat_meat.hpp b/src/armadillo_bits/op_row_as_mat_meat.hpp new file mode 100644 index 00000000..751d8d80 --- /dev/null +++ b/src/armadillo_bits/op_row_as_mat_meat.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_row_as_mat +//! @{ + + + +template +inline +void +op_row_as_mat::apply(Mat& out, const CubeToMatOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_cube U(expr.m); + const Cube& A = U.M; + + const uword in_row = expr.aux_uword; + + arma_debug_check_bounds( (in_row >= A.n_rows), "Cube::row_as_mat(): index out of bounds" ); + + const uword A_n_cols = A.n_cols; + const uword A_n_rows = A.n_rows; + const uword A_n_slices = A.n_slices; + + out.set_size(A_n_slices, A_n_cols); + + for(uword s=0; s < A_n_slices; ++s) + { + const eT* A_mem = &(A.at(in_row, 0, s)); + eT* out_mem = &(out.at(s,0)); + + for(uword c=0; c < A_n_cols; ++c) + { + (*out_mem) = (*A_mem); + + A_mem += A_n_rows; + out_mem += A_n_slices; + } + } + } + + + +//! @} diff --git a/src/armadillo_bits/op_shift_bones.hpp b/src/armadillo_bits/op_shift_bones.hpp index 45c152ea..74e49025 100644 --- a/src/armadillo_bits/op_shift_bones.hpp +++ b/src/armadillo_bits/op_shift_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -35,13 +37,7 @@ class op_shift { public: - template inline static void apply(Mat& out, const Op& in); - - template inline static void apply_direct(Mat& out, const Mat& X, const uword len, const uword neg, const uword dim); - template inline static void apply_noalias(Mat& out, const Mat& X, const uword len, const uword neg, const uword dim); - - template inline static void apply_alias(Mat& out, const uword len, const uword neg, const uword dim); }; diff --git a/src/armadillo_bits/op_shift_meat.hpp b/src/armadillo_bits/op_shift_meat.hpp index 4e65dd03..b369b5d3 100644 --- a/src/armadillo_bits/op_shift_meat.hpp +++ b/src/armadillo_bits/op_shift_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,55 +29,26 @@ op_shift_vec::apply(Mat& out, const Op& { arma_extra_debug_sigprint(); - const unwrap U(in.m); - - const uword len = in.aux_uword_a; - const uword neg = in.aux_uword_b; + typedef typename T1::elem_type eT; - const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); - - op_shift::apply_direct(out, U.M, len, neg, dim); - } - - - -template -inline -void -op_shift::apply(Mat& out, const Op& in) - { - arma_extra_debug_sigprint(); - - const unwrap U(in.m); + const quasi_unwrap U(in.m); const uword len = in.aux_uword_a; const uword neg = in.aux_uword_b; - const uword dim = in.aux_uword_c; - - arma_debug_check( (dim > 1), "shift(): parameter 'dim' must be 0 or 1" ); - - op_shift::apply_direct(out, U.M, len, neg, dim); - } - - - -template -inline -void -op_shift::apply_direct(Mat& out, const Mat& X, const uword len, const uword neg, const uword dim) - { - arma_extra_debug_sigprint(); - arma_debug_check( ((dim == 0) && (len >= X.n_rows)), "shift(): shift amount out of bounds" ); - arma_debug_check( ((dim == 1) && (len >= X.n_cols)), "shift(): shift amount out of bounds" ); + const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); - if(&out == &X) + if(U.is_alias(out)) { - op_shift::apply_alias(out, len, neg, dim); + Mat tmp; + + op_shift::apply_noalias(tmp, U.M, len, neg, dim); + + out.steal_mem(tmp); } else { - op_shift::apply_noalias(out, X, len, neg, dim); + op_shift::apply_noalias(out, U.M, len, neg, dim); } } @@ -88,6 +61,9 @@ op_shift::apply_noalias(Mat& out, const Mat& X, const uword len, const u { arma_extra_debug_sigprint(); + arma_debug_check_bounds( ((dim == 0) && (len >= X.n_rows)), "shift(): shift amount out of bounds" ); + arma_debug_check_bounds( ((dim == 1) && (len >= X.n_cols)), "shift(): shift amount out of bounds" ); + out.copy_size(X); const uword X_n_rows = X.n_rows; @@ -202,22 +178,4 @@ op_shift::apply_noalias(Mat& out, const Mat& X, const uword len, const u -template -inline -void -op_shift::apply_alias(Mat& X, const uword len, const uword neg, const uword dim) - { - arma_extra_debug_sigprint(); - - // TODO: replace with better implementation - - Mat tmp; - - op_shift::apply_noalias(tmp, X, len, neg, dim); - - X.steal_mem(tmp); - } - - - //! @} diff --git a/src/armadillo_bits/op_shuffle_bones.hpp b/src/armadillo_bits/op_shuffle_bones.hpp index 2c997bd5..8150d130 100644 --- a/src/armadillo_bits/op_shuffle_bones.hpp +++ b/src/armadillo_bits/op_shuffle_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_shuffle_meat.hpp b/src/armadillo_bits/op_shuffle_meat.hpp index d01398b4..ecfc2f6e 100644 --- a/src/armadillo_bits/op_shuffle_meat.hpp +++ b/src/armadillo_bits/op_shuffle_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -31,10 +33,12 @@ op_shuffle::apply_direct(Mat& out, const Mat& X, const uword dim) const uword N = (dim == 0) ? X.n_rows : X.n_cols; - // see op_sort_index_bones.hpp for the definition of arma_sort_index_packet // and the associated comparison functor - std::vector< arma_sort_index_packet > packet_vec(N); + + typedef arma_sort_index_packet packet; + + std::vector packet_vec(N); for(uword i=0; i& out, const Mat& X, const uword dim) if(dim == 0) { - if(X.n_rows > 1) // i.e. column vector + if(X.n_rows > 1) // ie. column vector { for(uword i=0; i& out, const Mat& X, const uword dim) } else { - if(X.n_cols > 1) // i.e. row vector + if(X.n_cols > 1) // ie. row vector { for(uword i=0; i& out, const Mat& X, const uword dim) if(dim == 0) { - if(X.n_rows > 1) // i.e. column vector + if(X.n_rows > 1) // ie. column vector { for(uword i=0; i& out, const Mat& X, const uword dim) } else { - if(X.n_cols > 1) // i.e. row vector + if(X.n_cols > 1) // ie. row vector { for(uword i=0; i& out, const Mat& X, const uword sort_type, co { arma_extra_debug_sigprint(); - if( (X.n_rows * X.n_cols) <= 1 ) - { - out = X; - return; - } - - arma_debug_check( (sort_type > 1), "sort(): parameter 'sort_type' must be 0 or 1" ); - arma_debug_check( (X.has_nan()), "sort(): detected NaN" ); + if((X.n_rows * X.n_cols) <= 1) { out = X; return; } if(dim == 0) // sort the contents of each column { @@ -177,12 +172,15 @@ op_sort::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; const quasi_unwrap U(in.m); - - const Mat& X = U.M; + const Mat& X = U.M; const uword sort_type = in.aux_uword_a; const uword dim = in.aux_uword_b; + arma_debug_check( (sort_type > 1), "sort(): parameter 'sort_type' must be 0 or 1" ); + arma_debug_check( (dim > 1), "sort(): parameter 'dim' must be 0 or 1" ); + arma_debug_check( (X.internal_has_nan()), "sort(): detected NaN" ); + if(U.is_alias(out)) { Mat tmp; @@ -208,24 +206,34 @@ op_sort_vec::apply(Mat& out, const Op& i typedef typename T1::elem_type eT; - const quasi_unwrap U(in.m); - + const unwrap U(in.m); // not using quasi_unwrap, to ensure there is no aliasing with subviews const Mat& X = U.M; const uword sort_type = in.aux_uword_a; - const uword dim = (T1::is_xvec) ? uword(U.M.is_rowvec() ? 1 : 0) : uword((T1::is_row) ? 1 : 0); - if(U.is_alias(out)) + arma_debug_check( (sort_type > 1), "sort(): parameter 'sort_type' must be 0 or 1" ); + arma_debug_check( (X.internal_has_nan()), "sort(): detected NaN" ); + + out = X; // not checking for aliasing, to allow inplace sorting of vectors + + if(out.n_elem <= 1) { return; } + + eT* out_mem = out.memptr(); + + eT* start_ptr = out_mem; + eT* endp1_ptr = &out_mem[out.n_elem]; + + if(sort_type == 0) { - Mat tmp; - - op_sort::apply_noalias(tmp, X, sort_type, dim); + arma_lt_comparator comparator; - out.steal_mem(tmp); + std::sort(start_ptr, endp1_ptr, comparator); } else { - op_sort::apply_noalias(out, X, sort_type, dim); + arma_gt_comparator comparator; + + std::sort(start_ptr, endp1_ptr, comparator); } } diff --git a/src/armadillo_bits/op_sp_minus_bones.hpp b/src/armadillo_bits/op_sp_minus_bones.hpp index 957eb00f..c0134ed0 100644 --- a/src/armadillo_bits/op_sp_minus_bones.hpp +++ b/src/armadillo_bits/op_sp_minus_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_sp_minus_meat.hpp b/src/armadillo_bits/op_sp_minus_meat.hpp index a9b3b32a..f8151a8a 100644 --- a/src/armadillo_bits/op_sp_minus_meat.hpp +++ b/src/armadillo_bits/op_sp_minus_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -179,8 +181,8 @@ op_sp_minus_post::apply(SpMat& out, const SpToDOp >& out, const mtOp >& out, const typedef typename T1::elem_type in_T; typedef typename std::complex out_T; - const Proxy P(expr.get_ref()); + const quasi_unwrap expr_unwrap(expr.get_ref()); + const Mat& A = expr_unwrap.M; - arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "sqrtmat(): given matrix must be square sized" ); + arma_debug_check( (A.is_square() == false), "sqrtmat(): given matrix must be square sized" ); - if(P.get_n_elem() == 0) + if(A.n_elem == 0) { out.reset(); return true; } + else + if(A.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt( std::complex( A[0] ) ); + return true; + } + + if(A.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat: detected diagonal matrix"); + + const uword N = A.n_rows; + + out.zeros(N,N); // aliasing can't happen as op_sqrtmat is defined as cx_mat = op(mat) + + for(uword i=0; i= in_T(0)) + { + out.at(i,i) = std::sqrt(val); + } + else + { + out.at(i,i) = std::sqrt( out_T(val) ); + } + } + + return true; + } + + const bool try_sympd = arma_config::optimise_sym && sym_helper::guess_sympd(A); + + if(try_sympd) + { + arma_extra_debug_print("op_sqrtmat: attempting sympd optimisation"); + + // if matrix A is sympd, all its eigenvalues are positive + + Col eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, A, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const in_T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i >::from( eigvec * diagmat(eigval) * eigvec.t() ); + + return true; + } + } + + arma_extra_debug_print("op_sqrtmat: sympd optimisation failed"); + + // fallthrough if eigen decomposition failed or an eigenvalue is <= 0 + } - typename Proxy::ea_type Pea = P.get_ea(); Mat U; - Mat S(P.get_n_rows(), P.get_n_cols()); + Mat S(A.n_rows, A.n_cols, arma_nozeros_indicator()); - out_T* Smem = S.memptr(); + const in_T* Amem = A.memptr(); + out_T* Smem = S.memptr(); - const uword N = P.get_n_elem(); + const uword n_elem = A.n_elem; - for(uword i=0; i( Pea[i] ); + Smem[i] = std::complex( Amem[i] ); } const bool schur_ok = auxlib::schur(U,S); @@ -151,7 +225,7 @@ op_sqrtmat_cx::apply(Mat& out, const Op& out, const Base U; Mat S = expr.get_ref(); - if(S.is_empty()) + arma_debug_check( (S.n_rows != S.n_cols), "sqrtmat(): given matrix must be square sized" ); + + if(S.n_elem == 0) { out.reset(); return true; } + else + if(S.n_elem == 1) + { + out.set_size(1,1); + out[0] = std::sqrt(S[0]); + return true; + } - arma_debug_check( (S.n_rows != S.n_cols), "sqrtmat(): given matrix must be square sized" ); + if(S.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_cx: detected diagonal matrix"); + + const uword N = S.n_rows; + + out.zeros(N,N); // aliasing can't happen as S is generated + + for(uword i=0; i eigval; + Mat eigvec; + + const bool eig_status = eig_sym_helper(eigval, eigvec, S, 'd', "sqrtmat()"); + + if(eig_status) + { + // ensure each eigenvalue is > 0 + + const uword N = eigval.n_elem; + const T* eigval_mem = eigval.memptr(); + + bool all_pos = true; + + for(uword i=0; i& out, const Base U(expr.get_ref()); const Mat& X = U.M; arma_debug_check( (X.is_square() == false), "sqrtmat_sympd(): given matrix must be square sized" ); + if((arma_config::debug) && (is_cx::yes) && (sym_helper::check_diag_imag(X) == false)) + { + arma_debug_warn_level(1, "sqrtmat_sympd(): imaginary components on the diagonal are non-zero"); + } + + if(is_op_diagmat::value || X.is_diagmat()) + { + arma_extra_debug_print("op_sqrtmat_sympd: detected diagonal matrix"); + + out = X; + + eT* colmem = out.memptr(); + + const uword N = X.n_rows; + + for(uword i=0; i eigval; Mat eigvec; diff --git a/src/armadillo_bits/op_stddev_bones.hpp b/src/armadillo_bits/op_stddev_bones.hpp index bdc1e090..dad6273c 100644 --- a/src/armadillo_bits/op_stddev_bones.hpp +++ b/src/armadillo_bits/op_stddev_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -17,7 +19,8 @@ //! \addtogroup op_stddev //! @{ -//! Class for finding the standard deviation + + class op_stddev : public traits_op_xvec { @@ -25,6 +28,11 @@ class op_stddev template inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim); }; + + //! @} diff --git a/src/armadillo_bits/op_stddev_meat.hpp b/src/armadillo_bits/op_stddev_meat.hpp index 743c102a..83724bcd 100644 --- a/src/armadillo_bits/op_stddev_meat.hpp +++ b/src/armadillo_bits/op_stddev_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,10 +20,7 @@ //! @{ -//! \brief -//! For each row or for each column, find the standard deviation. -//! The result is stored in a dense matrix that has either one column or one row. -//! The dimension for which the standard deviations are found is set via the stddev() function. + template inline void @@ -29,11 +28,7 @@ op_stddev::apply(Mat& out, const mtOp tmp(in.m, out); - const Mat& X = tmp.M; + typedef typename T1::pod_type out_eT; const uword norm_type = in.aux_uword_a; const uword dim = in.aux_uword_b; @@ -41,12 +36,39 @@ op_stddev::apply(Mat& out, const mtOp 1), "stddev(): parameter 'norm_type' must be 0 or 1" ); arma_debug_check( (dim > 1), "stddev(): parameter 'dim' must be 0 or 1" ); + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_stddev::apply_noalias(tmp, U.M, norm_type, dim); + + out.steal_mem(tmp); + } + else + { + op_stddev::apply_noalias(out, U.M, norm_type, dim); + } + } + + + +template +inline +void +op_stddev::apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; if(dim == 0) { - arma_extra_debug_print("op_stddev::apply(): dim = 0"); + arma_extra_debug_print("op_stddev::apply_noalias(): dim = 0"); out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); @@ -63,7 +85,7 @@ op_stddev::apply(Mat& out, const mtOp 0) ? 1 : 0); diff --git a/src/armadillo_bits/op_strans_bones.hpp b/src/armadillo_bits/op_strans_bones.hpp index 0f22ed0d..42534bfe 100644 --- a/src/armadillo_bits/op_strans_bones.hpp +++ b/src/armadillo_bits/op_strans_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,17 +29,17 @@ class op_strans template struct traits { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; }; template struct pos { - static const uword n2 = (do_flip == false) ? (row + col*2) : (col + row*2); - static const uword n3 = (do_flip == false) ? (row + col*3) : (col + row*3); - static const uword n4 = (do_flip == false) ? (row + col*4) : (col + row*4); + static constexpr uword n2 = (do_flip == false) ? (row + col*2) : (col + row*2); + static constexpr uword n3 = (do_flip == false) ? (row + col*3) : (col + row*3); + static constexpr uword n4 = (do_flip == false) ? (row + col*4) : (col + row*4); }; template @@ -56,50 +58,16 @@ class op_strans arma_hot inline static void apply_mat_inplace(Mat& out); template - arma_hot inline static void apply_mat(Mat& out, const TA& A); - - template - arma_hot inline static void apply_proxy(Mat& out, const T1& X); + inline static void apply_mat(Mat& out, const TA& A); template - arma_hot inline static void apply(Mat& out, const Op& in); - }; - - - -class op_strans2 - { - public: + inline static void apply_proxy(Mat& out, const Proxy& P); template - struct traits - { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; - }; - - template - struct pos - { - static const uword n2 = (do_flip == false) ? (row + col*2) : (col + row*2); - static const uword n3 = (do_flip == false) ? (row + col*3) : (col + row*3); - static const uword n4 = (do_flip == false) ? (row + col*4) : (col + row*4); - }; - - template - arma_cold inline static void apply_noalias_tinysq(Mat& out, const TA& A, const eT val); - - template - arma_hot inline static void apply_noalias(Mat& out, const TA& A, const eT val); - - template - arma_hot inline static void apply(Mat& out, const TA& A, const eT val); + inline static void apply_direct(Mat& out, const T1& X); template - arma_hot inline static void apply_proxy(Mat& out, const T1& X, const typename T1::elem_type val); - - // NOTE: there is no direct handling of Op, as op_strans2::apply_proxy() is currently only called by op_htrans2 for non-complex numbers + inline static void apply(Mat& out, const Op& in); }; diff --git a/src/armadillo_bits/op_strans_meat.hpp b/src/armadillo_bits/op_strans_meat.hpp index 8846975c..ed02d3be 100644 --- a/src/armadillo_bits/op_strans_meat.hpp +++ b/src/armadillo_bits/op_strans_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,6 @@ //! for tiny square matrices (size <= 4x4) template -arma_cold inline void op_strans::apply_mat_noalias_tinysq(Mat& out, const TA& A) @@ -96,7 +97,6 @@ op_strans::apply_mat_noalias_tinysq(Mat& out, const TA& A) template -arma_hot inline void op_strans::block_worker(eT* Y, const eT* X, const uword X_n_rows, const uword Y_n_rows, const uword n_rows, const uword n_cols) @@ -117,7 +117,6 @@ op_strans::block_worker(eT* Y, const eT* X, const uword X_n_rows, const uword Y_ template -arma_hot inline void op_strans::apply_mat_noalias_large(Mat& out, const Mat& A) @@ -174,7 +173,6 @@ op_strans::apply_mat_noalias_large(Mat& out, const Mat& A) //! Immediate transpose of a dense matrix template -arma_hot inline void op_strans::apply_mat_noalias(Mat& out, const TA& A) @@ -231,7 +229,6 @@ op_strans::apply_mat_noalias(Mat& out, const TA& A) template -arma_hot inline void op_strans::apply_mat_inplace(Mat& out) @@ -282,7 +279,6 @@ op_strans::apply_mat_inplace(Mat& out) template -arma_hot inline void op_strans::apply_mat(Mat& out, const TA& A) @@ -302,373 +298,68 @@ op_strans::apply_mat(Mat& out, const TA& A) template -arma_hot inline void -op_strans::apply_proxy(Mat& out, const T1& X) +op_strans::apply_proxy(Mat& out, const Proxy& P) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(X); + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); - // allow detection of in-place transpose - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) + if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) { - const unwrap::stored_type> tmp(P.Q); + out.set_size(n_cols, n_rows); - op_strans::apply_mat(out, tmp.M); - } - else - { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + eT* out_mem = out.memptr(); - const bool is_alias = P.is_alias(out); + const uword n_elem = P.get_n_elem(); - if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) - { - if(is_alias == false) - { - out.set_size(n_cols, n_rows); - - eT* out_mem = out.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - uword i,j; - for(i=0, j=1; j < n_elem; i+=2, j+=2) - { - const eT tmp_i = Pea[i]; - const eT tmp_j = Pea[j]; - - out_mem[i] = tmp_i; - out_mem[j] = tmp_j; - } - - if(i < n_elem) - { - out_mem[i] = Pea[i]; - } - } - else // aliasing - { - Mat out2(n_cols, n_rows); - - eT* out_mem = out2.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - uword i,j; - for(i=0, j=1; j < n_elem; i+=2, j+=2) - { - const eT tmp_i = Pea[i]; - const eT tmp_j = Pea[j]; - - out_mem[i] = tmp_i; - out_mem[j] = tmp_j; - } - - if(i < n_elem) - { - out_mem[i] = Pea[i]; - } - - out.steal_mem(out2); - } - } - else // general matrix transpose - { - if(is_alias == false) - { - out.set_size(n_cols, n_rows); - - eT* outptr = out.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - uword j; - for(j=1; j < n_cols; j+=2) - { - const uword i = j-1; - - const eT tmp_i = P.at(k,i); - const eT tmp_j = P.at(k,j); - - (*outptr) = tmp_i; outptr++; - (*outptr) = tmp_j; outptr++; - } - - const uword i = j-1; - - if(i < n_cols) - { - (*outptr) = P.at(k,i); outptr++; - } - } - } - else // aliasing - { - Mat out2(n_cols, n_rows); - - eT* out2ptr = out2.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - uword j; - for(j=1; j < n_cols; j+=2) - { - const uword i = j-1; - - const eT tmp_i = P.at(k,i); - const eT tmp_j = P.at(k,j); - - (*out2ptr) = tmp_i; out2ptr++; - (*out2ptr) = tmp_j; out2ptr++; - } - - const uword i = j-1; - - if(i < n_cols) - { - (*out2ptr) = P.at(k,i); out2ptr++; - } - } - - out.steal_mem(out2); - } - } - } - } - - - -template -arma_hot -inline -void -op_strans::apply(Mat& out, const Op& in) - { - arma_extra_debug_sigprint(); - - op_strans::apply_proxy(out, in.m); - } - - - -// -// op_strans2 - - - -//! for tiny square matrices (size <= 4x4) -template -arma_cold -inline -void -op_strans2::apply_noalias_tinysq(Mat& out, const TA& A, const eT val) - { - const eT* Am = A.memptr(); - eT* outm = out.memptr(); - - switch(A.n_rows) - { - case 1: - { - outm[0] = val * Am[0]; - } - break; - - case 2: - { - outm[pos::n2] = val * Am[pos::n2]; - outm[pos::n2] = val * Am[pos::n2]; - - outm[pos::n2] = val * Am[pos::n2]; - outm[pos::n2] = val * Am[pos::n2]; - } - break; + typename Proxy::ea_type Pea = P.get_ea(); - case 3: + uword i,j; + for(i=0, j=1; j < n_elem; i+=2, j+=2) { - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; + const eT tmp_i = Pea[i]; + const eT tmp_j = Pea[j]; - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; - - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; - outm[pos::n3] = val * Am[pos::n3]; + out_mem[i] = tmp_i; + out_mem[j] = tmp_j; } - break; - case 4: + if(i < n_elem) { - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; - outm[pos::n4] = val * Am[pos::n4]; + out_mem[i] = Pea[i]; } - break; - - default: - ; } - - } - - - -template -arma_hot -inline -void -op_strans2::apply_noalias(Mat& out, const TA& A, const eT val) - { - arma_extra_debug_sigprint(); - - const uword A_n_cols = A.n_cols; - const uword A_n_rows = A.n_rows; - - out.set_size(A_n_cols, A_n_rows); - - if( (TA::is_col) || (TA::is_row) || (A_n_cols == 1) || (A_n_rows == 1) ) + else // general matrix transpose { - const uword N = A.n_elem; - - const eT* A_mem = A.memptr(); - eT* out_mem = out.memptr(); + out.set_size(n_cols, n_rows); - uword i,j; - for(i=0, j=1; j < N; i+=2, j+=2) - { - const eT tmp_i = A_mem[i]; - const eT tmp_j = A_mem[j]; - - out_mem[i] = val * tmp_i; - out_mem[j] = val * tmp_j; - } + eT* outptr = out.memptr(); - if(i < N) - { - out_mem[i] = val * A_mem[i]; - } - } - else - { - if( (A_n_rows <= 4) && (A_n_rows == A_n_cols) ) - { - op_strans2::apply_noalias_tinysq(out, A, val); - } - else - if( (A_n_rows >= 512) && (A_n_cols >= 512) ) - { - op_strans::apply_mat_noalias_large(out, A); - arrayops::inplace_mul( out.memptr(), val, out.n_elem ); - } - else + for(uword k=0; k < n_rows; ++k) { - eT* outptr = out.memptr(); - - for(uword k=0; k < A_n_rows; ++k) + uword j; + for(j=1; j < n_cols; j+=2) { - const eT* Aptr = &(A.at(k,0)); + const uword i = j-1; - uword j; - for(j=1; j < A_n_cols; j+=2) - { - const eT tmp_i = (*Aptr); Aptr += A_n_rows; - const eT tmp_j = (*Aptr); Aptr += A_n_rows; - - (*outptr) = val * tmp_i; outptr++; - (*outptr) = val * tmp_j; outptr++; - } + const eT tmp_i = P.at(k,i); + const eT tmp_j = P.at(k,j); - if((j-1) < A_n_cols) - { - (*outptr) = val * (*Aptr); outptr++;; - } + (*outptr) = tmp_i; outptr++; + (*outptr) = tmp_j; outptr++; } - } - } - } - - - -template -arma_hot -inline -void -op_strans2::apply(Mat& out, const TA& A, const eT val) - { - arma_extra_debug_sigprint(); - - if(&out != &A) - { - op_strans2::apply_noalias(out, A, val); - } - else - { - const uword n_rows = out.n_rows; - const uword n_cols = out.n_cols; - - if(n_rows == n_cols) - { - arma_extra_debug_print("op_strans2::apply(): doing in-place transpose of a square matrix"); - const uword N = n_rows; + const uword i = j-1; - // TODO: do multiplication while swapping - - for(uword k=0; k < N; ++k) + if(i < n_cols) { - eT* colptr = out.colptr(k); - - uword i,j; - - for(i=(k+1), j=(k+2); j < N; i+=2, j+=2) - { - std::swap(out.at(k,i), colptr[i]); - std::swap(out.at(k,j), colptr[j]); - } - - if(i < N) - { - std::swap(out.at(k,i), colptr[i]); - } + (*outptr) = P.at(k,i); outptr++; } - - arrayops::inplace_mul( out.memptr(), val, out.n_elem ); - } - else - { - Mat tmp; - op_strans2::apply_noalias(tmp, A, val); - - out.steal_mem(tmp); } } } @@ -676,145 +367,57 @@ op_strans2::apply(Mat& out, const TA& A, const eT val) template -arma_hot inline void -op_strans2::apply_proxy(Mat& out, const T1& X, const typename T1::elem_type val) +op_strans::apply_direct(Mat& out, const T1& X) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; - const Proxy P(X); - // allow detection of in-place transpose - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) { - const unwrap::stored_type> tmp(P.Q); + const unwrap U(X); - op_strans2::apply(out, tmp.M, val); + op_strans::apply_mat(out, U.M); } else { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); + const Proxy P(X); const bool is_alias = P.is_alias(out); - if( (resolves_to_vector::yes) && (Proxy::use_at == false) ) + if(is_Mat::stored_type>::value) { - if(is_alias == false) + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) { - out.set_size(n_cols, n_rows); - - eT* out_mem = out.memptr(); - - const uword n_elem = P.get_n_elem(); + Mat tmp; - typename Proxy::ea_type Pea = P.get_ea(); + op_strans::apply_mat_noalias(tmp, U.M); - uword i,j; - for(i=0, j=1; j < n_elem; i+=2, j+=2) - { - const eT tmp_i = Pea[i]; - const eT tmp_j = Pea[j]; - - out_mem[i] = val * tmp_i; - out_mem[j] = val * tmp_j; - } - - if(i < n_elem) - { - out_mem[i] = val * Pea[i]; - } + out.steal_mem(tmp); } - else // aliasing + else { - Mat out2(n_cols, n_rows); - - eT* out_mem = out2.memptr(); - - const uword n_elem = P.get_n_elem(); - - typename Proxy::ea_type Pea = P.get_ea(); - - uword i,j; - for(i=0, j=1; j < n_elem; i+=2, j+=2) - { - const eT tmp_i = Pea[i]; - const eT tmp_j = Pea[j]; - - out_mem[i] = val * tmp_i; - out_mem[j] = val * tmp_j; - } - - if(i < n_elem) - { - out_mem[i] = val * Pea[i]; - } - - out.steal_mem(out2); + op_strans::apply_mat_noalias(out, U.M); } } - else // general matrix transpose + else { - if(is_alias == false) + if(is_alias) { - out.set_size(n_cols, n_rows); + Mat tmp; - eT* outptr = out.memptr(); + op_strans::apply_proxy(tmp, P); - for(uword k=0; k < n_rows; ++k) - { - uword j; - for(j=1; j < n_cols; j+=2) - { - const uword i = j-1; - - const eT tmp_i = P.at(k,i); - const eT tmp_j = P.at(k,j); - - (*outptr) = val * tmp_i; outptr++; - (*outptr) = val * tmp_j; outptr++; - } - - const uword i = j-1; - - if(i < n_cols) - { - (*outptr) = val * P.at(k,i); outptr++; - } - } + out.steal_mem(tmp); } - else // aliasing + else { - Mat out2(n_cols, n_rows); - - eT* out2ptr = out2.memptr(); - - for(uword k=0; k < n_rows; ++k) - { - uword j; - for(j=1; j < n_cols; j+=2) - { - const uword i = j-1; - - const eT tmp_i = P.at(k,i); - const eT tmp_j = P.at(k,j); - - (*out2ptr) = val * tmp_i; out2ptr++; - (*out2ptr) = val * tmp_j; out2ptr++; - } - - const uword i = j-1; - - if(i < n_cols) - { - (*out2ptr) = val * P.at(k,i); out2ptr++; - } - } - - out.steal_mem(out2); + op_strans::apply_proxy(out, P); } } } @@ -822,6 +425,24 @@ op_strans2::apply_proxy(Mat& out, const T1& X, const typ +template +inline +void +op_strans::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + op_strans::apply_direct(out, in.m); + } + + + +// +// +// + + + template inline void diff --git a/src/armadillo_bits/op_sum_bones.hpp b/src/armadillo_bits/op_sum_bones.hpp index 93fe1255..a72c7869 100644 --- a/src/armadillo_bits/op_sum_bones.hpp +++ b/src/armadillo_bits/op_sum_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -37,9 +39,6 @@ class op_sum template arma_hot inline static void apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim); - template - arma_hot inline static void apply_noalias_proxy_mp(Mat& out, const Proxy& P, const uword dim); - // cubes @@ -54,9 +53,6 @@ class op_sum template arma_hot inline static void apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim); - - template - arma_hot inline static void apply_noalias_proxy_mp(Cube& out, const ProxyCube& P, const uword dim); }; diff --git a/src/armadillo_bits/op_sum_meat.hpp b/src/armadillo_bits/op_sum_meat.hpp index f9f564fb..f759daaf 100644 --- a/src/armadillo_bits/op_sum_meat.hpp +++ b/src/armadillo_bits/op_sum_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void op_sum::apply(Mat& out, const Op& in) @@ -51,14 +52,13 @@ op_sum::apply(Mat& out, const Op& in) template -arma_hot inline void op_sum::apply_noalias(Mat& out, const Proxy& P, const uword dim) { arma_extra_debug_sigprint(); - if(is_Mat::stored_type>::value) + if(is_Mat::stored_type>::value || (arma_config::openmp && Proxy::use_mp)) { op_sum::apply_noalias_unwrap(out, P, dim); } @@ -71,7 +71,6 @@ op_sum::apply_noalias(Mat& out, const Proxy& P, cons template -arma_hot inline void op_sum::apply_noalias_unwrap(Mat& out, const Proxy& P, const uword dim) @@ -89,26 +88,36 @@ op_sum::apply_noalias_unwrap(Mat& out, const Proxy& const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; + const uword out_n_rows = (dim == 0) ? uword(1) : X_n_rows; + const uword out_n_cols = (dim == 0) ? X_n_cols : uword(1); + + out.set_size(out_n_rows, out_n_cols); + + if(X.n_elem == 0) { out.zeros(); return; } + + const eT* X_colptr = X.memptr(); + eT* out_mem = out.memptr(); + if(dim == 0) { - out.set_size(1, X_n_cols); - - eT* out_mem = out.memptr(); - for(uword col=0; col < X_n_cols; ++col) { - out_mem[col] = arrayops::accumulate( X.colptr(col), X_n_rows ); + out_mem[col] = arrayops::accumulate( X_colptr, X_n_rows ); + + X_colptr += X_n_rows; } } else { - out.zeros(X_n_rows, 1); + arrayops::copy(out_mem, X_colptr, X_n_rows); - eT* out_mem = out.memptr(); + X_colptr += X_n_rows; - for(uword col=0; col < X_n_cols; ++col) + for(uword col=1; col < X_n_cols; ++col) { - arrayops::inplace_plus( out_mem, X.colptr(col), X_n_rows ); + arrayops::inplace_plus( out_mem, X_colptr, X_n_rows ); + + X_colptr += X_n_rows; } } } @@ -116,7 +125,6 @@ op_sum::apply_noalias_unwrap(Mat& out, const Proxy& template -arma_hot inline void op_sum::apply_noalias_proxy(Mat& out, const Proxy& P, const uword dim) @@ -125,82 +133,64 @@ op_sum::apply_noalias_proxy(Mat& out, const Proxy& P typedef typename T1::elem_type eT; - if( arma_config::openmp && Proxy::use_mp && mp_gate::eval(P.get_n_elem()) ) - { - op_sum::apply_noalias_proxy_mp(out, P, dim); - - return; - } - const uword P_n_rows = P.get_n_rows(); const uword P_n_cols = P.get_n_cols(); - if(dim == 0) + const uword out_n_rows = (dim == 0) ? uword(1) : P_n_rows; + const uword out_n_cols = (dim == 0) ? P_n_cols : uword(1); + + out.set_size(out_n_rows, out_n_cols); + + if(P.get_n_elem() == 0) { out.zeros(); return; } + + eT* out_mem = out.memptr(); + + if(Proxy::use_at == false) { - out.set_size(1, P_n_cols); - - eT* out_mem = out.memptr(); - - for(uword col=0; col < P_n_cols; ++col) + if(dim == 0) { - eT val1 = eT(0); - eT val2 = eT(0); + uword count = 0; - uword i,j; - for(i=0, j=1; j < P_n_rows; i+=2, j+=2) + for(uword col=0; col < P_n_cols; ++col) { - val1 += P.at(i,col); - val2 += P.at(j,col); + eT val1 = eT(0); + eT val2 = eT(0); + + uword j; + for(j=1; j < P_n_rows; j+=2) + { + val1 += P[count]; ++count; + val2 += P[count]; ++count; + } + + if((j-1) < P_n_rows) + { + val1 += P[count]; ++count; + } + + out_mem[col] = (val1 + val2); } + } + else + { + uword count = 0; - if(i < P_n_rows) + for(uword row=0; row < P_n_rows; ++row) { - val1 += P.at(i,col); + out_mem[row] = P[count]; ++count; } - out_mem[col] = (val1 + val2); + for(uword col=1; col < P_n_cols; ++col) + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] += P[count]; ++count; + } } } else { - out.zeros(P_n_rows, 1); - - eT* out_mem = out.memptr(); - - for(uword col=0; col < P_n_cols; ++col) - for(uword row=0; row < P_n_rows; ++row) - { - out_mem[row] += P.at(row,col); - } - } - } - - - -template -arma_hot -inline -void -op_sum::apply_noalias_proxy_mp(Mat& out, const Proxy& P, const uword dim) - { - arma_extra_debug_sigprint(); - - #if defined(ARMA_USE_OPENMP) - { - typedef typename T1::elem_type eT; - - const uword P_n_rows = P.get_n_rows(); - const uword P_n_cols = P.get_n_cols(); - - const int n_threads = mp_thread_limit::get(); - if(dim == 0) { - out.set_size(1, P_n_cols); - - eT* out_mem = out.memptr(); - - #pragma omp parallel for schedule(static) num_threads(n_threads) for(uword col=0; col < P_n_cols; ++col) { eT val1 = eT(0); @@ -223,30 +213,18 @@ op_sum::apply_noalias_proxy_mp(Mat& out, const Proxy } else { - out.set_size(P_n_rows, 1); - - eT* out_mem = out.memptr(); + for(uword row=0; row < P_n_rows; ++row) + { + out_mem[row] = P.at(row,0); + } - #pragma omp parallel for schedule(static) num_threads(n_threads) + for(uword col=1; col < P_n_cols; ++col) for(uword row=0; row < P_n_rows; ++row) { - eT acc = eT(0); - for(uword col=0; col < P_n_cols; ++col) - { - acc += P.at(row,col); - } - - out_mem[row] = acc; + out_mem[row] += P.at(row,col); } } } - #else - { - arma_ignore(out); - arma_ignore(P); - arma_ignore(dim); - } - #endif } @@ -257,7 +235,6 @@ op_sum::apply_noalias_proxy_mp(Mat& out, const Proxy template -arma_hot inline void op_sum::apply(Cube& out, const OpCube& in) @@ -288,14 +265,13 @@ op_sum::apply(Cube& out, const OpCube& in) template -arma_hot inline void op_sum::apply_noalias(Cube& out, const ProxyCube& P, const uword dim) { arma_extra_debug_sigprint(); - if(is_Cube::stored_type>::value) + if(is_Cube::stored_type>::value || (arma_config::openmp && ProxyCube::use_mp)) { op_sum::apply_noalias_unwrap(out, P, dim); } @@ -308,7 +284,6 @@ op_sum::apply_noalias(Cube& out, const ProxyCube& P, template -arma_hot inline void op_sum::apply_noalias_unwrap(Cube& out, const ProxyCube& P, const uword dim) @@ -373,7 +348,6 @@ op_sum::apply_noalias_unwrap(Cube& out, const ProxyCube< template -arma_hot inline void op_sum::apply_noalias_proxy(Cube& out, const ProxyCube& P, const uword dim) @@ -382,13 +356,6 @@ op_sum::apply_noalias_proxy(Cube& out, const ProxyCube::use_mp && mp_gate::eval(P.get_n_elem()) ) - { - op_sum::apply_noalias_proxy_mp(out, P, dim); - - return; - } - const uword P_n_rows = P.get_n_rows(); const uword P_n_cols = P.get_n_cols(); const uword P_n_slices = P.get_n_slices(); @@ -460,121 +427,4 @@ op_sum::apply_noalias_proxy(Cube& out, const ProxyCube -arma_hot -inline -void -op_sum::apply_noalias_proxy_mp(Cube& out, const ProxyCube& P, const uword dim) - { - arma_extra_debug_sigprint(); - - #if defined(ARMA_USE_OPENMP) - { - typedef typename T1::elem_type eT; - - const uword P_n_rows = P.get_n_rows(); - const uword P_n_cols = P.get_n_cols(); - const uword P_n_slices = P.get_n_slices(); - - const int n_threads = mp_thread_limit::get(); - - if(dim == 0) - { - out.set_size(1, P_n_cols, P_n_slices); - - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword slice=0; slice < P_n_slices; ++slice) - { - eT* out_mem = out.slice_memptr(slice); - - for(uword col=0; col < P_n_cols; ++col) - { - eT val1 = eT(0); - eT val2 = eT(0); - - uword i,j; - for(i=0, j=1; j < P_n_rows; i+=2, j+=2) - { - val1 += P.at(i,col,slice); - val2 += P.at(j,col,slice); - } - - if(i < P_n_rows) - { - val1 += P.at(i,col,slice); - } - - out_mem[col] = (val1 + val2); - } - } - } - else - if(dim == 1) - { - out.zeros(P_n_rows, 1, P_n_slices); - - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword slice=0; slice < P_n_slices; ++slice) - { - eT* out_mem = out.slice_memptr(slice); - - for(uword col=0; col < P_n_cols; ++col) - for(uword row=0; row < P_n_rows; ++row) - { - out_mem[row] += P.at(row,col,slice); - } - } - } - else - if(dim == 2) - { - out.zeros(P_n_rows, P_n_cols, 1); - - if(P_n_cols >= P_n_rows) - { - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword col=0; col < P_n_cols; ++col) - { - for(uword row=0; row < P_n_rows; ++row) - { - eT acc = eT(0); - for(uword slice=0; slice < P_n_slices; ++slice) - { - acc += P.at(row,col,slice); - } - - out.at(row,col,0) = acc; - } - } - } - else - { - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword row=0; row < P_n_rows; ++row) - { - for(uword col=0; col < P_n_cols; ++col) - { - eT acc = eT(0); - for(uword slice=0; slice < P_n_slices; ++slice) - { - acc += P.at(row,col,slice); - } - - out.at(row,col,0) = acc; - } - } - } - } - } - #else - { - arma_ignore(out); - arma_ignore(P); - arma_ignore(dim); - } - #endif - } - - - //! @} diff --git a/src/armadillo_bits/op_symmat_bones.hpp b/src/armadillo_bits/op_symmat_bones.hpp index e17c1ec3..1e07f0e2 100644 --- a/src/armadillo_bits/op_symmat_bones.hpp +++ b/src/armadillo_bits/op_symmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,24 +21,46 @@ -class op_symmat +class op_symmatu + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_symmatl + : public traits_op_default + { + public: + + template + inline static void apply(Mat& out, const Op& in); + }; + + + +class op_symmatu_cx : public traits_op_default { public: template - inline static void apply(Mat& out, const Op& in); + inline static void apply(Mat& out, const Op& in); }; -class op_symmat_cx +class op_symmatl_cx : public traits_op_default { public: template - inline static void apply(Mat& out, const Op& in); + inline static void apply(Mat& out, const Op& in); }; diff --git a/src/armadillo_bits/op_symmat_meat.hpp b/src/armadillo_bits/op_symmat_meat.hpp index b706b3c7..52731bd8 100644 --- a/src/armadillo_bits/op_symmat_meat.hpp +++ b/src/armadillo_bits/op_symmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,7 +24,7 @@ template inline void -op_symmat::apply(Mat& out, const Op& in) +op_symmatu::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); @@ -31,78 +33,104 @@ op_symmat::apply(Mat& out, const Op& in) const unwrap tmp(in.m); const Mat& A = tmp.M; - arma_debug_check( (A.is_square() == false), "symmatu()/symmatl(): given matrix must be square sized" ); + arma_debug_check( (A.is_square() == false), "symmatu(): given matrix must be square sized" ); - const uword N = A.n_rows; - const bool upper = (in.aux_uword_a == 0); + const uword N = A.n_rows; if(&out != &A) { out.copy_size(A); - if(upper) + // upper triangular: copy the diagonal and the elements above the diagonal + + for(uword i=0; i +inline +void +op_symmatl::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; - if(upper) + arma_debug_check( (A.is_square() == false), "symmatl(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + if(&out != &A) { - // reflect elements across the diagonal from upper triangle to lower triangle + out.copy_size(A); - for(uword col=1; col < N; ++col) + // lower triangular: copy the diagonal and the elements below the diagonal + + for(uword i=0; i inline void -op_symmat_cx::apply(Mat& out, const Op& in) +op_symmatu_cx::apply(Mat& out, const Op& in) { arma_extra_debug_sigprint(); @@ -111,103 +139,135 @@ op_symmat_cx::apply(Mat& out, const Op& const unwrap tmp(in.m); const Mat& A = tmp.M; - arma_debug_check( (A.is_square() == false), "symmatu()/symmatl(): given matrix must be square sized" ); + arma_debug_check( (A.is_square() == false), "symmatu(): given matrix must be square sized" ); const uword N = A.n_rows; - const bool upper = (in.aux_uword_a == 0); const bool do_conj = (in.aux_uword_b == 1); if(&out != &A) { out.copy_size(A); - if(upper) - { - // upper triangular: copy the diagonal and the elements above the diagonal - - for(uword i=0; i +inline +void +op_symmatl_cx::apply(Mat& out, const Op& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap tmp(in.m); + const Mat& A = tmp.M; + + arma_debug_check( (A.is_square() == false), "symmatl(): given matrix must be square sized" ); + + const uword N = A.n_rows; + + const bool do_conj = (in.aux_uword_b == 1); + + if(&out != &A) { - if(upper) + out.copy_size(A); + + // lower triangular: copy the diagonal and the elements below the diagonal + + for(uword i=0; i& out, const Op& i const unwrap_check tmp(in.m, out); const Mat& X = tmp.M; - arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "toeplitz(): given object is not a vector" ); + arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "toeplitz(): given object must be a vector" ); const uword N = X.n_elem; const eT* X_mem = X.memptr(); @@ -66,14 +68,14 @@ op_toeplitz_c::apply(Mat& out, const Op tmp(in.m, out); const Mat& X = tmp.M; - arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "circ_toeplitz(): given object is not a vector" ); + arma_debug_check( ((X.is_vec() == false) && (X.is_empty() == false)), "circ_toeplitz(): given object must be a vector" ); const uword N = X.n_elem; const eT* X_mem = X.memptr(); out.set_size(N,N); - if(X.is_rowvec() == true) + if(X.is_rowvec()) { for(uword row=0; row < N; ++row) { diff --git a/src/armadillo_bits/op_trimat_bones.hpp b/src/armadillo_bits/op_trimat_bones.hpp index 4d564fe5..f500cbd4 100644 --- a/src/armadillo_bits/op_trimat_bones.hpp +++ b/src/armadillo_bits/op_trimat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,6 +21,8 @@ +// NOTE: don't split op_trimat into seperate op_trimatu and op_trimatl classes, +// NOTE: as several instances elsewhere rely on trimatu() and trimatl() producing the same type class op_trimat : public traits_op_default { @@ -32,16 +36,11 @@ class op_trimat template inline static void apply(Mat& out, const Op& in); - template - inline static void apply(Mat& out, const Op, op_trimat>& in); - - // - template - inline static void apply_htrans(Mat& out, const Mat& A, const bool upper, const typename arma_not_cx::result* junk = 0); + inline static void apply_unwrap(Mat& out, const Mat& A, const bool upper); - template - inline static void apply_htrans(Mat& out, const Mat& A, const bool upper, const typename arma_cx_only::result* junk = 0); + template + inline static void apply_proxy(Mat& out, const Proxy& P, const bool upper); }; diff --git a/src/armadillo_bits/op_trimat_meat.hpp b/src/armadillo_bits/op_trimat_meat.hpp index fa63ea16..79225153 100644 --- a/src/armadillo_bits/op_trimat_meat.hpp +++ b/src/armadillo_bits/op_trimat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -36,7 +38,7 @@ op_trimat::fill_zeros(Mat& out, const bool upper) { eT* data = out.colptr(i); - arrayops::inplace_set( &data[i+1], eT(0), (N-(i+1)) ); + arrayops::fill_zeros( &data[i+1], (N-(i+1)) ); } } else @@ -47,7 +49,7 @@ op_trimat::fill_zeros(Mat& out, const bool upper) { eT* data = out.colptr(i); - arrayops::inplace_set( data, eT(0), i ); + arrayops::fill_zeros( data, i ); } } } @@ -63,62 +65,54 @@ op_trimat::apply(Mat& out, const Op& in) typedef typename T1::elem_type eT; - const unwrap tmp(in.m); - const Mat& A = tmp.M; - - arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square sized" ); - - const uword N = A.n_rows; - const bool upper = (in.aux_uword_a == 0); + const bool upper = (in.aux_uword_a == 0); - if(&out != &A) + // allow detection of in-place operation + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) { - out.copy_size(A); + const unwrap U(in.m); - if(upper) + op_trimat::apply_unwrap(out, U.M, upper); + } + else + { + const Proxy P(in.m); + + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) { - // upper triangular: copy the diagonal and the elements above the diagonal - for(uword i=0; i::stored_type> U(P.Q); + + if(is_alias) { - const eT* A_data = A.colptr(i); - eT* out_data = out.colptr(i); + Mat tmp; - arrayops::copy( out_data, A_data, i+1 ); + op_trimat::apply_unwrap(tmp, U.M, upper); + + out.steal_mem(tmp); + } + else + { + op_trimat::apply_unwrap(out, U.M, upper); } } else { - // lower triangular: copy the diagonal and the elements below the diagonal - for(uword i=0; i tmp; - arrayops::copy( &out_data[i], &A_data[i], N-i ); + op_trimat::apply_proxy(tmp, P, upper); + + out.steal_mem(tmp); + } + else + { + op_trimat::apply_proxy(out, P, upper); } } } - - op_trimat::fill_zeros(out, upper); - } - - - -template -inline -void -op_trimat::apply(Mat& out, const Op, op_trimat>& in) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT; - - const unwrap tmp(in.m.m); - const Mat& A = tmp.M; - - const bool upper = (in.aux_uword_a == 0); - - op_trimat::apply_htrans(out, A, upper); } @@ -126,62 +120,38 @@ op_trimat::apply(Mat& out, const Op, o template inline void -op_trimat::apply_htrans - ( - Mat& out, - const Mat& A, - const bool upper, - const typename arma_not_cx::result* junk - ) +op_trimat::apply_unwrap(Mat& out, const Mat& A, const bool upper) { arma_extra_debug_sigprint(); - arma_ignore(junk); - - // This specialisation is for trimatl(trans(X)) = trans(trimatu(X)) and also - // trimatu(trans(X)) = trans(trimatl(X)). We want to avoid the creation of an - // extra temporary. - - // It doesn't matter if the input and output matrices are the same; we will - // pull data from the upper or lower triangular to the lower or upper - // triangular (respectively) and then set the rest to 0, so overwriting issues - // aren't present. arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square sized" ); - const uword N = A.n_rows; - if(&out != &A) { out.copy_size(A); - } - - // We can't really get away with any array copy operations here, - // unfortunately... - - if(upper) - { - // Upper triangular: but since we're transposing, we're taking the lower - // triangular and putting it in the upper half. - for(uword row = 0; row < N; ++row) + + const uword N = A.n_rows; + + if(upper) { - eT* out_colptr = out.colptr(row); - - for(uword col = 0; col <= row; ++col) + // upper triangular: copy the diagonal and the elements above the diagonal + for(uword i=0; i +template inline void -op_trimat::apply_htrans - ( - Mat& out, - const Mat& A, - const bool upper, - const typename arma_cx_only::result* junk - ) +op_trimat::apply_proxy(Mat& out, const Proxy& P, const bool upper) { arma_extra_debug_sigprint(); - arma_ignore(junk); - arma_debug_check( (A.is_square() == false), "trimatu()/trimatl(): given matrix must be square sized" ); + arma_debug_check( (P.get_n_rows() != P.get_n_cols()), "trimatu()/trimatl(): given matrix must be square sized" ); - const uword N = A.n_rows; + const uword N = P.get_n_rows(); - if(&out != &A) - { - out.copy_size(A); - } + out.set_size(N,N); if(upper) { - // Upper triangular: but since we're transposing, we're taking the lower - // triangular and putting it in the upper half. - for(uword row = 0; row < N; ++row) + for(uword j=0; j < N; ++j) + for(uword i=0; i < (j+1); ++i) { - eT* out_colptr = out.colptr(row); - - for(uword col = 0; col <= row; ++col) - { - //out.at(col, row) = std::conj( A.at(row, col) ); - out_colptr[col] = std::conj( A.at(row, col) ); - } + out.at(i,j) = P.at(i,j); } } else { - // Lower triangular: but since we're transposing, we're taking the upper - // triangular and putting it in the lower half. - for(uword row = 0; row < N; ++row) + for(uword j=0; j& out, const Op 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" ); + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" ); if(&out != &A) { @@ -362,7 +311,7 @@ op_trimatl_ext::apply(Mat& out, const Op 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" ); + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" ); if(&out != &A) { diff --git a/src/armadillo_bits/op_unique_bones.hpp b/src/armadillo_bits/op_unique_bones.hpp index 05c05537..7e7bb692 100644 --- a/src/armadillo_bits/op_unique_bones.hpp +++ b/src/armadillo_bits/op_unique_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_unique_meat.hpp b/src/armadillo_bits/op_unique_meat.hpp index baa37eac..1605ea7c 100644 --- a/src/armadillo_bits/op_unique_meat.hpp +++ b/src/armadillo_bits/op_unique_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,7 +58,7 @@ op_unique::apply_helper(Mat& out, const Proxy& P, co return true; } - Mat X(n_elem,1); + Mat X(n_elem, 1, arma_nozeros_indicator()); eT* X_mem = X.memptr(); diff --git a/src/armadillo_bits/op_var_bones.hpp b/src/armadillo_bits/op_var_bones.hpp index 16418a9f..ee13bd48 100644 --- a/src/armadillo_bits/op_var_bones.hpp +++ b/src/armadillo_bits/op_var_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,6 @@ -//! Class for finding variance values of a matrix class op_var : public traits_op_xvec { @@ -28,12 +29,14 @@ class op_var template inline static void apply(Mat& out, const mtOp& in); + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim); // template inline static typename get_pod_type::result var_vec(const subview_col& X, const uword norm_type = 0); - + template inline static typename get_pod_type::result var_vec(const subview_row& X, const uword norm_type = 0); diff --git a/src/armadillo_bits/op_var_meat.hpp b/src/armadillo_bits/op_var_meat.hpp index 17de4dc9..49a252d4 100644 --- a/src/armadillo_bits/op_var_meat.hpp +++ b/src/armadillo_bits/op_var_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -18,10 +20,7 @@ //! @{ -//! \brief -//! For each row or for each column, find the variance. -//! The result is stored in a dense matrix that has either one column or one row. -//! The dimension, for which the variances are found, is set via the var() function. + template inline void @@ -29,11 +28,7 @@ op_var::apply(Mat& out, const mtOp tmp(in.m, out); - const Mat& X = tmp.M; + typedef typename T1::pod_type out_eT; const uword norm_type = in.aux_uword_a; const uword dim = in.aux_uword_b; @@ -41,12 +36,39 @@ op_var::apply(Mat& out, const mtOp 1), "var(): parameter 'norm_type' must be 0 or 1" ); arma_debug_check( (dim > 1), "var(): parameter 'dim' must be 0 or 1" ); + const quasi_unwrap U(in.m); + + if(U.is_alias(out)) + { + Mat tmp; + + op_var::apply_noalias(tmp, U.M, norm_type, dim); + + out.steal_mem(tmp); + } + else + { + op_var::apply_noalias(out, U.M, norm_type, dim); + } + } + + + +template +inline +void +op_var::apply_noalias(Mat::result>& out, const Mat& X, const uword norm_type, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + const uword X_n_rows = X.n_rows; const uword X_n_cols = X.n_cols; if(dim == 0) { - arma_extra_debug_print("op_var::apply(): dim = 0"); + arma_extra_debug_print("op_var::apply_noalias(): dim = 0"); out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); @@ -63,7 +85,7 @@ op_var::apply(Mat& out, const mtOp 0) ? 1 : 0); diff --git a/src/armadillo_bits/op_vecnorm_bones.hpp b/src/armadillo_bits/op_vecnorm_bones.hpp new file mode 100644 index 00000000..7640aa42 --- /dev/null +++ b/src/armadillo_bits/op_vecnorm_bones.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_vecnorm +//! @{ + + +class op_vecnorm + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword k, const uword dim); + + template + inline static void apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword k); + }; + + +class op_vecnorm_ext + : public traits_op_xvec + { + public: + + template + inline static void apply(Mat& out, const mtOp& in); + + template + inline static void apply_noalias(Mat::result>& out, const Mat& X, const uword method_id, const uword dim); + + template + inline static void apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword method_id); + }; + + +//! @} diff --git a/src/armadillo_bits/op_vecnorm_meat.hpp b/src/armadillo_bits/op_vecnorm_meat.hpp new file mode 100644 index 00000000..7f9f6642 --- /dev/null +++ b/src/armadillo_bits/op_vecnorm_meat.hpp @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_vecnorm +//! @{ + + + +template +inline +void +op_vecnorm::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type out_eT; + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword k = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_vecnorm::apply_noalias(tmp, X, k, dim); + + out.steal_mem(tmp); + } + else + { + op_vecnorm::apply_noalias(out, X, k, dim); + } + } + + + + +template +inline +void +op_vecnorm::apply_noalias(Mat::result>& out, const Mat& X, const uword k, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_vecnorm::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_vecnorm::apply_rawmem( out_mem[col], X.colptr(col), X_n_rows, k ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_vecnorm::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + dat.copy_row(X, row); + + op_vecnorm::apply_rawmem( out_mem[row], dat_mem, X_n_cols, k ); + } + } + } + } + + + +template +inline +void +op_vecnorm::apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword k) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const Col tmp(const_cast(mem), N, false, false); + + const Proxy< Col > P(tmp); + + if(P.get_n_elem() == 0) { out_val = out_eT(0); return; } + + if(k == uword(1)) { out_val = op_norm::vec_norm_1(P); return; } + if(k == uword(2)) { out_val = op_norm::vec_norm_2(P); return; } + + out_val = op_norm::vec_norm_k(P, int(k)); + } + + + +// + + + +template +inline +void +op_vecnorm_ext::apply(Mat& out, const mtOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type in_eT; + typedef typename T1::pod_type out_eT; + + const quasi_unwrap U(in.m); + const Mat& X = U.M; + + const uword method_id = in.aux_uword_a; + const uword dim = in.aux_uword_b; + + arma_debug_check( (method_id == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + if(U.is_alias(out)) + { + Mat tmp; + + op_vecnorm_ext::apply_noalias(tmp, X, method_id, dim); + + out.steal_mem(tmp); + } + else + { + op_vecnorm_ext::apply_noalias(out, X, method_id, dim); + } + } + + + + +template +inline +void +op_vecnorm_ext::apply_noalias(Mat::result>& out, const Mat& X, const uword method_id, const uword dim) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const uword X_n_rows = X.n_rows; + const uword X_n_cols = X.n_cols; + + if(dim == 0) + { + arma_extra_debug_print("op_vecnorm_ext::apply(): dim = 0"); + + out.set_size((X_n_rows > 0) ? 1 : 0, X_n_cols); + + if(X_n_rows > 0) + { + out_eT* out_mem = out.memptr(); + + for(uword col=0; col < X_n_cols; ++col) + { + op_vecnorm_ext::apply_rawmem( out_mem[col], X.colptr(col), X_n_rows, method_id ); + } + } + } + else + if(dim == 1) + { + arma_extra_debug_print("op_vecnorm_ext::apply(): dim = 1"); + + out.set_size(X_n_rows, (X_n_cols > 0) ? 1 : 0); + + if(X_n_cols > 0) + { + podarray dat(X_n_cols); + + in_eT* dat_mem = dat.memptr(); + out_eT* out_mem = out.memptr(); + + for(uword row=0; row < X_n_rows; ++row) + { + dat.copy_row(X, row); + + op_vecnorm_ext::apply_rawmem( out_mem[row], dat_mem, X_n_cols, method_id ); + } + } + } + } + + + +template +inline +void +op_vecnorm_ext::apply_rawmem(typename get_pod_type::result& out_val, const in_eT* mem, const uword N, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result out_eT; + + const Col tmp(const_cast(mem), N, false, false); + + const Proxy< Col > P(tmp); + + if(P.get_n_elem() == 0) { out_val = out_eT(0); return; } + + if(method_id == uword(1)) { out_val = op_norm::vec_norm_max(P); return; } + if(method_id == uword(2)) { out_val = op_norm::vec_norm_min(P); return; } + + out_val = out_eT(0); + } + + + +//! @} diff --git a/src/armadillo_bits/op_vectorise_bones.hpp b/src/armadillo_bits/op_vectorise_bones.hpp index df291709..91f7df7c 100644 --- a/src/armadillo_bits/op_vectorise_bones.hpp +++ b/src/armadillo_bits/op_vectorise_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -69,7 +71,9 @@ class op_vectorise_cube_col template inline static void apply_subview(Mat& out, const subview_cube& sv); - template inline static void apply_proxy(Mat& out, const ProxyCube& P); + template inline static void apply_unwrap(Mat& out, const T1& expr); + + template inline static void apply_proxy(Mat& out, const T1& expr); }; diff --git a/src/armadillo_bits/op_vectorise_meat.hpp b/src/armadillo_bits/op_vectorise_meat.hpp index 396075db..c0f278cb 100644 --- a/src/armadillo_bits/op_vectorise_meat.hpp +++ b/src/armadillo_bits/op_vectorise_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -41,15 +43,80 @@ op_vectorise_col::apply_direct(Mat& out, const T1& expr) typedef typename T1::elem_type eT; - if(is_same_type< T1, subview >::yes) + // allow detection of in-place operation + if(is_Mat::value || (arma_config::openmp && Proxy::use_mp)) { - op_vectorise_col::apply_subview(out, reinterpret_cast< const subview& >(expr)); + const unwrap U(expr); + + if(&out == &(U.M)) + { + // output matrix is the same as the input matrix + + out.set_size(out.n_elem, 1); // set_size() doesn't destroy data as long as the number of elements in the matrix remains the same + } + else + { + out.set_size(U.M.n_elem, 1); + + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + } + else + if(is_subview::value) + { + const subview& sv = reinterpret_cast< const subview& >(expr); + + if(&out == &(sv.m)) + { + Mat tmp; + + op_vectorise_col::apply_subview(tmp, sv); + + out.steal_mem(tmp); + } + else + { + op_vectorise_col::apply_subview(out, sv); + } } else { const Proxy P(expr); - op_vectorise_col::apply_proxy(out, P); + const bool is_alias = P.is_alias(out); + + if(is_Mat::stored_type>::value) + { + const quasi_unwrap::stored_type> U(P.Q); + + if(is_alias) + { + Mat tmp(U.M.memptr(), U.M.n_elem, 1); + + out.steal_mem(tmp); + } + else + { + out.set_size(U.M.n_elem, 1); + + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + } + else + { + if(is_alias) + { + Mat tmp; + + op_vectorise_col::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_vectorise_col::apply_proxy(out, P); + } + } } } @@ -62,31 +129,18 @@ op_vectorise_col::apply_subview(Mat& out, const subview& sv) { arma_extra_debug_sigprint(); - const bool is_alias = (&out == &(sv.m)); + const uword sv_n_rows = sv.n_rows; + const uword sv_n_cols = sv.n_cols; - if(is_alias == false) - { - const uword sv_n_rows = sv.n_rows; - const uword sv_n_cols = sv.n_cols; - - out.set_size(sv.n_elem, 1); - - eT* out_mem = out.memptr(); - - for(uword col=0; col < sv_n_cols; ++col) - { - arrayops::copy(out_mem, sv.colptr(col), sv_n_rows); - - out_mem += sv_n_rows; - } - } - else + out.set_size(sv.n_elem, 1); + + eT* out_mem = out.memptr(); + + for(uword col=0; col < sv_n_cols; ++col) { - Mat tmp; + arrayops::copy(out_mem, sv.colptr(col), sv_n_rows); - op_vectorise_col::apply_subview(tmp, sv); - - out.steal_mem(tmp); + out_mem += sv_n_rows; } } @@ -101,83 +155,54 @@ op_vectorise_col::apply_proxy(Mat& out, const Proxy& typedef typename T1::elem_type eT; - if(P.is_alias(out) == false) + const uword N = P.get_n_elem(); + + out.set_size(N, 1); + + eT* outmem = out.memptr(); + + if(Proxy::use_at == false) { - const uword N = P.get_n_elem(); + // TODO: add handling of aligned access ? - out.set_size(N, 1); - - if(is_Mat::stored_type>::value == true) + typename Proxy::ea_type A = P.get_ea(); + + uword i,j; + + for(i=0, j=1; j < N; i+=2, j+=2) { - const unwrap::stored_type> tmp(P.Q); + const eT tmp_i = A[i]; + const eT tmp_j = A[j]; - arrayops::copy(out.memptr(), tmp.M.memptr(), N); + outmem[i] = tmp_i; + outmem[j] = tmp_j; } - else + + if(i < N) { - eT* outmem = out.memptr(); - - if(Proxy::use_at == false) - { - // TODO: add handling of aligned access ? - - typename Proxy::ea_type A = P.get_ea(); - - uword i,j; - - for(i=0, j=1; j < N; i+=2, j+=2) - { - const eT tmp_i = A[i]; - const eT tmp_j = A[j]; - - outmem[i] = tmp_i; - outmem[j] = tmp_j; - } - - if(i < N) - { - outmem[i] = A[i]; - } - } - else - { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - - if(n_rows == 1) - { - for(uword i=0; i < n_cols; ++i) - { - outmem[i] = P.at(0,i); - } - } - else - { - for(uword col=0; col < n_cols; ++col) - for(uword row=0; row < n_rows; ++row) - { - *outmem = P.at(row,col); - outmem++; - } - } - } + outmem[i] = A[i]; } } - else // we have aliasing + else { - arma_extra_debug_print("op_vectorise_col::apply(): aliasing detected"); + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); - if( (is_Mat::stored_type>::value == true) && (Proxy::fake_mat == false) ) + if(n_rows == 1) { - out.set_size(out.n_elem, 1); // set_size() doesn't destroy data as long as the number of elements in the matrix remains the same + for(uword i=0; i < n_cols; ++i) + { + outmem[i] = P.at(0,i); + } } else { - Mat tmp; - - op_vectorise_col::apply_proxy(tmp, P); - - out.steal_mem(tmp); + for(uword col=0; col < n_cols; ++col) + for(uword row=0; row < n_rows; ++row) + { + *outmem = P.at(row,col); + outmem++; + } } } } @@ -203,9 +228,22 @@ op_vectorise_row::apply_direct(Mat& out, const T1& expr) { arma_extra_debug_sigprint(); + typedef typename T1::elem_type eT; + const Proxy P(expr); - op_vectorise_row::apply_proxy(out, P); + if(P.is_alias(out)) + { + Mat tmp; + + op_vectorise_row::apply_proxy(tmp, P); + + out.steal_mem(tmp); + } + else + { + op_vectorise_row::apply_proxy(out, P); + } } @@ -219,61 +257,48 @@ op_vectorise_row::apply_proxy(Mat& out, const Proxy& typedef typename T1::elem_type eT; - if(P.is_alias(out) == false) + const uword n_rows = P.get_n_rows(); + const uword n_cols = P.get_n_cols(); + const uword n_elem = P.get_n_elem(); + + out.set_size(1, n_elem); + + eT* outmem = out.memptr(); + + if(n_cols == 1) { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_elem = P.get_n_elem(); - - out.set_size(1, n_elem); - - eT* outmem = out.memptr(); - - if(n_cols == 1) + if(is_Mat::stored_type>::value) { - if(is_Mat::stored_type>::value == true) - { - const unwrap::stored_type> tmp(P.Q); - - arrayops::copy(out.memptr(), tmp.M.memptr(), n_elem); - } - else - { - for(uword i=0; i < n_elem; ++i) { outmem[i] = P.at(i,0); } - } + const unwrap::stored_type> tmp(P.Q); + + arrayops::copy(out.memptr(), tmp.M.memptr(), n_elem); } else { - for(uword row=0; row < n_rows; ++row) + for(uword i=0; i < n_elem; ++i) { outmem[i] = P.at(i,0); } + } + } + else + { + for(uword row=0; row < n_rows; ++row) + { + uword i,j; + + for(i=0, j=1; j < n_cols; i+=2, j+=2) { - uword i,j; - - for(i=0, j=1; j < n_cols; i+=2, j+=2) - { - const eT tmp_i = P.at(row,i); - const eT tmp_j = P.at(row,j); - - *outmem = tmp_i; outmem++; - *outmem = tmp_j; outmem++; - } + const eT tmp_i = P.at(row,i); + const eT tmp_j = P.at(row,j); - if(i < n_cols) - { - *outmem = P.at(row,i); outmem++; - } + *outmem = tmp_i; outmem++; + *outmem = tmp_j; outmem++; + } + + if(i < n_cols) + { + *outmem = P.at(row,i); outmem++; } } } - else // we have aliasing - { - arma_extra_debug_print("op_vectorise_row::apply(): aliasing detected"); - - Mat tmp; - - op_vectorise_row::apply_proxy(tmp, P); - - out.steal_mem(tmp); - } } @@ -318,9 +343,14 @@ op_vectorise_cube_col::apply(Mat& out, const CubeToMatOp } else { - const ProxyCube P(in.m); - - op_vectorise_cube_col::apply_proxy(out, P); + if(is_Cube::value || (arma_config::openmp && ProxyCube::use_mp)) + { + op_vectorise_cube_col::apply_unwrap(out, in.m); + } + else + { + op_vectorise_cube_col::apply_proxy(out, in.m); + } } } @@ -333,86 +363,101 @@ op_vectorise_cube_col::apply_subview(Mat& out, const subview_cube& sv) { arma_extra_debug_sigprint(); - const uword sv_n_rows = sv.n_rows; - const uword sv_n_cols = sv.n_cols; - const uword sv_n_slices = sv.n_slices; + const uword sv_nr = sv.n_rows; + const uword sv_nc = sv.n_cols; + const uword sv_ns = sv.n_slices; out.set_size(sv.n_elem, 1); eT* out_mem = out.memptr(); - for(uword slice=0; slice < sv_n_slices; ++slice) - for(uword col=0; col < sv_n_cols; ++col ) + for(uword s=0; s < sv_ns; ++s) + for(uword c=0; c < sv_nc; ++c) { - arrayops::copy(out_mem, sv.slice_colptr(slice,col), sv_n_rows); + arrayops::copy(out_mem, sv.slice_colptr(s,c), sv_nr); - out_mem += sv_n_rows; + out_mem += sv_nr; } } + + + +template +inline +void +op_vectorise_cube_col::apply_unwrap(Mat& out, const T1& expr) + { + arma_extra_debug_sigprint(); + const unwrap_cube U(expr); + out.set_size(U.M.n_elem, 1); + arrayops::copy(out.memptr(), U.M.memptr(), U.M.n_elem); + } + + + template inline void -op_vectorise_cube_col::apply_proxy(Mat& out, const ProxyCube& P) +op_vectorise_cube_col::apply_proxy(Mat& out, const T1& expr) { arma_extra_debug_sigprint(); typedef typename T1::elem_type eT; + const ProxyCube P(expr); + + if(is_Cube::stored_type>::value) + { + op_vectorise_cube_col::apply_unwrap(out, P.Q); + + return; + } + const uword N = P.get_n_elem(); out.set_size(N, 1); - if(is_Cube::stored_type>::value == true) + eT* outmem = out.memptr(); + + if(ProxyCube::use_at == false) { - const unwrap_cube::stored_type> tmp(P.Q); + typename ProxyCube::ea_type A = P.get_ea(); - arrayops::copy(out.memptr(), tmp.M.memptr(), N); - } - else - { - eT* outmem = out.memptr(); + uword i,j; - if(ProxyCube::use_at == false) + for(i=0, j=1; j < N; i+=2, j+=2) { - typename ProxyCube::ea_type A = P.get_ea(); - - uword i,j; - - for(i=0, j=1; j < N; i+=2, j+=2) - { - const eT tmp_i = A[i]; - const eT tmp_j = A[j]; - - outmem[i] = tmp_i; - outmem[j] = tmp_j; - } + const eT tmp_i = A[i]; + const eT tmp_j = A[j]; - if(i < N) - { - outmem[i] = A[i]; - } + outmem[i] = tmp_i; + outmem[j] = tmp_j; } - else + + if(i < N) { - const uword n_rows = P.get_n_rows(); - const uword n_cols = P.get_n_cols(); - const uword n_slices = P.get_n_slices(); - - for(uword slice=0; slice < n_slices; ++slice) - for(uword col=0; col < n_cols; ++col ) - for(uword row=0; row < n_rows; ++row ) - { - *outmem = P.at(row,col,slice); - outmem++; - } + outmem[i] = A[i]; + } + } + else + { + const uword nr = P.get_n_rows(); + const uword nc = P.get_n_cols(); + const uword ns = P.get_n_slices(); + + for(uword s=0; s < ns; ++s) + for(uword c=0; c < nc; ++c) + for(uword r=0; r < nr; ++r) + { + *outmem = P.at(r,c,s); + outmem++; } } } - //! @} diff --git a/src/armadillo_bits/op_wishrnd_bones.hpp b/src/armadillo_bits/op_wishrnd_bones.hpp index 53ad94f0..b85a72d4 100644 --- a/src/armadillo_bits/op_wishrnd_bones.hpp +++ b/src/armadillo_bits/op_wishrnd_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/op_wishrnd_meat.hpp b/src/armadillo_bits/op_wishrnd_meat.hpp index bf1acb38..44fa77d0 100644 --- a/src/armadillo_bits/op_wishrnd_meat.hpp +++ b/src/armadillo_bits/op_wishrnd_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -40,6 +42,7 @@ op_wishrnd::apply(Mat& out, const Op& exp if(status == false) { + out.soft_reset(); arma_stop_runtime_error("wishrnd(): given matrix is not symmetric positive definite"); } } @@ -74,8 +77,6 @@ op_wishrnd::apply_direct(Mat& out, const Base& out, const Mat& S, const eT df) if(S.is_empty()) { out.reset(); return true; } + if(auxlib::rudimentary_sym_check(S) == false) { return false; } + Mat D; const bool status = op_chol::apply_direct(D, S, 0); @@ -110,62 +113,49 @@ op_wishrnd::apply_noalias_mode2(Mat& out, const Mat& D, const eT df) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) + arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); + arma_debug_check( (D.is_square() == false), "wishrnd(): given matrix must be square sized" ); + + if(D.is_empty()) { out.reset(); return true; } + + const uword N = D.n_rows; + + if(df < eT(N)) { - arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); - arma_debug_check( (D.is_square() == false), "wishrnd(): given matrix must be square sized" ); + arma_extra_debug_print("simple generator"); - if(D.is_empty()) { out.reset(); return true; } + const uword df_floor = uword(std::floor(df)); + + const Mat tmp = (randn< Mat >(df_floor, N)) * D; + + out = tmp.t() * tmp; + } + else + { + arma_extra_debug_print("standard generator"); - const uword N = D.n_rows; + op_chi2rnd_varying_df chi2rnd_generator; - if(df < eT(N)) + Mat A(N, N, arma_zeros_indicator()); + + for(uword i=0; i tmp = (randn< Mat >(df_floor, N)) * D; - - out = tmp.t() * tmp; + A.at(i,i) = std::sqrt( chi2rnd_generator(df - eT(i)) ); } - else + + for(uword i=1; i < N; ++i) { - arma_extra_debug_print("standard generator"); - - op_chi2rnd_varying_df chi2rnd_generator; - - Mat A(N, N, fill::zeros); - - for(uword i=0; i::fill( A.colptr(i), i ); - } - - const Mat tmp = A * D; - - A.reset(); - - out = tmp.t() * tmp; + arma_rng::randn::fill( A.colptr(i), i ); } - return true; - } - #else - { - arma_ignore(out); - arma_ignore(D); - arma_ignore(df); - arma_stop_logic_error("wishrnd(): C++11 compiler required"); + const Mat tmp = A * D; + + A.reset(); - return false; + out = tmp.t() * tmp; } - #endif + + return true; } @@ -190,6 +180,7 @@ op_iwishrnd::apply(Mat& out, const Op& e if(status == false) { + out.soft_reset(); arma_stop_runtime_error("iwishrnd(): given matrix is not symmetric positive definite and/or df is too low"); } } @@ -224,8 +215,6 @@ op_iwishrnd::apply_direct(Mat& out, const Base& out, const Mat& T, const eT df) if(T.is_empty()) { out.reset(); return true; } + if(auxlib::rudimentary_sym_check(T) == false) { return false; } + Mat Tinv; Mat Dinv; @@ -265,37 +256,24 @@ op_iwishrnd::apply_noalias_mode2(Mat& out, const Mat& Dinv, const eT df) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) - { - arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); - arma_debug_check( (Dinv.is_square() == false), "iwishrnd(): given matrix must be square sized" ); - - if(Dinv.is_empty()) { out.reset(); return true; } - - Mat tmp; - - const bool wishrnd_status = op_wishrnd::apply_noalias_mode2(tmp, Dinv, df); - - if(wishrnd_status == false) { return false; } - - const bool inv_status1 = auxlib::inv_sympd(out, tmp); - - const bool inv_status2 = (inv_status1) ? bool(true) : bool(auxlib::inv(out, tmp)); - - if(inv_status2 == false) { return false; } - - return true; - } - #else - { - arma_ignore(out); - arma_ignore(Dinv); - arma_ignore(df); - arma_stop_logic_error("iwishrnd(): C++11 compiler required"); - - return false; - } - #endif + arma_debug_check( (df <= eT(0)), "df must be greater than zero" ); + arma_debug_check( (Dinv.is_square() == false), "iwishrnd(): given matrix must be square sized" ); + + if(Dinv.is_empty()) { out.reset(); return true; } + + Mat tmp; + + const bool wishrnd_status = op_wishrnd::apply_noalias_mode2(tmp, Dinv, df); + + if(wishrnd_status == false) { return false; } + + const bool inv_status1 = auxlib::inv_sympd(out, tmp); + + const bool inv_status2 = (inv_status1) ? bool(true) : bool(auxlib::inv(out, tmp)); + + if(inv_status2 == false) { return false; } + + return true; } diff --git a/src/armadillo_bits/operator_cube_div.hpp b/src/armadillo_bits/operator_cube_div.hpp index 8a1beaa4..58ff3a0a 100644 --- a/src/armadillo_bits/operator_cube_div.hpp +++ b/src/armadillo_bits/operator_cube_div.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_cube_minus.hpp b/src/armadillo_bits/operator_cube_minus.hpp index 72322c61..53cb4149 100644 --- a/src/armadillo_bits/operator_cube_minus.hpp +++ b/src/armadillo_bits/operator_cube_minus.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_cube_plus.hpp b/src/armadillo_bits/operator_cube_plus.hpp index a44dd7cd..fb360fea 100644 --- a/src/armadillo_bits/operator_cube_plus.hpp +++ b/src/armadillo_bits/operator_cube_plus.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_cube_relational.hpp b/src/armadillo_bits/operator_cube_relational.hpp index 3f0091ac..8270d0f3 100644 --- a/src/armadillo_bits/operator_cube_relational.hpp +++ b/src/armadillo_bits/operator_cube_relational.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_cube_schur.hpp b/src/armadillo_bits/operator_cube_schur.hpp index 451b943a..21b7ee1e 100644 --- a/src/armadillo_bits/operator_cube_schur.hpp +++ b/src/armadillo_bits/operator_cube_schur.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_cube_times.hpp b/src/armadillo_bits/operator_cube_times.hpp index 6fd07f63..0b9cf766 100644 --- a/src/armadillo_bits/operator_cube_times.hpp +++ b/src/armadillo_bits/operator_cube_times.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_div.hpp b/src/armadillo_bits/operator_div.hpp index c2af0597..4f17fdf1 100644 --- a/src/armadillo_bits/operator_div.hpp +++ b/src/armadillo_bits/operator_div.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -23,7 +25,7 @@ template arma_inline typename -enable_if2< is_arma_type::value, const eOp >::result +enable_if2< is_arma_type::value, const eOp< T1, eop_scalar_div_post> >::result operator/ ( const T1& X, @@ -41,7 +43,7 @@ operator/ template arma_inline typename -enable_if2< is_arma_type::value, const eOp >::result +enable_if2< is_arma_type::value, const eOp< T1, eop_scalar_div_pre> >::result operator/ ( const typename T1::elem_type k, @@ -154,7 +156,7 @@ operator/ template inline typename -enable_if2::value, SpMat >::result +enable_if2< is_arma_sparse_type::value, SpMat >::result operator/ ( const T1& X, @@ -300,7 +302,7 @@ operator/ arma_debug_assert_same_size(n_rows, n_cols, pb.get_n_rows(), pb.get_n_cols(), "element-wise division"); - Mat result(n_rows, n_cols); + Mat result(n_rows, n_cols, arma_nozeros_indicator()); for(uword col=0; col < n_cols; ++col) for(uword row=0; row < n_rows; ++row) diff --git a/src/armadillo_bits/operator_minus.hpp b/src/armadillo_bits/operator_minus.hpp index 713be7b2..42047a78 100644 --- a/src/armadillo_bits/operator_minus.hpp +++ b/src/armadillo_bits/operator_minus.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -264,7 +266,7 @@ operator- Mat result(x); - const SpProxy pb(y.get_ref()); + const SpProxy pb(y); arma_debug_assert_same_size( result.n_rows, result.n_cols, pb.get_n_rows(), pb.get_n_cols(), "subtraction" ); diff --git a/src/armadillo_bits/operator_ostream.hpp b/src/armadillo_bits/operator_ostream.hpp index 5c076f46..ce9a9cb0 100644 --- a/src/armadillo_bits/operator_ostream.hpp +++ b/src/armadillo_bits/operator_ostream.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_plus.hpp b/src/armadillo_bits/operator_plus.hpp index 8b908754..3cae597c 100644 --- a/src/armadillo_bits/operator_plus.hpp +++ b/src/armadillo_bits/operator_plus.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -178,7 +180,7 @@ operator+ -//! addition of sparse and non-sparse object +//! addition of one dense and one sparse object template inline typename @@ -215,7 +217,7 @@ operator+ -//! addition of sparse and non-sparse object +//! addition of one sparse and one dense object template inline typename @@ -232,9 +234,22 @@ operator+ { arma_extra_debug_sigprint(); - // Just call the other order (these operations are commutative) - // TODO: if there is a matrix size mismatch, the debug assert will print the matrix sizes in wrong order - return (y + x); + const SpProxy pa(x); + + Mat result(y); + + arma_debug_assert_same_size( pa.get_n_rows(), pa.get_n_cols(), result.n_rows, result.n_cols, "addition" ); + + typename SpProxy::const_iterator_type it = pa.begin(); + typename SpProxy::const_iterator_type it_end = pa.end(); + + while(it != it_end) + { + result.at(it.row(), it.col()) += (*it); + ++it; + } + + return result; } diff --git a/src/armadillo_bits/operator_relational.hpp b/src/armadillo_bits/operator_relational.hpp index 84868d30..7313fdc5 100644 --- a/src/armadillo_bits/operator_relational.hpp +++ b/src/armadillo_bits/operator_relational.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -438,4 +440,44 @@ operator> +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator&& +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + +template +inline +typename +enable_if2 + < + (is_arma_sparse_type::value && is_arma_sparse_type::value && (is_cx::no) && (is_cx::no)), + const mtSpGlue + >::result +operator|| +(const T1& X, const T2& Y) + { + arma_extra_debug_sigprint(); + + // TODO: ensure T1::elem_type and T2::elem_type are the same + + return mtSpGlue( X, Y ); + } + + + //! @} diff --git a/src/armadillo_bits/operator_schur.hpp b/src/armadillo_bits/operator_schur.hpp index 5f34afa1..2acfdbc6 100644 --- a/src/armadillo_bits/operator_schur.hpp +++ b/src/armadillo_bits/operator_schur.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/operator_times.hpp b/src/armadillo_bits/operator_times.hpp index 53fcaef2..861166c0 100644 --- a/src/armadillo_bits/operator_times.hpp +++ b/src/armadillo_bits/operator_times.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -178,7 +180,7 @@ operator* arma_debug_assert_mul_size(A.n_rows, A.n_cols, B.n_rows, B.n_cols, "matrix multiplication"); - Mat out(A.n_rows, B.n_cols, fill::zeros); + Mat out(A.n_rows, B.n_cols, arma_zeros_indicator()); const uword A_length = (std::min)(A.n_rows, A.n_cols); const uword B_length = (std::min)(B.n_rows, B.n_cols); @@ -359,7 +361,7 @@ typename enable_if2 < (is_arma_sparse_type::value && is_arma_type::value && is_same_type::value), - Mat + const SpToDGlue >::result operator* ( @@ -369,13 +371,7 @@ operator* { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - - Mat result; - - spglue_times_misc::sparse_times_dense(result, x, y); - - return result; + return SpToDGlue(x, y); } @@ -387,7 +383,7 @@ typename enable_if2 < (is_arma_type::value && is_arma_sparse_type::value && is_same_type::value), - Mat + const SpToDGlue >::result operator* ( @@ -397,13 +393,7 @@ operator* { arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; - - Mat result; - - spglue_times_misc::dense_times_sparse(result, x, y); - - return result; + return SpToDGlue(x, y); } @@ -456,7 +446,7 @@ operator* Mat< typename promote_type::result > out; - spglue_times_mixed::sparse_times_dense(out, X, Y); + glue_times_sparse_dense::apply_mixed(out, X, Y); return out; } @@ -482,7 +472,7 @@ operator* Mat< typename promote_type::result > out; - spglue_times_mixed::dense_times_sparse(out, X, Y); + glue_times_dense_sparse::apply_mixed(out, X, Y); return out; } diff --git a/src/armadillo_bits/podarray_bones.hpp b/src/armadillo_bits/podarray_bones.hpp index b052a9ae..9aa2cf1e 100644 --- a/src/armadillo_bits/podarray_bones.hpp +++ b/src/armadillo_bits/podarray_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,7 +23,7 @@ struct podarray_prealloc_n_elem { - static const uword val = 16; + static constexpr uword val = 16; }; @@ -51,10 +53,8 @@ class podarray arma_inline explicit podarray(const uword new_N); - arma_inline explicit podarray(const eT* X, const uword new_N); - - // template - // inline explicit podarray(const Proxy& P); + template + inline explicit podarray(const uword new_N, const arma_initmode_indicator&); arma_inline eT& operator[] (const uword i); arma_inline eT operator[] (const uword i) const; @@ -76,7 +76,7 @@ class podarray arma_inline eT* memptr(); arma_inline const eT* memptr() const; - arma_hot inline void copy_row(const Mat& A, const uword row); + inline void copy_row(const Mat& A, const uword row); protected: diff --git a/src/armadillo_bits/podarray_meat.hpp b/src/armadillo_bits/podarray_meat.hpp index 7dc5ed32..2eb62f24 100644 --- a/src/armadillo_bits/podarray_meat.hpp +++ b/src/armadillo_bits/podarray_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,6 @@ template -arma_hot inline podarray::~podarray() { @@ -82,7 +83,6 @@ podarray::operator=(const podarray& x) template -arma_hot arma_inline podarray::podarray(const uword new_n_elem) : n_elem(new_n_elem) @@ -95,79 +95,24 @@ podarray::podarray(const uword new_n_elem) template -arma_inline -podarray::podarray(const eT* X, const uword new_n_elem) +template +inline +podarray::podarray(const uword new_n_elem, const arma_initmode_indicator&) : n_elem(new_n_elem) { arma_extra_debug_sigprint_this(this); init_cold(new_n_elem); - arrayops::copy( memptr(), X, new_n_elem ); + if(do_zeros) + { + arma_extra_debug_print("podarray::constructor: zeroing memory"); + arrayops::fill_zeros(memptr(), n_elem); + } } -// template -// template -// inline -// podarray::podarray(const Proxy& P) -// : n_elem(P.get_n_elem()) -// { -// arma_extra_debug_sigprint_this(this); -// -// const uword P_n_elem = P.get_n_elem(); -// -// init_cold(P_n_elem); -// -// eT* out_mem = (*this).memptr(); -// -// if(Proxy::use_at == false) -// { -// typename Proxy::ea_type A = P.get_ea(); -// -// uword i,j; -// for(i=0, j=1; j < P_n_elem; i+=2, j+=2) -// { -// const eT val_i = A[i]; -// const eT val_j = A[j]; -// -// out_mem[i] = val_i; -// out_mem[j] = val_j; -// } -// -// if(i < P_n_elem) -// { -// out_mem[i] = A[i]; -// } -// } -// else -// { -// const uword P_n_rows = P.get_n_rows(); -// const uword P_n_cols = P.get_n_cols(); -// -// if(P_n_rows != 1) -// { -// uword count = 0; -// -// for(uword col=0; col < P_n_cols; ++col) -// for(uword row=0; row < P_n_rows; ++row, ++count) -// { -// out_mem[count] = P.at(row,col); -// } -// } -// else -// { -// for(uword col=0; col < P_n_cols; ++col) -// { -// out_mem[col] = P.at(0,col); -// } -// } -// } -// } - - - template arma_inline eT @@ -193,7 +138,7 @@ arma_inline eT podarray::operator() (const uword i) const { - arma_debug_check( (i >= n_elem), "podarray::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "podarray::operator(): index out of bounds" ); return mem[i]; } @@ -205,7 +150,7 @@ arma_inline eT& podarray::operator() (const uword i) { - arma_debug_check( (i >= n_elem), "podarray::operator(): index out of bounds"); + arma_debug_check_bounds( (i >= n_elem), "podarray::operator(): index out of bounds" ); return access::rw(mem[i]); } @@ -219,10 +164,7 @@ podarray::set_min_size(const uword min_n_elem) { arma_extra_debug_sigprint(); - if(min_n_elem > n_elem) - { - init_warm(min_n_elem); - } + if(min_n_elem > n_elem) { init_warm(min_n_elem); } } @@ -310,75 +252,38 @@ podarray::memptr() const template -arma_hot inline void podarray::copy_row(const Mat& A, const uword row) { - const uword cols = A.n_cols; + arma_extra_debug_sigprint(); // note: this function assumes that the podarray has been set to the correct size beforehand - eT* out = memptr(); - switch(cols) + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + const eT* A_mem = &(A.at(row,0)); + eT* out_mem = memptr(); + + for(uword i=0; i < n_cols; ++i) { - default: - { - uword i,j; - for(i=0, j=1; j < cols; i+=2, j+=2) - { - const eT tmp_i = A.at(row, i); - const eT tmp_j = A.at(row, j); - - out[i] = tmp_i; - out[j] = tmp_j; - } - - if(i < cols) - { - out[i] = A.at(row, i); - } - } - break; + out_mem[i] = (*A_mem); - case 8: out[7] = A.at(row, 7); - // fallthrough - case 7: out[6] = A.at(row, 6); - // fallthrough - case 6: out[5] = A.at(row, 5); - // fallthrough - case 5: out[4] = A.at(row, 4); - // fallthrough - case 4: out[3] = A.at(row, 3); - // fallthrough - case 3: out[2] = A.at(row, 2); - // fallthrough - case 2: out[1] = A.at(row, 1); - // fallthrough - case 1: out[0] = A.at(row, 0); - // fallthrough - case 0: ; - // fallthrough + A_mem += n_rows; } } + template -arma_hot inline void podarray::init_cold(const uword new_n_elem) { arma_extra_debug_sigprint(); - if(new_n_elem <= podarray_prealloc_n_elem::val ) - { - mem = mem_local; - } - else - { - mem = memory::acquire(new_n_elem); - } + mem = (new_n_elem <= podarray_prealloc_n_elem::val) ? mem_local : memory::acquire(new_n_elem); } @@ -390,24 +295,11 @@ podarray::init_warm(const uword new_n_elem) { arma_extra_debug_sigprint(); - if(n_elem == new_n_elem) - { - return; - } + if(n_elem == new_n_elem) { return; } - if(n_elem > podarray_prealloc_n_elem::val ) - { - memory::release( mem ); - } + if(n_elem > podarray_prealloc_n_elem::val) { memory::release( mem ); } - if(new_n_elem <= podarray_prealloc_n_elem::val ) - { - mem = mem_local; - } - else - { - mem = memory::acquire(new_n_elem); - } + mem = (new_n_elem <= podarray_prealloc_n_elem::val) ? mem_local : memory::acquire(new_n_elem); access::rw(n_elem) = new_n_elem; } diff --git a/src/armadillo_bits/promote_type.hpp b/src/armadillo_bits/promote_type.hpp index 5428d870..d53eb322 100644 --- a/src/armadillo_bits/promote_type.hpp +++ b/src/armadillo_bits/promote_type.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -21,33 +23,29 @@ template struct is_promotable { - static const bool value = false; + static constexpr bool value = false; typedef T1 result; }; struct is_promotable_ok { - static const bool value = true; + static constexpr bool value = true; }; template struct is_promotable : public is_promotable_ok { typedef T result; }; template struct is_promotable, T> : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable, std::complex > : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable, float> : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable, double> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, std::complex> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, float> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, double> : public is_promotable_ok { typedef std::complex result; }; -#if defined(ARMA_USE_U64S64) template struct is_promotable, u64> : public is_promotable_ok { typedef std::complex result; }; template struct is_promotable, s64> : public is_promotable_ok { typedef std::complex result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template struct is_promotable, ulng_t> : public is_promotable_ok { typedef std::complex result; }; template struct is_promotable, slng_t> : public is_promotable_ok { typedef std::complex result; }; -#endif template struct is_promotable, s32> : public is_promotable_ok { typedef std::complex result; }; template struct is_promotable, u32> : public is_promotable_ok { typedef std::complex result; }; template struct is_promotable, s16> : public is_promotable_ok { typedef std::complex result; }; @@ -57,14 +55,10 @@ template struct is_promotable, u8> : public is_p template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; @@ -72,14 +66,10 @@ template<> struct is_promotable : public is_promotable_ok { type template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; @@ -87,13 +77,10 @@ template<> struct is_promotable : public is_promotable_ok { typed template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; -#endif -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; // float ? template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; @@ -101,7 +88,6 @@ template<> struct is_promotable : public is_promotable_ok { typedef s6 template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; @@ -129,37 +115,29 @@ template<> struct is_promotable : public is_promotable_ok { typedef s8 r // // mirrored versions -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable, std::complex > : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template<> struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable, std::complex> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template<> struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -#if defined(ARMA_USE_U64S64) -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -#endif -#if defined(ARMA_ALLOW_LONG) -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -#endif -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; -template struct is_promotable > : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; +template struct is_promotable> : public is_promotable_ok { typedef std::complex result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; @@ -167,14 +145,10 @@ template<> struct is_promotable : public is_promotable_ok { type template<> struct is_promotable : public is_promotable_ok { typedef double result; }; template<> struct is_promotable : public is_promotable_ok { typedef double result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; @@ -182,13 +156,10 @@ template<> struct is_promotable : public is_promotable_ok { typed template<> struct is_promotable : public is_promotable_ok { typedef float result; }; template<> struct is_promotable : public is_promotable_ok { typedef float result; }; -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef u64 result; }; -#endif -#if defined(ARMA_USE_U64S64) template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; // float ? template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; @@ -196,7 +167,6 @@ template<> struct is_promotable : public is_promotable_ok { typedef s6 template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; template<> struct is_promotable : public is_promotable_ok { typedef s64 result; }; -#endif template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; // float ? template<> struct is_promotable : public is_promotable_ok { typedef s32 result; }; diff --git a/src/armadillo_bits/rbfgs.h b/src/armadillo_bits/rbfgs.h new file mode 100644 index 00000000..399ef77a --- /dev/null +++ b/src/armadillo_bits/rbfgs.h @@ -0,0 +1,11 @@ +#ifndef RBFGS_H +#define RBFGS_H + +#include "armadillo" + +using namespace arma; + +using namespace std; + + +#endif // end of RBFGS_H diff --git a/src/armadillo_bits/restrictors.hpp b/src/armadillo_bits/restrictors.hpp index 908e3512..019a5f41 100644 --- a/src/armadillo_bits/restrictors.hpp +++ b/src/armadillo_bits/restrictors.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -32,14 +34,10 @@ template<> struct arma_scalar_only< u16 > { typedef u16 result; }; template<> struct arma_scalar_only< s16 > { typedef s16 result; }; template<> struct arma_scalar_only< u32 > { typedef u32 result; }; template<> struct arma_scalar_only< s32 > { typedef s32 result; }; -#if defined(ARMA_USE_U64S64) template<> struct arma_scalar_only< u64 > { typedef u64 result; }; template<> struct arma_scalar_only< s64 > { typedef s64 result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct arma_scalar_only< ulng_t > { typedef ulng_t result; }; template<> struct arma_scalar_only< slng_t > { typedef slng_t result; }; -#endif template<> struct arma_scalar_only< float > { typedef float result; }; template<> struct arma_scalar_only< double > { typedef double result; }; template<> struct arma_scalar_only< cx_float > { typedef cx_float result; }; @@ -55,14 +53,10 @@ template<> struct arma_integral_only< u16 > { typedef u16 result; }; template<> struct arma_integral_only< s16 > { typedef s16 result; }; template<> struct arma_integral_only< u32 > { typedef u32 result; }; template<> struct arma_integral_only< s32 > { typedef s32 result; }; -#if defined(ARMA_USE_U64S64) template<> struct arma_integral_only< u64 > { typedef u64 result; }; template<> struct arma_integral_only< s64 > { typedef s64 result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct arma_integral_only< ulng_t > { typedef ulng_t result; }; template<> struct arma_integral_only< slng_t > { typedef slng_t result; }; -#endif @@ -71,12 +65,8 @@ template struct arma_unsigned_integral_only { }; template<> struct arma_unsigned_integral_only< u8 > { typedef u8 result; }; template<> struct arma_unsigned_integral_only< u16 > { typedef u16 result; }; template<> struct arma_unsigned_integral_only< u32 > { typedef u32 result; }; -#if defined(ARMA_USE_U64S64) template<> struct arma_unsigned_integral_only< u64 > { typedef u64 result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct arma_unsigned_integral_only< ulng_t > { typedef ulng_t result; }; -#endif @@ -85,12 +75,8 @@ template struct arma_signed_integral_only { }; template<> struct arma_signed_integral_only< s8 > { typedef s8 result; }; template<> struct arma_signed_integral_only< s16 > { typedef s16 result; }; template<> struct arma_signed_integral_only< s32 > { typedef s32 result; }; -#if defined(ARMA_USE_U64S64) template<> struct arma_signed_integral_only< s64 > { typedef s64 result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct arma_signed_integral_only< slng_t > { typedef slng_t result; }; -#endif @@ -99,12 +85,8 @@ template struct arma_signed_only { }; template<> struct arma_signed_only< s8 > { typedef s8 result; }; template<> struct arma_signed_only< s16 > { typedef s16 result; }; template<> struct arma_signed_only< s32 > { typedef s32 result; }; -#if defined(ARMA_USE_U64S64) template<> struct arma_signed_only< s64 > { typedef s64 result; }; -#endif -#if defined(ARMA_ALLOW_LONG) template<> struct arma_signed_only< slng_t > { typedef slng_t result; }; -#endif template<> struct arma_signed_only< float > { typedef float result; }; template<> struct arma_signed_only< double > { typedef double result; }; template<> struct arma_signed_only< cx_float > { typedef cx_float result; }; diff --git a/src/armadillo_bits/running_stat_bones.hpp b/src/armadillo_bits/running_stat_bones.hpp index d07c637e..bd25e1d2 100644 --- a/src/armadillo_bits/running_stat_bones.hpp +++ b/src/armadillo_bits/running_stat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -102,16 +104,16 @@ class running_stat_aux public: template - inline static void update_stats(running_stat& x, const eT sample, const typename arma_not_cx::result* junk = 0); + inline static void update_stats(running_stat& x, const eT sample, const typename arma_not_cx::result* junk = nullptr); template - inline static void update_stats(running_stat& x, const std::complex& sample, const typename arma_not_cx::result* junk = 0); + inline static void update_stats(running_stat& x, const std::complex& sample, const typename arma_not_cx::result* junk = nullptr); template - inline static void update_stats(running_stat& x, const typename eT::value_type sample, const typename arma_cx_only::result* junk = 0); + inline static void update_stats(running_stat& x, const typename eT::value_type sample, const typename arma_cx_only::result* junk = nullptr); template - inline static void update_stats(running_stat& x, const eT& sample, const typename arma_cx_only::result* junk = 0); + inline static void update_stats(running_stat& x, const eT& sample, const typename arma_cx_only::result* junk = nullptr); }; diff --git a/src/armadillo_bits/running_stat_meat.hpp b/src/armadillo_bits/running_stat_meat.hpp index ed029463..35e6ba82 100644 --- a/src/armadillo_bits/running_stat_meat.hpp +++ b/src/armadillo_bits/running_stat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -162,7 +164,7 @@ running_stat::operator() (const typename running_stat::T sample) if( arma_isfinite(sample) == false ) { - arma_debug_warn("running_stat: sample ignored as it is non-finite" ); + arma_debug_warn_level(3, "running_stat: sample ignored as it is non-finite" ); return; } @@ -181,7 +183,7 @@ running_stat::operator() (const std::complex< typename running_stat::T > if( arma_isfinite(sample) == false ) { - arma_debug_warn("running_stat: sample ignored as it is non-finite" ); + arma_debug_warn_level(3, "running_stat: sample ignored as it is non-finite" ); return; } diff --git a/src/armadillo_bits/running_stat_vec_bones.hpp b/src/armadillo_bits/running_stat_vec_bones.hpp index 0ad96bed..13b076c2 100644 --- a/src/armadillo_bits/running_stat_vec_bones.hpp +++ b/src/armadillo_bits/running_stat_vec_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -62,8 +64,8 @@ class running_stat_vec inline running_stat_vec& operator=(const running_stat_vec& in_rsv); - template arma_hot inline void operator() (const Base< T, T1>& X); - template arma_hot inline void operator() (const Base, T1>& X); + template inline void operator() (const Base< T, T1>& X); + template inline void operator() (const Base, T1>& X); inline void reset(); @@ -119,7 +121,7 @@ class running_stat_vec_aux ( running_stat_vec& x, const Mat::eT>& sample, - const typename arma_not_cx::eT>::result* junk = 0 + const typename arma_not_cx::eT>::result* junk = nullptr ); template @@ -128,7 +130,7 @@ class running_stat_vec_aux ( running_stat_vec& x, const Mat::T > >& sample, - const typename arma_not_cx::eT>::result* junk = 0 + const typename arma_not_cx::eT>::result* junk = nullptr ); template @@ -137,7 +139,7 @@ class running_stat_vec_aux ( running_stat_vec& x, const Mat< typename running_stat_vec::T >& sample, - const typename arma_cx_only::eT>::result* junk = 0 + const typename arma_cx_only::eT>::result* junk = nullptr ); template @@ -146,7 +148,7 @@ class running_stat_vec_aux ( running_stat_vec& x, const Mat::eT>& sample, - const typename arma_cx_only::eT>::result* junk = 0 + const typename arma_cx_only::eT>::result* junk = nullptr ); }; diff --git a/src/armadillo_bits/running_stat_vec_meat.hpp b/src/armadillo_bits/running_stat_vec_meat.hpp index 8b2045c8..370fcf73 100644 --- a/src/armadillo_bits/running_stat_vec_meat.hpp +++ b/src/armadillo_bits/running_stat_vec_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -82,7 +84,6 @@ running_stat_vec::operator=(const running_stat_vec& in_rsv) //! update statistics to reflect new sample template template -arma_hot inline void running_stat_vec::operator() (const Base::T, T1>& X) @@ -97,9 +98,9 @@ running_stat_vec::operator() (const Base::operator() (const Base template -arma_hot inline void running_stat_vec::operator() (const Base< std::complex::T>, T1>& X) @@ -126,9 +126,9 @@ running_stat_vec::operator() (const Base< std::complex - inline static bool eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const char* form_str, const eT default_tol); + inline static bool eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); - template - inline static bool eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const eT default_tol); + template + inline static bool eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); template - inline static bool eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const eT default_tol); + inline static bool eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts); + + template + inline static bool eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts); // - // eigs_gen() + // eigs_gen() for real matrices template - inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const char* form_str, const T default_tol); + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); - template - inline static bool eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const T default_tol); + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts); template - inline static bool eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const T default_tol); + inline static bool eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); + + template + inline static bool eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts); + + // + // eigs_gen() for complex matrices + + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts); template - inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const char* form_str, const T default_tol); + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts); + template + inline static bool eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat< std::complex >& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts); // // spsolve() via SuperLU @@ -67,16 +86,34 @@ class sp_auxlib template inline static bool spsolve_refine(Mat& out, typename T1::pod_type& out_rcond, const SpBase& A, const Base& B, const superlu_opts& user_opts); + // + // support functions + #if defined(ARMA_USE_SUPERLU) + + template + inline static typename get_pod_type::result norm1(superlu::SuperMatrix* A); + + template + inline static typename get_pod_type::result lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result norm_val); + inline static void set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts); template inline static bool copy_to_supermatrix(superlu::SuperMatrix& out, const SpMat& A); + template + inline static bool copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat& A, const eT shift); + + // // for debugging only + // template + // inline static void copy_to_spmat(SpMat& out, const superlu::SuperMatrix& A); + template inline static bool wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat& A); inline static void destroy_supermatrix(superlu::SuperMatrix& out); + #endif @@ -88,10 +125,23 @@ class sp_auxlib // functions are very different and we can't combine their code template - inline static void run_aupd + inline static void run_aupd_plain ( - const uword n_eigvals, char* which, const SpMat& X, const bool sym, - blas_int& n, eT& tol, + const uword n_eigvals, char* which, + const SpMat& X, const SpMat& Xst, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ); + + template + inline static void run_aupd_shiftinvert + ( + const uword n_eigvals, const T sigma, + const SpMat& X, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, podarray& iparam, podarray& ipntr, podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, @@ -105,3 +155,129 @@ class sp_auxlib template inline static bool rudimentary_sym_check(const SpMat< std::complex >& X); }; + + + +template +struct eigs_randu_filler + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline eigs_randu_filler(); + + inline void fill(podarray& X, const uword N); + }; + + +template +struct eigs_randu_filler< std::complex > + { + std::mt19937_64 local_engine; + std::uniform_real_distribution local_u_distr; + + inline eigs_randu_filler(); + + inline void fill(podarray< std::complex >& X, const uword N); + }; + + + +#if defined(ARMA_USE_SUPERLU) + +class superlu_supermatrix_wrangler + { + private: + + bool used = false; + + arma_aligned superlu::SuperMatrix m; + + public: + + inline ~superlu_supermatrix_wrangler(); + inline superlu_supermatrix_wrangler(); + + inline superlu_supermatrix_wrangler(const superlu_supermatrix_wrangler&) = delete; + inline void operator= (const superlu_supermatrix_wrangler&) = delete; + + inline superlu::SuperMatrix& get_ref(); + inline superlu::SuperMatrix* get_ptr(); + }; + + +class superlu_stat_wrangler + { + private: + + arma_aligned superlu::SuperLUStat_t stat; + + public: + + inline ~superlu_stat_wrangler(); + inline superlu_stat_wrangler(); + + inline superlu_stat_wrangler(const superlu_stat_wrangler&) = delete; + inline void operator= (const superlu_stat_wrangler&) = delete; + + inline superlu::SuperLUStat_t* get_ptr(); + }; + + +template +class superlu_array_wrangler + { + private: + + arma_aligned eT* mem = nullptr; + + public: + + inline ~superlu_array_wrangler(); + inline superlu_array_wrangler(); + inline superlu_array_wrangler(const uword n_elem); + + inline void set_size(const uword n_elem); + inline void reset(); + + inline superlu_array_wrangler(const superlu_array_wrangler&) = delete; + inline void operator= (const superlu_array_wrangler&) = delete; + + inline eT* get_ptr(); + }; + + +template +class superlu_worker + { + private: + + bool factorisation_valid = false; + + superlu_supermatrix_wrangler* l = nullptr; + superlu_supermatrix_wrangler* u = nullptr; + + superlu_array_wrangler perm_c; + superlu_array_wrangler perm_r; + + superlu_stat_wrangler stat; + + public: + + inline ~superlu_worker(); + inline superlu_worker(); + + inline bool factorise(typename get_pod_type::result& out_rcond, const SpMat& A, const superlu_opts& user_opts); + + inline bool solve(Mat& X, const Mat& B); + + inline superlu_worker(const superlu_worker&) = delete; + inline void operator= (const superlu_worker&) = delete; + }; + +#endif + + + +//! @} + diff --git a/src/armadillo_bits/sp_auxlib_meat.hpp b/src/armadillo_bits/sp_auxlib_meat.hpp index 5b77e0de..dbfdf2d7 100644 --- a/src/armadillo_bits/sp_auxlib_meat.hpp +++ b/src/armadillo_bits/sp_auxlib_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,7 +27,7 @@ sp_auxlib::interpret_form_str(const char* form_str) arma_extra_debug_sigprint(); // the order of the 3 if statements below is important - if( form_str == NULL ) { return form_none; } + if( form_str == nullptr ) { return form_none; } if( form_str[0] == char(0) ) { return form_none; } if( form_str[1] == char(0) ) { return form_none; } @@ -57,33 +59,46 @@ sp_auxlib::interpret_form_str(const char* form_str) template inline bool -sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const char* form_str, const eT default_tol) +sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) { arma_extra_debug_sigprint(); const unwrap_spmat U(X.get_ref()); + arma_debug_check( (U.M.is_square() == false), "eigs_sym(): given matrix must be square sized" ); + if((arma_config::debug) && (sp_auxlib::rudimentary_sym_check(U.M) == false)) { - if(is_cx::no ) { arma_debug_warn("eigs_sym(): given matrix is not symmetric"); } - if(is_cx::yes) { arma_debug_warn("eigs_sym(): given matrix is not hermitian"); } + if(is_cx::no ) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not hermitian"); } + } + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_sym(): detected non-finite elements"); + return false; } + // TODO: investigate optional redirection of "sm" to ARPACK as it's capable of shift-invert; + // TODO: in shift-invert mode, "sm" maps to "lm" of the shift-inverted matrix (with sigma = 0) + #if defined(ARMA_USE_NEWARP) { - return sp_auxlib::eigs_sym_newarp(eigval, eigvec, U.M, n_eigvals, form_str, default_tol); + return sp_auxlib::eigs_sym_newarp(eigval, eigvec, U.M, n_eigvals, form_val, opts); } #elif defined(ARMA_USE_ARPACK) { - return sp_auxlib::eigs_sym_arpack(eigval, eigvec, U.M, n_eigvals, form_str, default_tol); + constexpr eT sigma = eT(0); + + return sp_auxlib::eigs_sym_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); } #else { arma_ignore(eigval); arma_ignore(eigvec); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(opts); arma_stop_logic_error("eigs_sym(): use of NEWARP or ARPACK must be enabled"); return false; @@ -93,22 +108,70 @@ sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, c +//! immediate eigendecomposition of symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_sym(Col& eigval, Mat& eigvec, const SpBase& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_sym(): given matrix must be square sized" ); + + if((arma_config::debug) && (sp_auxlib::rudimentary_sym_check(U.M) == false)) + { + if(is_cx::no ) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not symmetric"); } + if(is_cx::yes) { arma_debug_warn_level(1, "eigs_sym(): given matrix is not hermitian"); } + } + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_sym(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_NEWARP) && defined(ARMA_USE_SUPERLU)) + { + return sp_auxlib::eigs_sym_newarp(eigval, eigvec, U.M, n_eigvals, sigma, opts); + } + #elif (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_sym_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_sym(): use of NEWARP or ARPACK as well as SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + template inline bool -sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const eT default_tol) +sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_NEWARP) { - const form_type form_val = sp_auxlib::interpret_form_str(form_str); - arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_la) && (form_val != form_sa), "eigs_sym(): unknown form specified" ); - const newarp::SparseGenMatProd op(X); + if(X.is_square() == false) { return false; } - arma_debug_check( (op.n_rows != op.n_cols), "eigs_sym(): given matrix must be square sized" ); + const newarp::SparseGenMatProd op(X); arma_debug_check( (n_eigvals >= op.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); @@ -121,12 +184,39 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, } uword n = op.n_rows; - uword ncv = n_eigvals + 2 + 1; - if(ncv < (2 * n_eigvals + 1)) { ncv = 2 * n_eigvals + 1; } - if(ncv > n) { ncv = n; } + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits, otherwise cap it. + uword ncv = ncv_default; - eT tol = (std::max)(default_tol, std::numeric_limits::epsilon()); + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = uword(n_eigvals + 1); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 1)) { ncv = (n_eigvals + 1); } + if(ncv > n ) { ncv = n; } + + eT tol = (std::max)(eT(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); // eigval.set_size(n_eigvals); // eigvec.set_size(n, n_eigvals); @@ -141,7 +231,7 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, { newarp::SymEigsSolver< eT, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -150,7 +240,7 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, { newarp::SymEigsSolver< eT, newarp::EigsSelect::SMALLEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -159,7 +249,7 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, { newarp::SymEigsSolver< eT, newarp::EigsSelect::LARGEST_ALGE, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -168,7 +258,7 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, { newarp::SymEigsSolver< eT, newarp::EigsSelect::SMALLEST_ALGE, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -191,8 +281,8 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, arma_ignore(eigvec); arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(opts); return false; } @@ -204,34 +294,135 @@ sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, template inline bool -sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const eT default_tol) +sp_auxlib::eigs_sym_newarp(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const eT sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_NEWARP) + { + if(X.is_square() == false) { return false; } + + const newarp::SparseGenRealShiftSolve op(X, sigma); + + if(op.valid == false) { return false; } + + arma_debug_check( (n_eigvals >= op.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); + + // If the matrix is empty, the case is trivial. + if( (op.n_cols == 0) || (n_eigvals == 0) ) // We already know n_cols == n_rows. + { + eigval.reset(); + eigvec.reset(); + return true; + } + + uword n = op.n_rows; + + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits, otherwise cap it. + uword ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = uword(n_eigvals + 1); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 1)) { ncv = (n_eigvals + 1); } + if(ncv > n ) { ncv = n; } + + eT tol = (std::max)(eT(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); + + // eigval.set_size(n_eigvals); + // eigvec.set_size(n, n_eigvals); + + bool status = true; + + uword nconv = 0; + + try + { + newarp::SymEigsShiftSolver< eT, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenRealShiftSolve > eigs(op, n_eigvals, ncv, sigma); + eigs.init(); + nconv = eigs.compute(maxiter, tol); + eigval = eigs.eigenvalues(); + eigvec = eigs.eigenvectors(); + } + catch(const std::runtime_error&) + { + status = false; + } + + if(status == true) + { + if(nconv == 0) { status = false; } + } + + return status; + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(X); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eT sigma, const eigs_opts& opts) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_ARPACK) { - const form_type form_val = sp_auxlib::interpret_form_str(form_str); + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_la) && (form_val != form_sa) && (form_val != form_sigma), "eigs_sym(): unknown form specified" ); - arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_la) && (form_val != form_sa), "eigs_sym(): unknown form specified" ); + if(X.is_square() == false) { return false; } char which_sm[3] = "SM"; char which_lm[3] = "LM"; char which_sa[3] = "SA"; char which_la[3] = "LA"; char* which; - switch (form_val) + + switch(form_val) { case form_sm: which = which_sm; break; case form_lm: which = which_lm; break; case form_sa: which = which_sa; break; case form_la: which = which_la; break; - + default: which = which_lm; break; } - // Make sure it's square. - arma_debug_check( (X.n_rows != X.n_cols), "eigs_sym(): given matrix must be square sized" ); - // Make sure we aren't asking for every eigenvalue. // The _saupd() functions allow asking for one more eigenvalue than the _naupd() functions. arma_debug_check( (n_eigvals >= X.n_rows), "eigs_sym(): n_eigvals must be less than the number of rows in the matrix" ); @@ -245,41 +436,74 @@ sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, } // Set up variables that get used for neupd(). - blas_int n, ncv, ldv, lworkl, info; - eT tol = default_tol; + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + eT tol = eT(opts.tol); + maxiter = blas_int(opts.maxiter); + podarray resid, v, workd, workl; podarray iparam, ipntr; podarray rwork; // Not used in this case. - run_aupd(n_eigvals, which, X, true /* sym, not gen */, n, tol, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + n = blas_int(X.n_rows); // The size of the matrix. + + // Use max(2*k+1, 20) as default subspace dimension for the sym case; MATLAB uses max(2*k, 20), but we need to be backward-compatible. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); - if(info != 0) + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) { - return false; + if(opts.subdim < (n_eigvals + 1)) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim must be greater than k; using k+1 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 1); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_sym(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } + } + + if(use_sigma) + //if(form_val == form_sigma) + { + run_aupd_shiftinvert(n_eigvals, sigma, X, true /* sym, not gen */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, true /* sym, not gen */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); } + if(info != 0) { return false; } + // The process has converged, and now we need to recover the actual eigenvectors using seupd() blas_int rvec = 1; // .TRUE - blas_int nev = n_eigvals; + blas_int nev = blas_int(n_eigvals); char howmny = 'A'; char bmat = 'I'; // We are considering the standard eigenvalue problem. - podarray select(ncv); // Logical array of dimension NCV. + podarray select(ncv, arma_zeros_indicator()); // Logical array of dimension NCV. blas_int ldz = n; // seupd() will output directly into the eigval and eigvec objects. - eigval.zeros(n_eigvals); + eigval.zeros( n_eigvals); eigvec.zeros(n, n_eigvals); - arpack::seupd(&rvec, &howmny, select.memptr(), eigval.memptr(), eigvec.memptr(), &ldz, (eT*) NULL, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + arpack::seupd(&rvec, &howmny, select.memptr(), eigval.memptr(), eigvec.memptr(), &ldz, (eT*) &sigma, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); // Check for errors. - if(info != 0) - { - arma_debug_warn("eigs_sym(): ARPACK error ", info, " in seupd()"); - return false; - } + if(info != 0) { arma_debug_warn_level(1, "eigs_sym(): ARPACK error ", info, " in seupd()"); return false; } return (info == 0); } @@ -289,8 +513,9 @@ sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, arma_ignore(eigvec); arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); return false; } @@ -303,30 +528,40 @@ sp_auxlib::eigs_sym_arpack(Col& eigval, Mat& eigvec, const SpMat& X, template inline bool -sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const char* form_str, const T default_tol) +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_NEWARP) + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) { - const unwrap_spmat U(X.get_ref()); + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } - return sp_auxlib::eigs_gen_newarp(eigval, eigvec, U.M, n_eigvals, form_str, default_tol); + // TODO: investigate optional redirection of "sm" to ARPACK as it's capable of shift-invert; + // TODO: in shift-invert mode, "sm" maps to "lm" of the shift-inverted matrix (with sigma = 0) + + #if defined(ARMA_USE_NEWARP) + { + return sp_auxlib::eigs_gen_newarp(eigval, eigvec, U.M, n_eigvals, form_val, opts); } #elif defined(ARMA_USE_ARPACK) { - const unwrap_spmat U(X.get_ref()); - - return sp_auxlib::eigs_gen_arpack(eigval, eigvec, U.M, n_eigvals, form_str, default_tol); + constexpr std::complex sigma = T(0); + + return sp_auxlib::eigs_gen_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); } #else { arma_ignore(eigval); arma_ignore(eigvec); - arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(opts); arma_stop_logic_error("eigs_gen(): use of NEWARP or ARPACK must be enabled"); return false; @@ -336,23 +571,61 @@ sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigv +//! immediate eigendecomposition of non-symmetric real sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_gen_arpack(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of ARPACK and SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + template inline bool -sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const T default_tol) +sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_NEWARP) { - const form_type form_val = sp_auxlib::interpret_form_str(form_str); + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si), "eigs_gen(): unknown form specified" ); - arma_debug_check( (form_val == form_none), "eigs_gen(): unknown form specified" ); + if(X.is_square() == false) { return false; } const newarp::SparseGenMatProd op(X); - arma_debug_check( (op.n_rows != op.n_cols), "eigs_sym(): given matrix must be square sized" ); - arma_debug_check( (n_eigvals + 1 >= op.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); // If the matrix is empty, the case is trivial. @@ -364,12 +637,39 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex } uword n = op.n_rows; - uword ncv = n_eigvals + 2 + 1; - if(ncv < (2 * n_eigvals + 1)) { ncv = 2 * n_eigvals + 1; } - if(ncv > n) { ncv = n; } + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + uword ncv_default = uword( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + uword ncv = ncv_default; + + if(opts.subdim != 0) + { + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = uword(n_eigvals + 3); + } + else + if(opts.subdim > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = uword(opts.subdim); + } + } + + // Re-check that we are within the limits + if(ncv < (n_eigvals + 3)) { ncv = (n_eigvals + 3); } + if(ncv > n ) { ncv = n; } - T tol = (std::max)(default_tol, std::numeric_limits::epsilon()); + T tol = (std::max)(T(opts.tol), std::numeric_limits::epsilon()); + + uword maxiter = uword(opts.maxiter); // eigval.set_size(n_eigvals); // eigvec.set_size(n, n_eigvals); @@ -384,7 +684,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -393,7 +693,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_MAGN, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -402,7 +702,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_REAL, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -411,7 +711,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_REAL, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -420,7 +720,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::LARGEST_IMAG, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -429,7 +729,7 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex { newarp::GenEigsSolver< T, newarp::EigsSelect::SMALLEST_IMAG, newarp::SparseGenMatProd > eigs(op, n_eigvals, ncv); eigs.init(); - nconv = eigs.compute(1000, tol); + nconv = eigs.compute(maxiter, tol); eigval = eigs.eigenvalues(); eigvec = eigs.eigenvectors(); } @@ -452,8 +752,8 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex arma_ignore(eigvec); arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(opts); return false; } @@ -463,18 +763,18 @@ sp_auxlib::eigs_gen_newarp(Col< std::complex >& eigval, Mat< std::complex -template +template inline bool -sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const char* form_str, const T default_tol) +sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_ARPACK) { - const form_type form_val = sp_auxlib::interpret_form_str(form_str); + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si) && (form_val != form_sigma), "eigs_gen(): unknown form specified" ); - arma_debug_check( (form_val == form_none), "eigs_gen(): unknown form specified" ); + if(X.is_square() == false) { return false; } char which_lm[3] = "LM"; char which_sm[3] = "SM"; @@ -497,10 +797,6 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex default: which = which_lm; } - - // Make sure it's square. - arma_debug_check( (X.n_rows != X.n_cols), "eigs_gen(): given matrix must be square sized" ); - // Make sure we aren't asking for every eigenvalue. arma_debug_check( (n_eigvals + 1 >= X.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); @@ -513,63 +809,106 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex } // Set up variables that get used for neupd(). - blas_int n, ncv, ldv, lworkl, info; - T tol = default_tol; + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + T tol = T(opts.tol); + maxiter = blas_int(opts.maxiter); + podarray resid, v, workd, workl; podarray iparam, ipntr; podarray rwork; // Not used in the real case. - run_aupd(n_eigvals, which, X, false /* gen, not sym */, n, tol, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + n = blas_int(X.n_rows); // The size of the matrix. - if(info != 0) + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); + + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) { - return false; + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 3); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } } - + + // WARNING!!! + // We are still not able to apply truly complex shifts to real matrices, + // in which case the OP that ARPACK wants is different (see [s/d]naupd). + // Also, if sigma contains a non-zero imaginary part, retrieving the eigenvalues + // becomes utterly messy (see [s/d]eupd, remark #3). + // We should never get to the point in which the imaginary part of sigma is non-zero; + // the user-facing functions currently convert X from real to complex if a complex sigma is detected. + // The check here is just for extra safety, and as a reminder of what's missing. + T sigmar = real(sigma); + T sigmai = imag(sigma); + + if(use_sigma) + //if(form_val == form_sigma) + { + if(sigmai != T(0)) { arma_stop_logic_error("eigs_gen(): complex 'sigma' not applicable to real matrix"); return false; } + + run_aupd_shiftinvert(n_eigvals, sigmar, X, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + + if(info != 0) { return false; } + // The process has converged, and now we need to recover the actual eigenvectors using neupd(). blas_int rvec = 1; // .TRUE - blas_int nev = n_eigvals; + blas_int nev = blas_int(n_eigvals); char howmny = 'A'; char bmat = 'I'; // We are considering the standard eigenvalue problem. - podarray select(ncv); // Logical array of dimension NCV. - podarray dr(nev + 1); // Real array of dimension NEV + 1. - podarray di(nev + 1); // Real array of dimension NEV + 1. - podarray z(n * (nev + 1)); // Real N by NEV array if HOWMNY = 'A'. - blas_int ldz = n; - podarray workev(3 * ncv); + podarray select(ncv, arma_zeros_indicator()); // logical array of dimension NCV + podarray dr(nev + 1, arma_zeros_indicator()); // real array of dimension NEV + 1 + podarray di(nev + 1, arma_zeros_indicator()); // real array of dimension NEV + 1 + podarray z(n * (nev + 1), arma_zeros_indicator()); // real N by NEV array if HOWMNY = 'A' + podarray workev(3 * ncv, arma_zeros_indicator()); - dr.zeros(); - di.zeros(); - z.zeros(); + blas_int ldz = n; - arpack::neupd(&rvec, &howmny, select.memptr(), dr.memptr(), di.memptr(), z.memptr(), &ldz, (T*) NULL, (T*) NULL, workev.memptr(), &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + arpack::neupd(&rvec, &howmny, select.memptr(), dr.memptr(), di.memptr(), z.memptr(), &ldz, (T*) &sigmar, (T*) &sigmai, workev.memptr(), &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); // Check for errors. - if(info != 0) - { - arma_debug_warn("eigs_gen(): ARPACK error ", info, " in neupd()"); - return false; - } + if(info != 0) { arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in neupd()"); return false; } // Put it into the outputs. eigval.set_size(n_eigvals); eigvec.zeros(n, n_eigvals); - for (uword i = 0; i < n_eigvals; ++i) + for(uword i = 0; i < n_eigvals; ++i) { eigval[i] = std::complex(dr[i], di[i]); } // Now recover the eigenvectors. - for (uword i = 0; i < n_eigvals; ++i) + for(uword i = 0; i < n_eigvals; ++i) { // ARPACK ?neupd lays things out kinda odd in memory; // so does LAPACK ?geev -- see auxlib::eig_gen() if((i < n_eigvals - 1) && (eigval[i] == std::conj(eigval[i + 1]))) { - for (uword j = 0; j < uword(n); ++j) + for(uword j = 0; j < uword(n); ++j) { eigvec.at(j, i) = std::complex(z[n * i + j], z[n * (i + 1) + j]); eigvec.at(j, i + 1) = std::complex(z[n * i + j], -z[n * (i + 1) + j]); @@ -580,7 +919,7 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex if((i == n_eigvals - 1) && (std::complex(eigval[i]).imag() != 0.0)) { // We don't have the matched conjugate eigenvalue. - for (uword j = 0; j < uword(n); ++j) + for(uword j = 0; j < uword(n); ++j) { eigvec.at(j, i) = std::complex(z[n * i + j], z[n * (i + 1) + j]); } @@ -588,7 +927,7 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex else { // The eigenvector is entirely real. - for (uword j = 0; j < uword(n); ++j) + for(uword j = 0; j < uword(n); ++j) { eigvec.at(j, i) = std::complex(z[n * i + j], T(0)); } @@ -603,8 +942,9 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex arma_ignore(eigvec); arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); return false; } @@ -617,17 +957,81 @@ sp_auxlib::eigs_gen_arpack(Col< std::complex >& eigval, Mat< std::complex template inline bool -sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X_expr, const uword n_eigvals, const char* form_str, const T default_tol) +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X_expr, const uword n_eigvals, const form_type form_val, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X_expr.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + constexpr std::complex sigma = T(0); + + return sp_auxlib::eigs_gen(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + + + +//! immediate eigendecomposition of non-symmetric complex sparse object +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpBase< std::complex, T1>& X, const uword n_eigvals, const std::complex sigma, const eigs_opts& opts) + { + arma_extra_debug_sigprint(); + + const unwrap_spmat U(X.get_ref()); + + arma_debug_check( (U.M.is_square() == false), "eigs_gen(): given matrix must be square sized" ); + + if(arma_config::check_nonfinite && U.M.internal_has_nonfinite()) + { + arma_debug_warn_level(3, "eigs_gen(): detected non-finite elements"); + return false; + } + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + constexpr form_type form_val = form_sigma; + + return sp_auxlib::eigs_gen(eigval, eigvec, U.M, n_eigvals, form_val, sigma, opts); + } + #else + { + arma_ignore(eigval); + arma_ignore(eigvec); + arma_ignore(n_eigvals); + arma_ignore(sigma); + arma_ignore(opts); + + arma_stop_logic_error("eigs_gen(): use of ARPACK and SuperLU must be enabled to use 'sigma'"); + return false; + } + #endif + } + + + +template +inline +bool +sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigvec, const SpMat< std::complex >& X, const uword n_eigvals, const form_type form_val, const std::complex sigma, const eigs_opts& opts) { arma_extra_debug_sigprint(); #if defined(ARMA_USE_ARPACK) { - typedef typename std::complex eT; + // typedef typename std::complex eT; - const form_type form_val = sp_auxlib::interpret_form_str(form_str); + arma_debug_check( (form_val != form_lm) && (form_val != form_sm) && (form_val != form_lr) && (form_val != form_sr) && (form_val != form_li) && (form_val != form_si) && (form_val != form_sigma), "eigs_gen(): unknown form specified" ); - arma_debug_check( (form_val == form_none), "eigs_gen(): unknown form specified" ); + if(X.is_square() == false) { return false; } char which_lm[3] = "LM"; char which_sm[3] = "SM"; @@ -650,13 +1054,6 @@ sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigv default: which = which_lm; } - const unwrap_spmat U(X_expr.get_ref()); - - const SpMat& X = U.M; - - // Make sure it's square. - arma_debug_check( (X.n_rows != X.n_cols), "eigs_gen(): given matrix must be square sized" ); - // Make sure we aren't asking for every eigenvalue. arma_debug_check( (n_eigvals + 1 >= X.n_rows), "eigs_gen(): n_eigvals + 1 must be less than the number of rows in the matrix" ); @@ -669,46 +1066,79 @@ sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigv } // Set up variables that get used for neupd(). - blas_int n, ncv, ldv, lworkl, info; - T tol = default_tol; + blas_int n, ncv, ncv_default, ldv, lworkl, info, maxiter; + + T tol = T(opts.tol); + maxiter = blas_int(opts.maxiter); + podarray< std::complex > resid, v, workd, workl; podarray iparam, ipntr; podarray rwork; - run_aupd(n_eigvals, which, X, false /* gen, not sym */, n, tol, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + n = blas_int(X.n_rows); // The size of the matrix. + + // Use max(2*k+1, 20) as default subspace dimension for the gen case; same as MATLAB. + ncv_default = blas_int( ((2*n_eigvals+1)>(20)) ? (2*n_eigvals+1) : (20) ); - if(info != 0) + // Use opts.subdim only if it's within the limits + ncv = ncv_default; + + if(opts.subdim != 0) { - return false; + if(opts.subdim < (n_eigvals + 3)) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim must be greater than k+2; using k+3 instead of ", opts.subdim); + ncv = blas_int(n_eigvals + 3); + } + else + if(blas_int(opts.subdim) > n) + { + arma_debug_warn_level(1, "eigs_gen(): opts.subdim cannot be greater than n_rows; using n_rows instead of ", opts.subdim); + ncv = n; + } + else + { + ncv = blas_int(opts.subdim); + } + } + + if(use_sigma) + //if(form_val == form_sigma) + { + run_aupd_shiftinvert(n_eigvals, sigma, X, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); + } + else + { + const SpMat< std::complex > Xst = X.st(); + + run_aupd_plain(n_eigvals, which, X, Xst, false /* gen, not sym */, n, tol, maxiter, resid, ncv, v, ldv, iparam, ipntr, workd, workl, lworkl, rwork, info); } + if(info != 0) { return false; } + // The process has converged, and now we need to recover the actual eigenvectors using neupd(). blas_int rvec = 1; // .TRUE - blas_int nev = n_eigvals; + blas_int nev = blas_int(n_eigvals); char howmny = 'A'; char bmat = 'I'; // We are considering the standard eigenvalue problem. - podarray select(ncv); // Logical array of dimension NCV. - podarray > d(nev + 1); // Real array of dimension NEV + 1. - podarray > z(n * nev); // Real N by NEV array if HOWMNY = 'A'. + podarray select(ncv, arma_zeros_indicator()); // logical array of dimension NCV + podarray> d(nev + 1, arma_zeros_indicator()); // complex array of dimension NEV + 1 + podarray> z(n * nev, arma_zeros_indicator()); // complex N by NEV array if HOWMNY = 'A' + podarray> workev(2 * ncv, arma_zeros_indicator()); + blas_int ldz = n; - podarray > workev(2 * ncv); // Prepare the outputs; neupd() will write directly to them. eigval.zeros(n_eigvals); eigvec.zeros(n, n_eigvals); - std::complex sigma; arpack::neupd(&rvec, &howmny, select.memptr(), eigval.memptr(), (std::complex*) NULL, eigvec.memptr(), &ldz, (std::complex*) &sigma, (std::complex*) NULL, workev.memptr(), &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); // Check for errors. - if(info != 0) - { - arma_debug_warn("eigs_gen(): ARPACK error ", info, " in neupd()"); - return false; - } + if(info != 0) { arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in neupd()"); return false; } return (info == 0); } @@ -716,10 +1146,11 @@ sp_auxlib::eigs_gen(Col< std::complex >& eigval, Mat< std::complex >& eigv { arma_ignore(eigval); arma_ignore(eigvec); - arma_ignore(X_expr); + arma_ignore(X); arma_ignore(n_eigvals); - arma_ignore(form_str); - arma_ignore(default_tol); + arma_ignore(form_val); + arma_ignore(sigma); + arma_ignore(opts); arma_stop_logic_error("eigs_gen(): use of ARPACK must be enabled for decomposition of complex matrices"); return false; @@ -748,20 +1179,14 @@ sp_auxlib::spsolve_simple(Mat& X, const SpBase A.n_cols) - { - arma_stop_logic_error("spsolve(): solving over-determined systems currently not supported"); - X.soft_reset(); - return false; - } - else if(A.n_rows < A.n_cols) + if(A.is_square() == false) { - arma_stop_logic_error("spsolve(): solving under-determined systems currently not supported"); X.soft_reset(); + arma_stop_logic_error("spsolve(): solving under-determined / over-determined systems is currently not supported"); return false; } - arma_debug_check( (A.n_rows != X.n_rows), "spsolve(): number of rows in the given objects must be the same" ); + arma_debug_check( (A.n_rows != X.n_rows), "spsolve(): number of rows in the given objects must be the same", [&](){ X.soft_reset(); } ); if(A.is_empty() || X.is_empty()) { @@ -771,9 +1196,15 @@ sp_auxlib::spsolve_simple(Mat& X, const SpBase INT_MAX); overflow = (A.n_rows > INT_MAX) || overflow; @@ -788,41 +1219,28 @@ sp_auxlib::spsolve_simple(Mat& X, const SpBase(&x), char(0), sizeof(superlu::SuperMatrix)); - superlu::SuperMatrix a; arrayops::inplace_set(reinterpret_cast(&a), char(0), sizeof(superlu::SuperMatrix)); + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler a; - const bool status_x = wrap_to_supermatrix(x, X); - const bool status_a = copy_to_supermatrix(a, A); + const bool status_x = wrap_to_supermatrix(x.get_ref(), X); + const bool status_a = copy_to_supermatrix(a.get_ref(), A); - if( (status_x == false) || (status_a == false) ) - { - destroy_supermatrix(a); - destroy_supermatrix(x); - X.soft_reset(); - return false; - } + if( (status_x == false) || (status_a == false) ) { X.soft_reset(); return false; } - superlu::SuperMatrix l; arrayops::inplace_set(reinterpret_cast(&l), char(0), sizeof(superlu::SuperMatrix)); - superlu::SuperMatrix u; arrayops::inplace_set(reinterpret_cast(&u), char(0), sizeof(superlu::SuperMatrix)); + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; // paranoia: use SuperLU's memory allocation, in case it reallocs - int* perm_c = (int*) superlu::malloc( (A.n_cols+1) * sizeof(int)); // extra paranoia: increase array length by 1 - int* perm_r = (int*) superlu::malloc( (A.n_rows+1) * sizeof(int)); + superlu_array_wrangler perm_c(A.n_cols+1); // extra paranoia: increase array length by 1 + superlu_array_wrangler perm_r(A.n_rows+1); - arma_check_bad_alloc( (perm_c == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (perm_r == 0), "spsolve(): out of memory" ); - - arrayops::inplace_set(perm_c, 0, A.n_cols+1); - arrayops::inplace_set(perm_r, 0, A.n_rows+1); - - superlu::SuperLUStat_t stat; - superlu::init_stat(&stat); + superlu_stat_wrangler stat; int info = 0; // Return code. arma_extra_debug_print("superlu::gssv()"); - superlu::gssv(&options, &a, perm_c, perm_r, &l, &u, &x, &stat, &info); + superlu::gssv(&options, a.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), x.get_ptr(), stat.get_ptr(), &info); // Process the return code. @@ -830,29 +1248,20 @@ sp_auxlib::spsolve_simple(Mat& X, const SpBase int(A.n_cols)) { - arma_debug_warn("spsolve(): memory allocation failure: could not allocate ", (info - int(A.n_cols)), " bytes"); + arma_debug_warn_level(1, "spsolve(): memory allocation failure"); } else if(info < 0) { - arma_debug_warn("spsolve(): unknown SuperLU error code from gssv(): ", info); + arma_debug_warn_level(1, "spsolve(): unknown SuperLU error code from gssv(): ", info); } - - superlu::free_stat(&stat); - - superlu::free(perm_c); - superlu::free(perm_r); - - destroy_supermatrix(u); - destroy_supermatrix(l); - destroy_supermatrix(a); - destroy_supermatrix(x); // No need to extract the data from x, since it's using the same memory as X + // No need to extract the data from x, since it's using the same memory as X return (info == 0); } @@ -897,30 +1306,27 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& const Mat& B = (B_is_modified) ? B_copy : B_unwrap; - if(A.n_rows > A.n_cols) - { - arma_stop_logic_error("spsolve(): solving over-determined systems currently not supported"); - X.soft_reset(); - return false; - } - else if(A.n_rows < A.n_cols) + if(A.is_square() == false) { - arma_stop_logic_error("spsolve(): solving under-determined systems currently not supported"); X.soft_reset(); + arma_stop_logic_error("spsolve(): solving under-determined / over-determined systems is currently not supported"); return false; } - arma_debug_check( (A.n_rows != B.n_rows), "spsolve(): number of rows in the given objects must be the same" ); + arma_debug_check( (A.n_rows != B.n_rows), "spsolve(): number of rows in the given objects must be the same", [&](){ X.soft_reset(); } ); X.zeros(A.n_cols, B.n_cols); // set the elements to zero, as we don't trust the SuperLU spaghetti code - if(A.is_empty() || B.is_empty()) - { - return true; - } + if(A.is_empty() || B.is_empty()) { return true; } if(A.n_nonzero == uword(0)) { X.soft_reset(); return false; } + if(arma_config::check_nonfinite && (A.internal_has_nonfinite() || B.internal_has_nonfinite())) + { + arma_debug_warn_level(3, "spsolve(): detected non-finite elements"); + return false; + } + if(arma_config::debug) { bool overflow; @@ -940,74 +1346,48 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& } } - superlu::SuperMatrix x; arrayops::inplace_set(reinterpret_cast(&x), char(0), sizeof(superlu::SuperMatrix)); - superlu::SuperMatrix a; arrayops::inplace_set(reinterpret_cast(&a), char(0), sizeof(superlu::SuperMatrix)); - superlu::SuperMatrix b; arrayops::inplace_set(reinterpret_cast(&b), char(0), sizeof(superlu::SuperMatrix)); + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler a; + superlu_supermatrix_wrangler b; - const bool status_x = wrap_to_supermatrix(x, X); - const bool status_a = copy_to_supermatrix(a, A); // NOTE: superlu::gssvx() modifies 'a' if equilibration is enabled - const bool status_b = wrap_to_supermatrix(b, B); // NOTE: superlu::gssvx() modifies 'b' if equilibration is enabled + const bool status_x = wrap_to_supermatrix(x.get_ref(), X); + const bool status_a = copy_to_supermatrix(a.get_ref(), A); // NOTE: superlu::gssvx() modifies 'a' if equilibration is enabled + const bool status_b = wrap_to_supermatrix(b.get_ref(), B); // NOTE: superlu::gssvx() modifies 'b' if equilibration is enabled - if( (status_x == false) || (status_a == false) || (status_b == false) ) - { - destroy_supermatrix(x); - destroy_supermatrix(a); - destroy_supermatrix(b); - X.soft_reset(); - return false; - } + if( (status_x == false) || (status_a == false) || (status_b == false) ) { X.soft_reset(); return false; } - superlu::SuperMatrix l; arrayops::inplace_set(reinterpret_cast(&l), char(0), sizeof(superlu::SuperMatrix)); - superlu::SuperMatrix u; arrayops::inplace_set(reinterpret_cast(&u), char(0), sizeof(superlu::SuperMatrix)); + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; // paranoia: use SuperLU's memory allocation, in case it reallocs - int* perm_c = (int*) superlu::malloc( (A.n_cols+1) * sizeof(int) ); // extra paranoia: increase array length by 1 - int* perm_r = (int*) superlu::malloc( (A.n_rows+1) * sizeof(int) ); - int* etree = (int*) superlu::malloc( (A.n_cols+1) * sizeof(int) ); - - T* R = (T*) superlu::malloc( (A.n_rows+1) * sizeof(T) ); - T* C = (T*) superlu::malloc( (A.n_cols+1) * sizeof(T) ); - T* ferr = (T*) superlu::malloc( (B.n_cols+1) * sizeof(T) ); - T* berr = (T*) superlu::malloc( (B.n_cols+1) * sizeof(T) ); - - arma_check_bad_alloc( (perm_c == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (perm_r == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (etree == 0), "spsolve(): out of memory" ); - - arma_check_bad_alloc( (R == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (C == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (ferr == 0), "spsolve(): out of memory" ); - arma_check_bad_alloc( (berr == 0), "spsolve(): out of memory" ); - - arrayops::inplace_set(perm_c, int(0), A.n_cols+1); - arrayops::inplace_set(perm_r, int(0), A.n_rows+1); - arrayops::inplace_set(etree, int(0), A.n_cols+1); + superlu_array_wrangler perm_c(A.n_cols+1); // extra paranoia: increase array length by 1 + superlu_array_wrangler perm_r(A.n_rows+1); + superlu_array_wrangler etree(A.n_cols+1); - arrayops::inplace_set(R, T(0), A.n_rows+1); - arrayops::inplace_set(C, T(0), A.n_cols+1); - arrayops::inplace_set(ferr, T(0), B.n_cols+1); - arrayops::inplace_set(berr, T(0), B.n_cols+1); + superlu_array_wrangler R(A.n_rows+1); + superlu_array_wrangler C(A.n_cols+1); + superlu_array_wrangler ferr(B.n_cols+1); + superlu_array_wrangler berr(B.n_cols+1); superlu::GlobalLU_t glu; - arrayops::inplace_set(reinterpret_cast(&glu), char(0), sizeof(superlu::GlobalLU_t)); + arrayops::fill_zeros(reinterpret_cast(&glu), sizeof(superlu::GlobalLU_t)); superlu::mem_usage_t mu; - arrayops::inplace_set(reinterpret_cast(&mu), char(0), sizeof(superlu::mem_usage_t)); + arrayops::fill_zeros(reinterpret_cast(&mu), sizeof(superlu::mem_usage_t)); - superlu::SuperLUStat_t stat; - superlu::init_stat(&stat); + superlu_stat_wrangler stat; - char equed[8]; // extra characters for paranoia - T rpg = T(0); - T rcond = T(0); - int info = int(0); // Return code. + char equed[8] = {}; // extra characters for paranoia + T rpg = T(0); + T rcond = T(0); + int info = int(0); // Return code. - char work[8]; - int lwork = int(0); // 0 means superlu will allocate memory + char work[8] = {}; + int lwork = int(0); // 0 means superlu will allocate memory arma_extra_debug_print("superlu::gssvx()"); - superlu::gssvx(&options, &a, perm_c, perm_r, etree, equed, R, C, &l, &u, &work[0], lwork, &b, &x, &rpg, &rcond, ferr, berr, &glu, &mu, &stat, &info); + superlu::gssvx(&options, a.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), etree.get_ptr(), equed, R.get_ptr(), C.get_ptr(), l.get_ptr(), u.get_ptr(), &work[0], lwork, b.get_ptr(), x.get_ptr(), &rpg, &rcond, ferr.get_ptr(), berr.get_ptr(), &glu, &mu, stat.get_ptr(), &info); bool status = false; @@ -1020,40 +1400,26 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& { // std::ostringstream tmp; // tmp << "spsolve(): could not solve system; LU factorisation completed, but detected zero in U(" << (info-1) << ',' << (info-1) << ')'; - // arma_debug_warn(tmp.str()); + // arma_debug_warn_level(1, tmp.str()); } else if( (info == int(A.n_cols+1)) && (user_opts.allow_ugly) ) { - arma_debug_warn("spsolve(): system is singular to working precision (rcond: ", rcond, ")"); + arma_debug_warn_level(2, "spsolve(): system is singular to working precision (rcond: ", rcond, ")"); status = true; } else if(info > int(A.n_cols+1)) { - arma_debug_warn("spsolve(): memory allocation failure: could not allocate ", (info - int(A.n_cols)), " bytes"); + arma_debug_warn_level(1, "spsolve(): memory allocation failure"); } else if(info < 0) { - arma_debug_warn("spsolve(): unknown SuperLU error code from gssvx(): ", info); + arma_debug_warn_level(1, "spsolve(): unknown SuperLU error code from gssvx(): ", info); } - superlu::free_stat(&stat); - - superlu::free(berr); - superlu::free(ferr); - superlu::free(C); - superlu::free(R); - superlu::free(etree); - superlu::free(perm_r); - superlu::free(perm_c); - - destroy_supermatrix(u); - destroy_supermatrix(l); - destroy_supermatrix(b); - destroy_supermatrix(a); - destroy_supermatrix(x); // No need to extract the data from x, since it's using the same memory as X + // No need to extract the data from x, since it's using the same memory as X out_rcond = rcond; @@ -1076,6 +1442,44 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& #if defined(ARMA_USE_SUPERLU) + template + inline + typename get_pod_type::result + sp_auxlib::norm1(superlu::SuperMatrix* A) + { + arma_extra_debug_sigprint(); + + char norm_id = '1'; + + arma_extra_debug_print("superlu::langs()"); + return superlu::langs(&norm_id, A); + } + + + + template + inline + typename get_pod_type::result + sp_auxlib::lu_rcond(superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result norm_val) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + char norm_id = '1'; + T rcond_out = T(0); + int info = int(0); + + superlu_stat_wrangler stat; + + arma_extra_debug_print("superlu::gscon()"); + superlu::gscon(&norm_id, L, U, norm_val, &rcond_out, stat.get_ptr(), &info); + + return (info == 0) ? T(rcond_out) : T(0); + } + + + inline void sp_auxlib::set_superlu_opts(superlu::superlu_options_t& options, const superlu_opts& user_opts) @@ -1122,55 +1526,42 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& // We store in column-major CSC. out.Stype = superlu::SLU_NC; - if(is_float::value) - { - out.Dtype = superlu::SLU_S; - } - else - if(is_double::value) - { - out.Dtype = superlu::SLU_D; - } - else - if(is_cx_float::value) - { - out.Dtype = superlu::SLU_C; - } - else - if(is_cx_double::value) - { - out.Dtype = superlu::SLU_Z; - } + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } out.Mtype = superlu::SLU_GE; // Just a general matrix. We don't know more now. // We have to actually create the object which stores the data. // This gets cleaned by destroy_supermatrix(). - // We have to use SuperLU's stupid memory allocation routines since they are + // We have to use SuperLU's problematic memory allocation routines since they are // not guaranteed to be new and delete. See the comments in def_superlu.hpp superlu::NCformat* nc = (superlu::NCformat*)superlu::malloc(sizeof(superlu::NCformat)); - if(nc == NULL) { return false; } + if(nc == nullptr) { return false; } + + A.sync(); nc->nnz = A.n_nonzero; nc->nzval = (void*) superlu::malloc(sizeof(eT) * A.n_nonzero ); nc->colptr = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * (A.n_cols + 1)); nc->rowind = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * A.n_nonzero ); - if( (nc->nzval == NULL) || (nc->colptr == NULL) || (nc->rowind == NULL) ) { return false; } + if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { return false; } // Fill the matrix. arrayops::copy((eT*) nc->nzval, A.values, A.n_nonzero); // // These have to be copied by hand, because the types may differ. - // for (uword i = 0; i <= A.n_cols; ++i) { nc->colptr[i] = (int_t) A.col_ptrs[i]; } - // for (uword i = 0; i < A.n_nonzero; ++i) { nc->rowind[i] = (int_t) A.row_indices[i]; } + // for(uword i = 0; i <= A.n_cols; ++i) { nc->colptr[i] = (int_t) A.col_ptrs[i]; } + // for(uword i = 0; i < A.n_nonzero; ++i) { nc->rowind[i] = (int_t) A.row_indices[i]; } arrayops::convert(nc->colptr, A.col_ptrs, A.n_cols+1 ); arrayops::convert(nc->rowind, A.row_indices, A.n_nonzero); - out.nrow = A.n_rows; - out.ncol = A.n_cols; + out.nrow = superlu::int_t(A.n_rows); + out.ncol = superlu::int_t(A.n_cols); out.Store = (void*) nc; return true; @@ -1178,44 +1569,255 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& + // memory efficient implementation of out = A - shift*I, where A is a square matrix template inline bool - sp_auxlib::wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat& A) + sp_auxlib::copy_to_supermatrix_with_shift(superlu::SuperMatrix& out, const SpMat& A, const eT shift) { arma_extra_debug_sigprint(); - // NOTE: this function re-uses memory from matrix A - - // This is being stored as a dense matrix. - out.Stype = superlu::SLU_DN; + arma_debug_check( (A.is_square() == false), "sp_auxlib::copy_to_supermatrix_with_shift(): given matrix must be square sized" ); - if(is_float::value) + if(shift == eT(0)) { - out.Dtype = superlu::SLU_S; + arma_extra_debug_print("sp_auxlib::copy_to_supermatrix_with_shift(): shift is zero; redirecting to sp_auxlib::copy_to_supermatrix()"); + return sp_auxlib::copy_to_supermatrix(out, A); } - else - if(is_double::value) - { - out.Dtype = superlu::SLU_D; - } - else - if(is_cx_float::value) + + // We store in column-major CSC. + out.Stype = superlu::SLU_NC; + + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } + + out.Mtype = superlu::SLU_GE; // Just a general matrix. We don't know more now. + + // We have to actually create the object which stores the data. + // This gets cleaned by destroy_supermatrix(). + superlu::NCformat* nc = (superlu::NCformat*)superlu::malloc(sizeof(superlu::NCformat)); + + if(nc == nullptr) { return false; } + + A.sync(); + + uword n_nonzero_diag_old = 0; + uword n_nonzero_diag_new = 0; + + const uword n_search_cols = (std::min)(A.n_rows, A.n_cols); + + for(uword j=0; j < n_search_cols; ++j) { - out.Dtype = superlu::SLU_C; + const uword col_offset = A.col_ptrs[j ]; + const uword next_col_offset = A.col_ptrs[j + 1]; + + const uword* start_ptr = &(A.row_indices[ col_offset]); + const uword* end_ptr = &(A.row_indices[next_col_offset]); + + const uword wanted_row = j; + + const uword* pos_ptr = std::lower_bound(start_ptr, end_ptr, wanted_row); // binary search + + if( (pos_ptr != end_ptr) && ((*pos_ptr) == wanted_row) ) + { + // element on the main diagonal is non-zero + ++n_nonzero_diag_old; + + const uword offset = uword(pos_ptr - start_ptr); + const uword index = offset + col_offset; + + const eT new_val = A.values[index] - shift; + + if(new_val != eT(0)) { ++n_nonzero_diag_new; } + } + else + { + // element on the main diagonal is zero, but sigma is non-zero, + // so the number of new non-zero elments on the diagonal is increased + ++n_nonzero_diag_new; + } } - else - if(is_cx_double::value) + + const uword out_n_nonzero = A.n_nonzero - n_nonzero_diag_old + n_nonzero_diag_new; + + arma_extra_debug_print( arma_str::format("A.n_nonzero: %u") % A.n_nonzero ); + arma_extra_debug_print( arma_str::format("n_nonzero_diag_old: %u") % n_nonzero_diag_old ); + arma_extra_debug_print( arma_str::format("n_nonzero_diag_new: %u") % n_nonzero_diag_new ); + arma_extra_debug_print( arma_str::format("out_n_nonzero: %u") % out_n_nonzero ); + + nc->nnz = out_n_nonzero; + nc->nzval = (void*) superlu::malloc(sizeof(eT) * out_n_nonzero ); + nc->colptr = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * (A.n_cols + 1)); + nc->rowind = (superlu::int_t*)superlu::malloc(sizeof(superlu::int_t) * out_n_nonzero ); + + if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { return false; } + + // fill the matrix column by column, and insert diagonal elements when necessary + + nc->colptr[0] = 0; + + eT* values_current = (eT*) nc->nzval; + superlu::int_t* rowind_current = nc->rowind; + + uword count = 0; + + for(uword j=0; j < A.n_cols; ++j) { - out.Dtype = superlu::SLU_Z; + const uword idx_start = A.col_ptrs[j ]; + const uword idx_end = A.col_ptrs[j + 1]; + + const eT* values_start = values_current; + + uword i = idx_start; + + // elements in the upper triangular part, excluding the main diagonal + for(; (i < idx_end) && (A.row_indices[i] < j); ++i) + { + (*values_current) = A.values[i]; + (*rowind_current) = superlu::int_t(A.row_indices[i]); + + ++values_current; + ++rowind_current; + + ++count; + } + + // elements on the main diagonal + if( (i < idx_end) && (A.row_indices[i] == j) ) + { + // A(j,j) is non-zero + + const eT new_diag_val = A.values[i] - shift; + + if(new_diag_val != eT(0)) + { + (*values_current) = new_diag_val; + (*rowind_current) = superlu::int_t(j); + + ++values_current; + ++rowind_current; + + ++count; + } + + ++i; + } + else + { + // A(j,j) is zero, so insert a new element + + if(j < n_search_cols) + { + (*values_current) = -shift; + (*rowind_current) = superlu::int_t(j); + + ++values_current; + ++rowind_current; + + ++count; + } + } + + // elements in the lower triangular part, excluding the main diagonal + for(; i < idx_end; ++i) + { + (*values_current) = A.values[i]; + (*rowind_current) = superlu::int_t(A.row_indices[i]); + + ++values_current; + ++rowind_current; + + ++count; + } + + // number of non-zero elements in the j-th column of out + const uword nnz_col = values_current - values_start; + nc->colptr[j + 1] = superlu::int_t(nc->colptr[j] + nnz_col); } + arma_extra_debug_print( arma_str::format("count: %u") % count ); + + arma_check( (count != out_n_nonzero), "internal error: sp_auxlib::copy_to_supermatrix_with_shift(): count != out_n_nonzero" ); + + out.nrow = superlu::int_t(A.n_rows); + out.ncol = superlu::int_t(A.n_cols); + out.Store = (void*) nc; + + return true; + } + + + +// // for debugging only +// template +// inline +// void +// sp_auxlib::copy_to_spmat(SpMat& out, const superlu::SuperMatrix& A) +// { +// arma_extra_debug_sigprint(); +// +// bool type_matched = false; +// +// if( is_float::value) { type_matched = (A.Dtype == superlu::SLU_S); } +// else if( is_double::value) { type_matched = (A.Dtype == superlu::SLU_D); } +// else if( is_cx_float::value) { type_matched = (A.Dtype == superlu::SLU_C); } +// else if(is_cx_double::value) { type_matched = (A.Dtype == superlu::SLU_Z); } +// +// arma_debug_check( (type_matched == false), "copy_to_spmat(): type mismatch" ); +// arma_debug_check( (A.Mtype != superlu::SLU_GE), "copy_to_spmat(): unknown layout" ); +// +// // NOTE: the l and u instances of SuperMatrix resulting from superlu::gstrf() +// // NOTE: do not have the superlu::SLU_GE layout +// +// const superlu::NCformat* nc = (const superlu::NCformat*)(A.Store); +// +// if(nc == nullptr) { out.reset(); return; } +// +// if( (nc->nzval == nullptr) || (nc->colptr == nullptr) || (nc->rowind == nullptr) ) { out.reset(); return; } +// +// const uword A_n_rows = uword(A.nrow ); +// const uword A_n_cols = uword(A.ncol ); +// const uword A_n_nonzero = uword(nc->nnz); +// +// if(A_n_nonzero == 0) { out.zeros(A_n_rows, A_n_cols); return; } +// +// out.reserve(A_n_rows, A_n_cols, A_n_nonzero); +// +// arrayops::copy(access::rwp(out.values), (const eT*)(nc->nzval), A_n_nonzero); +// +// arrayops::convert(access::rwp(out.col_ptrs), nc->colptr, A_n_cols+1 ); +// arrayops::convert(access::rwp(out.row_indices), nc->rowind, A_n_nonzero); +// +// out.remove_zeros(); // in case SuperLU has bugs and stores zeros in sparse matrices +// } + + + + template + inline + bool + sp_auxlib::wrap_to_supermatrix(superlu::SuperMatrix& out, const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: this function re-uses memory from matrix A + + // This is being stored as a dense matrix. + out.Stype = superlu::SLU_DN; + + if( is_float::value) { out.Dtype = superlu::SLU_S; } + else if( is_double::value) { out.Dtype = superlu::SLU_D; } + else if( is_cx_float::value) { out.Dtype = superlu::SLU_C; } + else if(is_cx_double::value) { out.Dtype = superlu::SLU_Z; } + out.Mtype = superlu::SLU_GE; // We have to create the object that stores the data. superlu::DNformat* dn = (superlu::DNformat*)superlu::malloc(sizeof(superlu::DNformat)); - if(dn == NULL) { return false; } + if(dn == nullptr) { return false; } dn->lda = A.n_rows; dn->nzval = (void*) A.memptr(); // re-use memory instead of copying @@ -1241,6 +1843,11 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& superlu::destroy_compcol_mat(&out); } else + if(out.Stype == superlu::SLU_NCP) + { + superlu::destroy_compcolperm_mat(&out); + } + else if(out.Stype == superlu::SLU_DN) { // superlu::destroy_dense_mat(&out); @@ -1251,7 +1858,7 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& superlu::DNformat* dn = (superlu::DNformat*) out.Store; - if(dn != NULL) { superlu::free(dn); } + if(dn != nullptr) { superlu::free(dn); } } else if(out.Stype == superlu::SLU_SC) @@ -1277,8 +1884,8 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& if(out.Stype == superlu::SLU_DN) { tmp << "SLU_DN"; } if(out.Stype == superlu::SLU_NR_loc) { tmp << "SLU_NR_loc"; } - arma_debug_warn(tmp.str()); - arma_stop_runtime_error("sp_auxlib::destroy_supermatrix(): internal error"); + arma_debug_warn_level(1, tmp.str()); + arma_stop_runtime_error("internal error: sp_auxlib::destroy_supermatrix()"); } } @@ -1289,10 +1896,11 @@ sp_auxlib::spsolve_refine(Mat& X, typename T1::pod_type& template inline void -sp_auxlib::run_aupd +sp_auxlib::run_aupd_plain ( - const uword n_eigvals, char* which, const SpMat& X, const bool sym, - blas_int& n, eT& tol, + const uword n_eigvals, char* which, + const SpMat& X, const SpMat& Xst, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, podarray& iparam, podarray& ipntr, podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, @@ -1307,57 +1915,66 @@ sp_auxlib::run_aupd // return code what we need to do next (usually a matrix-vector product) and // then call it again. So this results in some type of iterative process // where we call saupd()/naupd() many times. + blas_int ido = 0; // This must be 0 for the first call. char bmat = 'I'; // We are considering the standard eigenvalue problem. - n = X.n_rows; // The size of the matrix. + n = X.n_rows; // The size of the matrix (should already be set outside). blas_int nev = n_eigvals; - resid.set_size(n); + // resid.zeros(n); + eigs_randu_filler randu_filler; + randu_filler.fill(resid, n); // use deterministic starting point - // Two contraints on NCV: (NCV > NEV + 2) and (NCV <= N) + // Two contraints on NCV: (NCV > NEV) for sym problems or + // (NCV > NEV + 2) for gen problems and (NCV <= N) // // We're calling either arpack::saupd() or arpack::naupd(), // which have slighly different minimum constraint and recommended value for NCV: // http://www.caam.rice.edu/software/ARPACK/UG/node136.html // http://www.caam.rice.edu/software/ARPACK/UG/node138.html - ncv = nev + 2 + 1; - - if (ncv < (2 * nev + 1)) { ncv = 2 * nev + 1; } - if (ncv > n ) { ncv = n; } + if(ncv < (nev + (sym ? 1 : 3))) { ncv = (nev + (sym ? 1 : 3)); } + if(ncv > n ) { ncv = n; } - v.set_size(n * ncv); // Array N by NCV (output). - rwork.set_size(ncv); // Work array of size NCV for complex calls. + v.zeros(n * ncv); // Array N by NCV (output). + rwork.zeros(ncv); // Work array of size NCV for complex calls. ldv = n; // "Leading dimension of V exactly as declared in the calling program." // IPARAM: integer array of length 11. iparam.zeros(11); iparam(0) = 1; // Exact shifts (not provided by us). - iparam(2) = 1000; // Maximum iterations; all the examples use 300, but they were written in the ancient times. + iparam(2) = maxiter; // Maximum iterations; all the examples use 300, but they were written in the ancient times. iparam(6) = 1; // Mode 1: A * x = lambda * x. // IPNTR: integer array of length 14 (output). - ipntr.set_size(14); + ipntr.zeros(14); // Real work array used in the basic Arnoldi iteration for reverse communication. - workd.set_size(3 * n); + workd.zeros(3 * n); // lworkl must be at least 3 * NCV^2 + 6 * NCV. lworkl = 3 * (ncv * ncv) + 6 * ncv; // Real work array of length lworkl. - workl.set_size(lworkl); + workl.zeros(lworkl); - info = 0; // Set to 0 initially to use random initial vector. + // info = 0; // resid to be filled with random values by ARPACK (non-deterministic) + info = 1; // resid is already filled with random values (deterministic) // All the parameters have been set or created. Time to loop a lot. - while (ido != 99) + while(ido != 99) { // Call saupd() or naupd() with the current parameters. if(sym) + { + arma_extra_debug_print("arpack::saupd()"); arpack::saupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + } else + { + arma_extra_debug_print("arpack::naupd()"); arpack::naupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + } // What do we do now? switch (ido) @@ -1370,24 +1987,300 @@ sp_auxlib::run_aupd // where x is of length n and starts at workd(ipntr(0)), and y is of // length n and starts at workd(ipntr(1)). - // operator*(sp_mat, vec) doesn't properly put the result into the - // right place so we'll just reimplement it here for now... + // // OLD METHOD + // + // // operator*(sp_mat, vec) doesn't properly put the result into the + // // right place so we'll just reimplement it here for now... + // + // // Set the output to point at the right memory. We have to subtract + // // one from FORTRAN pointers... + // Col out(workd.memptr() + ipntr(1) - 1, n, false /* don't copy */); + // // Set the input to point at the right memory. + // Col in(workd.memptr() + ipntr(0) - 1, n, false /* don't copy */); + // + // out.zeros(); + // + // T* out_mem = out.memptr(); + // const T* in_mem = in.memptr(); + // + // typename SpMat::const_iterator X_it = X.begin(); + // + // const uword X_nnz = X.n_nonzero; + // + // for(uword count=0; count < X_nnz; ++count, ++X_it) + // { + // const eT X_it_val = (*X_it); + // const uword X_it_row = X_it.row(); + // const uword X_it_col = X_it.col(); + // + // out_mem[X_it_row] += X_it_val * in_mem[X_it_col]; + // } + // + // // No need to modify memory further since it was all done in-place. - // Set the output to point at the right memory. We have to subtract - // one from FORTRAN pointers... - Col out(workd.memptr() + ipntr(1) - 1, n, false /* don't copy */); - // Set the input to point at the right memory. + + // NEW METHOD + // + // both operator*(rowvec, sp_mat) and operator*(sp_mat, colvec) can now write to an existing object + + Row out(workd.memptr() + ipntr(1) - 1, n, false, true); + Row in(workd.memptr() + ipntr(0) - 1, n, false, true); + + out = in * Xst; + + break; + } + case 99: + // Nothing to do here, things have converged. + break; + default: + { + return; // Parent frame can look at the value of info. + } + } + } + + // The process has ended; check the return code. + if( (info != 0) && (info != 1) ) + { + // Print warnings if there was a failure. + + if(sym) + { + arma_debug_warn_level(1, "eigs_sym(): ARPACK error ", info, " in saupd()"); + } + else + { + arma_debug_warn_level(1, "eigs_gen(): ARPACK error ", info, " in naupd()"); + } + + return; // Parent frame can look at the value of info. + } + } + #else + { + arma_ignore(n_eigvals); + arma_ignore(which); + arma_ignore(X); + arma_ignore(sym); + arma_ignore(n); + arma_ignore(tol); + arma_ignore(maxiter); + arma_ignore(resid); + arma_ignore(ncv); + arma_ignore(v); + arma_ignore(ldv); + arma_ignore(iparam); + arma_ignore(ipntr); + arma_ignore(workd); + arma_ignore(workl); + arma_ignore(lworkl); + arma_ignore(rwork); + arma_ignore(info); + } + #endif + } + + + +// Here 'sigma' is 'T', but should be 'eT'. +// Applying complex shifts to real matrices is currently not directly implemented +template +inline +void +sp_auxlib::run_aupd_shiftinvert + ( + const uword n_eigvals, const T sigma, + const SpMat& X, const bool sym, + blas_int& n, eT& tol, blas_int& maxiter, + podarray& resid, blas_int& ncv, podarray& v, blas_int& ldv, + podarray& iparam, podarray& ipntr, + podarray& workd, podarray& workl, blas_int& lworkl, podarray& rwork, + blas_int& info + ) + { + // TODO: inconsistent use of type names: T can be complex while eT can be real + + #if (defined(ARMA_USE_ARPACK) && defined(ARMA_USE_SUPERLU)) + { + char which_lm[3] = "LM"; + + char* which = which_lm; // NOTE: which_lm is the assumed operation when using shift-invert + + blas_int ido = 0; // This must be 0 for the first call. + char bmat = 'I'; // We are considering the standard eigenvalue problem. + n = X.n_rows; // The size of the matrix (should already be set outside). + blas_int nev = n_eigvals; + + // resid.zeros(n); + eigs_randu_filler randu_filler; + randu_filler.fill(resid, n); // use deterministic starting point + + // Two contraints on NCV: (NCV > NEV) for sym problems or + // (NCV > NEV + 2) for gen problems and (NCV <= N) + // + // We're calling either arpack::saupd() or arpack::naupd(), + // which have slighly different minimum constraint and recommended value for NCV: + // http://www.caam.rice.edu/software/ARPACK/UG/node136.html + // http://www.caam.rice.edu/software/ARPACK/UG/node138.html + + if(ncv < (nev + (sym ? 1 : 3))) { ncv = (nev + (sym ? 1 : 3)); } + if(ncv > n ) { ncv = n; } + + v.zeros(n * ncv); // Array N by NCV (output). + rwork.zeros(ncv); // Work array of size NCV for complex calls. + ldv = n; // "Leading dimension of V exactly as declared in the calling program." + + // IPARAM: integer array of length 11. + iparam.zeros(11); + iparam(0) = 1; // Exact shifts (not provided by us). + iparam(2) = maxiter; // Maximum iterations; all the examples use 300, but they were written in the ancient times. + // iparam(6) = 1; // Mode 1: A * x = lambda * x. + + // Change IPARAM for shift-invert + iparam(6) = 3; // Mode 3: A * x = lambda * M * x, M symmetric semi-definite. OP = inv[A - sigma*M]*M (A complex) or Real_Part{ inv[A - sigma*M]*M } (A real) and B = M. + + // IPNTR: integer array of length 14 (output). + ipntr.zeros(14); + + // Real work array used in the basic Arnoldi iteration for reverse communication. + workd.zeros(3 * n); + + // lworkl must be at least 3 * NCV^2 + 6 * NCV. + lworkl = 3 * (ncv * ncv) + 6 * ncv; + + // Real work array of length lworkl. + workl.zeros(lworkl); + + // info = 0; // resid to be filled with random values by ARPACK (non-deterministic) + info = 1; // resid is already filled with random values (deterministic) + + superlu_opts superlu_opts_default; + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, superlu_opts_default); + int lwork = 0; + superlu::trans_t trans = superlu::NOTRANS; + + superlu::GlobalLU_t Glu; /* Not needed on return. */ + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + superlu_supermatrix_wrangler x; + superlu_supermatrix_wrangler xC; + + const bool status_x = sp_auxlib::copy_to_supermatrix_with_shift(x.get_ref(), X, sigma); + + if(status_x == false) + { + arma_stop_runtime_error("run_aupd_shiftinvert(): could not construct SuperLU matrix"); + info = blas_int(-1); + return; + } + + // // for debugging only + // if(true) + // { + // cout << "*** testing output of copy_to_supermatrix_with_shift()" << endl; + // cout << "*** sigma: " << sigma << endl; + // + // SpMat Y(X); + // Y.diag() -= sigma; + // + // SpMat Z; + // + // sp_auxlib::copy_to_spmat(Z, x.get_ref()); + // + // cout << "*** size(Y): " << arma::size(Y) << endl; + // cout << "*** size(Z): " << arma::size(Z) << endl; + // cout << "*** accu(abs(Y)): " << accu(abs(Y)) << endl; + // cout << "*** accu(abs(Z)): " << accu(abs(Z)) << endl; + // + // if(arma::size(Y) == arma::size(Z)) + // { + // cout << "*** error: " << accu(abs(Y-Z)) << endl; + // } + // } + + superlu_supermatrix_wrangler l; + superlu_supermatrix_wrangler u; + + superlu_array_wrangler perm_c(X.n_cols+1); // paranoia: increase array length by 1 + superlu_array_wrangler perm_r(X.n_rows+1); + superlu_array_wrangler etree(X.n_cols+1); + + superlu_stat_wrangler stat; + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int slu_info = 0; // Return code. + + arma_extra_debug_print("superlu::gstrf()"); + superlu::get_permutation_c(options.ColPerm, x.get_ptr(), perm_c.get_ptr()); + superlu::sp_preorder_mat(&options, x.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), xC.get_ptr()); + superlu::gstrf(&options, xC.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l.get_ptr(), u.get_ptr(), &Glu, stat.get_ptr(), &slu_info); + + if(slu_info != 0) + { + arma_debug_warn_level(2, "matrix is singular to working precision"); + info = blas_int(-1); + return; + } + + // NOTE: potential problem with inconsistent/mismatched use of eT and T types + eT x_norm_val = sp_auxlib::norm1(x.get_ptr()); + eT x_rcond = sp_auxlib::lu_rcond(l.get_ptr(), u.get_ptr(), x_norm_val); + + if( (x_rcond < std::numeric_limits::epsilon()) || arma_isnan(x_rcond) ) + { + arma_debug_warn_level(2, "matrix is singular to working precision (rcond: ", x_rcond, ")"); + info = blas_int(-1); + return; + } + + // All the parameters have been set or created. Time to loop a lot. + while(ido != 99) + { + // Call saupd() or naupd() with the current parameters. + if(sym) + { + arma_extra_debug_print("arpack::saupd()"); + arpack::saupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, &info); + } + else + { + arma_extra_debug_print("arpack::naupd()"); + arpack::naupd(&ido, &bmat, &n, which, &nev, &tol, resid.memptr(), &ncv, v.memptr(), &ldv, iparam.memptr(), ipntr.memptr(), workd.memptr(), workl.memptr(), &lworkl, rwork.memptr(), &info); + } + + // What do we do now? + switch (ido) + { + case -1: + // fallthrough + case 1: + { + // We need to calculate the matrix-vector multiplication y = OP * x + // where x is of length n and starts at workd(ipntr(0)), and y is of + // length n and starts at workd(ipntr(1)). + + // Set the output to point at the right memory. We have to subtract + // one from FORTRAN pointers... + Col out(workd.memptr() + ipntr(1) - 1, n, false /* don't copy */); + // Set the input to point at the right memory. Col in(workd.memptr() + ipntr(0) - 1, n, false /* don't copy */); - out.zeros(); - typename SpMat::const_iterator x_it = X.begin(); - typename SpMat::const_iterator x_it_end = X.end(); + // Consider getting the LU factorization from ZGSTRF, and then + // solve the system L*U*out = in (possibly with permutation matrix?) + // Instead of "spsolve(out,X,in)" we call gstrf above and gstrs below + + out = in; + superlu_supermatrix_wrangler out_slu; + + const bool status_out_slu = sp_auxlib::wrap_to_supermatrix(out_slu.get_ref(), out); - while(x_it != x_it_end) - { - out[x_it.row()] += (*x_it) * in[x_it.col()]; - ++x_it; - } + if(status_out_slu == false) { arma_stop_runtime_error("run_aupd_shiftinvert(): could not construct SuperLU matrix"); return; } + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(trans, l.get_ptr(), u.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), out_slu.get_ptr(), stat.get_ptr(), &info); // No need to modify memory further since it was all done in-place. @@ -1410,23 +2303,25 @@ sp_auxlib::run_aupd if(sym) { - arma_debug_warn("eigs_sym(): ARPACK error ", info, " in saupd()"); + arma_debug_warn_level(2, "eigs_sym(): ARPACK error ", info, " in saupd()"); } else { - arma_debug_warn("eigs_gen(): ARPACK error ", info, " in naupd()"); + arma_debug_warn_level(2, "eigs_gen(): ARPACK error ", info, " in naupd()"); } return; // Parent frame can look at the value of info. } } #else + { arma_ignore(n_eigvals); - arma_ignore(which); + arma_ignore(sigma); arma_ignore(X); arma_ignore(sym); arma_ignore(n); arma_ignore(tol); + arma_ignore(maxiter); arma_ignore(resid); arma_ignore(ncv); arma_ignore(v); @@ -1438,6 +2333,7 @@ sp_auxlib::run_aupd arma_ignore(lworkl); arma_ignore(rwork); arma_ignore(info); + } #endif } @@ -1534,6 +2430,12 @@ sp_auxlib::rudimentary_sym_check(const SpMat< std::complex >& X) ++n_check; } + else + { + const eT A = (*it); + + if(std::abs(A.imag()) > tol) { return false; } + } ++it; } @@ -1543,4 +2445,370 @@ sp_auxlib::rudimentary_sym_check(const SpMat< std::complex >& X) +// + + + +template +inline +eigs_randu_filler::eigs_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +eigs_randu_filler::fill(podarray& X, const uword N) + { + arma_extra_debug_sigprint(); + + X.set_size(N); + + eT* X_mem = X.memptr(); + + for(uword i=0; i +inline +eigs_randu_filler< std::complex >::eigs_randu_filler() + { + arma_extra_debug_sigprint(); + + typedef typename std::mt19937_64::result_type local_seed_type; + + local_engine.seed(local_seed_type(123)); + + typedef typename std::uniform_real_distribution::param_type local_param_type; + + local_u_distr.param(local_param_type(-1.0, +1.0)); + } + + +template +inline +void +eigs_randu_filler< std::complex >::fill(podarray< std::complex >& X, const uword N) + { + arma_extra_debug_sigprint(); + + typedef typename std::complex eT; + + X.set_size(N); + + eT* X_mem = X.memptr(); + + for(uword i=0; i(&m); + bool all_zero = true; + + for(size_t i=0; i < sizeof(superlu::SuperMatrix); ++i) + { + if(m_char[i] != char(0)) { all_zero = false; break; } + } + + if(all_zero == false) { sp_auxlib::destroy_supermatrix(m); } + } + +inline +superlu_supermatrix_wrangler::superlu_supermatrix_wrangler() + { + arma_extra_debug_sigprint_this(this); + + arrayops::fill_zeros(reinterpret_cast(&m), sizeof(superlu::SuperMatrix)); + } + +inline +superlu::SuperMatrix& +superlu_supermatrix_wrangler::get_ref() + { + used = true; + + return m; + } + +inline +superlu::SuperMatrix* +superlu_supermatrix_wrangler::get_ptr() + { + used = true; + + return &m; + } + + +// + + +inline +superlu_stat_wrangler::~superlu_stat_wrangler() + { + arma_extra_debug_sigprint_this(this); + + superlu::free_stat(&stat); + } + +inline +superlu_stat_wrangler::superlu_stat_wrangler() + { + arma_extra_debug_sigprint_this(this); + + arrayops::fill_zeros(reinterpret_cast(&stat), sizeof(superlu::SuperLUStat_t)); + + superlu::init_stat(&stat); + } + +inline +superlu::SuperLUStat_t* +superlu_stat_wrangler::get_ptr() + { + return &stat; + } + + +// + + +template +inline +superlu_array_wrangler::~superlu_array_wrangler() + { + arma_extra_debug_sigprint_this(this); + + (*this).reset(); + } + +template +inline +superlu_array_wrangler::superlu_array_wrangler() + : mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + } + +template +inline +superlu_array_wrangler::superlu_array_wrangler(const uword n_elem) + : mem(nullptr) + { + arma_extra_debug_sigprint_this(this); + + (*this).set_size(n_elem); + } + +template +inline +void +superlu_array_wrangler::set_size(const uword n_elem) + { + arma_extra_debug_sigprint(); + + if(mem != nullptr) { (*this).reset(); } + + mem = (eT*)(superlu::malloc(n_elem * sizeof(eT))); + + arma_check_bad_alloc( (mem == nullptr), "superlu::malloc(): out of memory" ); + + arrayops::fill_zeros(mem, n_elem); + } + +template +inline +void +superlu_array_wrangler::reset() + { + arma_extra_debug_sigprint(); + + if(mem != nullptr) + { + superlu::free(mem); + mem = nullptr; + } + } + +template +inline +eT* +superlu_array_wrangler::get_ptr() + { + return mem; + } + + +// + + +template +inline +superlu_worker::~superlu_worker() + { + arma_extra_debug_sigprint_this(this); + + if(l != nullptr) { delete l; l = nullptr; } + if(u != nullptr) { delete u; u = nullptr; } + } + + +template +inline +superlu_worker::superlu_worker() + { + arma_extra_debug_sigprint_this(this); + } + + +template +inline +bool +superlu_worker::factorise(typename get_pod_type::result& out_rcond, const SpMat& A, const superlu_opts& user_opts) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + factorisation_valid = false; + + if(l != nullptr) { delete l; l = nullptr; } + if(u != nullptr) { delete u; u = nullptr; } + + l = new(std::nothrow) superlu_supermatrix_wrangler; + u = new(std::nothrow) superlu_supermatrix_wrangler; + + if( (l == nullptr) || (u == nullptr) ) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): could not construct SuperLU matrix"); + return false; + } + + superlu_supermatrix_wrangler& l_ref = (*l); + superlu_supermatrix_wrangler& u_ref = (*u); + + superlu::superlu_options_t options; + sp_auxlib::set_superlu_opts(options, user_opts); + + superlu_supermatrix_wrangler AA; + superlu_supermatrix_wrangler AAc; + + const bool status_AA = sp_auxlib::copy_to_supermatrix(AA.get_ref(), A); + + if(status_AA == false) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): could not construct SuperLU matrix"); + return false; + } + + (*this).perm_c.set_size(A.n_cols+1); // paranoia: increase array length by 1 + (*this).perm_r.set_size(A.n_rows+1); + + superlu_array_wrangler etree(A.n_cols+1); + + superlu::GlobalLU_t Glu; + arrayops::fill_zeros(reinterpret_cast(&Glu), sizeof(superlu::GlobalLU_t)); + + int panel_size = superlu::sp_ispec_environ(1); + int relax = superlu::sp_ispec_environ(2); + int lwork = 0; + int info = 0; + + arma_extra_debug_print("superlu::superlu::get_permutation_c()"); + superlu::get_permutation_c(options.ColPerm, AA.get_ptr(), perm_c.get_ptr()); + + arma_extra_debug_print("superlu::superlu::sp_preorder_mat()"); + superlu::sp_preorder_mat(&options, AA.get_ptr(), perm_c.get_ptr(), etree.get_ptr(), AAc.get_ptr()); + + arma_extra_debug_print("superlu::gstrf()"); + superlu::gstrf(&options, AAc.get_ptr(), relax, panel_size, etree.get_ptr(), NULL, lwork, perm_c.get_ptr(), perm_r.get_ptr(), l_ref.get_ptr(), u_ref.get_ptr(), &Glu, stat.get_ptr(), &info); + + if(info != 0) + { + arma_debug_warn_level(3, "superlu_worker()::factorise(): LU factorisation failed"); + return false; + } + + const T AA_norm = sp_auxlib::norm1(AA.get_ptr()); + const T AA_rcond = sp_auxlib::lu_rcond(l_ref.get_ptr(), u_ref.get_ptr(), AA_norm); + + out_rcond = AA_rcond; + + if(arma_isnan(AA_rcond)) { return false; } + // if(AA_rcond == T(0)) { return false; } + + factorisation_valid = true; + + return true; + } + + +template +inline +bool +superlu_worker::solve(Mat& X, const Mat& B) + { + arma_extra_debug_sigprint(); + + if(factorisation_valid == false) { return false; } + if( (l == nullptr) || (u == nullptr) ) { return false; } + + superlu_supermatrix_wrangler& l_ref = (*l); + superlu_supermatrix_wrangler& u_ref = (*u); + + X = B; + + superlu_supermatrix_wrangler XX; + + const bool status_XX = sp_auxlib::wrap_to_supermatrix(XX.get_ref(), X); + + if(status_XX == false) + { + arma_debug_warn_level(3, "superlu_worker()::solve(): could not construct SuperLU matrix"); + return false; + } + + superlu::trans_t trans = superlu::NOTRANS; + int info = 0; + + arma_extra_debug_print("superlu::gstrs()"); + superlu::gstrs(trans, l_ref.get_ptr(), u_ref.get_ptr(), perm_c.get_ptr(), perm_r.get_ptr(), XX.get_ptr(), stat.get_ptr(), &info); + + return (info == 0); + } + + +#endif + + //! @} diff --git a/src/armadillo_bits/span.hpp b/src/armadillo_bits/span.hpp index 778731e4..14774f14 100644 --- a/src/armadillo_bits/span.hpp +++ b/src/armadillo_bits/span.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -70,8 +72,8 @@ class span : public span_base<> } - // the "explicit" keyword is required here to prevent a C++11 compiler - // automatically converting {a,b} into an instance of span() when submatrices are specified + // the "explicit" keyword is required here to prevent automatic conversion of {a,b} + // into an instance of span() when submatrices are specified inline explicit span(const uword in_a, const uword in_b) diff --git a/src/armadillo_bits/spdiagview_bones.hpp b/src/armadillo_bits/spdiagview_bones.hpp index 521e7cef..238e8a38 100644 --- a/src/armadillo_bits/spdiagview_bones.hpp +++ b/src/armadillo_bits/spdiagview_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,7 @@ //! Class for storing data required to extract and set the diagonals of a sparse matrix template -class spdiagview : public SpBase > +class spdiagview : public SpBase< eT, spdiagview > { public: @@ -29,9 +31,9 @@ class spdiagview : public SpBase > arma_aligned const SpMat& m; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; const uword row_offset; const uword col_offset; @@ -39,7 +41,7 @@ class spdiagview : public SpBase > const uword n_rows; // equal to n_elem const uword n_elem; - static const uword n_cols = 1; + static constexpr uword n_cols = 1; protected: @@ -50,6 +52,7 @@ class spdiagview : public SpBase > public: inline ~spdiagview(); + inline spdiagview() = delete; inline void operator=(const spdiagview& x); @@ -86,6 +89,12 @@ class spdiagview : public SpBase > inline eT operator()(const uword in_n_row, const uword in_n_col) const; + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); @@ -97,10 +106,7 @@ class spdiagview : public SpBase > inline static void extract( Mat& out, const spdiagview& in); - private: - friend class SpMat; - spdiagview(); }; diff --git a/src/armadillo_bits/spdiagview_meat.hpp b/src/armadillo_bits/spdiagview_meat.hpp index f9f97062..603cadc9 100644 --- a/src/armadillo_bits/spdiagview_meat.hpp +++ b/src/armadillo_bits/spdiagview_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -186,6 +188,28 @@ spdiagview::operator= (const Base& o) const uword d_row_offset = d.row_offset; const uword d_col_offset = d.col_offset; + if(is_same_type< T1, Gen, gen_zeros> >::yes) + { + const Proxy P(o.get_ref()); + + arma_debug_check( (d_n_elem != P.get_n_elem()), "spdiagview: given object has incompatible size" ); + + (*this).zeros(); + + return; + } + + if(is_same_type< T1, Gen, gen_ones> >::yes) + { + const Proxy P(o.get_ref()); + + arma_debug_check( (d_n_elem != P.get_n_elem()), "spdiagview: given object has incompatible size" ); + + (*this).ones(); + + return; + } + const quasi_unwrap U(o.get_ref()); const Mat& x = U.M; @@ -216,6 +240,8 @@ spdiagview::operator= (const Base& o) if(has_zero) { tmp1.remove_zeros(); } + if(tmp1.n_nonzero == 0) { (*this).zeros(); return; } + SpMat tmp2; spglue_merge::diagview_merge(tmp2, d_m, tmp1); @@ -683,7 +709,7 @@ spdiagview::extract(SpMat& out, const spdiagview& d) const uword d_row_offset = d.row_offset; const uword d_col_offset = d.col_offset; - Col cache(d_n_elem); + Col cache(d_n_elem, arma_nozeros_indicator()); eT* cache_mem = cache.memptr(); uword d_n_nonzero = 0; @@ -790,7 +816,7 @@ inline SpMat_MapMat_val spdiagview::operator()(const uword i) { - arma_debug_check( (i >= n_elem), "spdiagview::operator(): out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "spdiagview::operator(): out of bounds" ); return (const_cast< SpMat& >(m)).at(i+row_offset, i+col_offset); } @@ -802,7 +828,7 @@ inline eT spdiagview::operator()(const uword i) const { - arma_debug_check( (i >= n_elem), "spdiagview::operator(): out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "spdiagview::operator(): out of bounds" ); return m.at(i+row_offset, i+col_offset); } @@ -834,7 +860,7 @@ inline SpMat_MapMat_val spdiagview::operator()(const uword row, const uword col) { - arma_debug_check( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); return (const_cast< SpMat& >(m)).at(row+row_offset, row+col_offset); } @@ -846,13 +872,68 @@ inline eT spdiagview::operator()(const uword row, const uword col) const { - arma_debug_check( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); + arma_debug_check_bounds( ((row >= n_elem) || (col > 0)), "spdiagview::operator(): out of bounds" ); return m.at(row+row_offset, row+col_offset); } +template +inline +void +spdiagview::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + if(old_val == eT(0)) + { + arma_debug_warn_level(1, "spdiagview::replace(): replacement not done, as old_val = 0"); + } + else + { + Mat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + } + + + +template +inline +void +spdiagview::clean(const typename get_pod_type::result threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +spdiagview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + SpMat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + template inline void @@ -860,7 +941,7 @@ spdiagview::fill(const eT val) { arma_extra_debug_sigprint(); - if( (row_offset == 0) && (col_offset == 0) ) + if( (row_offset == 0) && (col_offset == 0) && (m.sync_state != 1) ) { if(val == eT(0)) { diff --git a/src/armadillo_bits/spglue_elem_helper_meat.hpp b/src/armadillo_bits/spglue_elem_helper_meat.hpp deleted file mode 100644 index 7a88418d..00000000 --- a/src/armadillo_bits/spglue_elem_helper_meat.hpp +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -//! \addtogroup spglue_elem_helper -//! @{ - - - -template -arma_hot -inline -uword -spglue_elem_helper::max_n_nonzero_plus(const SpProxy& pa, const SpProxy& pb) - { - arma_extra_debug_sigprint(); - - // assuming that pa and pb have the same size, ie. - // pa.get_n_rows() == pb.get_n_rows() - // pa.get_n_cols() == pb.get_n_cols() - - typename SpProxy::const_iterator_type x_it = pa.begin(); - typename SpProxy::const_iterator_type x_end = pa.end(); - - typename SpProxy::const_iterator_type y_it = pb.begin(); - typename SpProxy::const_iterator_type y_end = pb.end(); - - uword count = 0; - - while( (x_it != x_end) || (y_it != y_end) ) - { - if(x_it == y_it) - { - ++x_it; - ++y_it; - } - else - { - const uword x_it_col = x_it.col(); - const uword x_it_row = x_it.row(); - - const uword y_it_col = y_it.col(); - const uword y_it_row = y_it.row(); - - if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end - { - ++x_it; - } - else - { - ++y_it; - } - } - - ++count; - } - - return count; - } - - - -template -arma_hot -inline -uword -spglue_elem_helper::max_n_nonzero_schur(const SpProxy& pa, const SpProxy& pb) - { - arma_extra_debug_sigprint(); - - // assuming that pa and pb have the same size, ie. - // pa.get_n_rows() == pb.get_n_rows() - // pa.get_n_cols() == pb.get_n_cols() - - typename SpProxy::const_iterator_type x_it = pa.begin(); - typename SpProxy::const_iterator_type x_end = pa.end(); - - typename SpProxy::const_iterator_type y_it = pb.begin(); - typename SpProxy::const_iterator_type y_end = pb.end(); - - uword count = 0; - - while( (x_it != x_end) || (y_it != y_end) ) - { - if(x_it == y_it) - { - ++x_it; - ++y_it; - - ++count; - } - else - { - const uword x_it_col = x_it.col(); - const uword x_it_row = x_it.row(); - - const uword y_it_col = y_it.col(); - const uword y_it_row = y_it.row(); - - if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end - { - ++x_it; - } - else - { - ++y_it; - } - } - } - - return count; - } - - - -//! @} diff --git a/src/armadillo_bits/spglue_join_bones.hpp b/src/armadillo_bits/spglue_join_bones.hpp index 9251ea64..93829b70 100644 --- a/src/armadillo_bits/spglue_join_bones.hpp +++ b/src/armadillo_bits/spglue_join_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class spglue_join_cols template struct traits { - static const bool is_row = false; - static const bool is_col = (T1::is_col && T2::is_col); - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; }; template @@ -53,9 +55,9 @@ class spglue_join_rows template struct traits { - static const bool is_row = (T1::is_row && T2::is_row); - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/spglue_join_meat.hpp b/src/armadillo_bits/spglue_join_meat.hpp index d6e58c26..4a3e2441 100644 --- a/src/armadillo_bits/spglue_join_meat.hpp +++ b/src/armadillo_bits/spglue_join_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -250,8 +252,8 @@ spglue_join_rows::apply_noalias(SpMat& out, const SpMat& A, const SpMat< // // OLD METHOD // - // umat locs(2, C_n_nz); - // Col vals( C_n_nz); + // umat locs(2, C_n_nz, arma_nozeros_indicator()); + // Col vals( C_n_nz, arma_nozeros_indicator()); // // uword* locs_mem = locs.memptr(); // eT* vals_mem = vals.memptr(); diff --git a/src/armadillo_bits/spglue_kron_bones.hpp b/src/armadillo_bits/spglue_kron_bones.hpp index 365f4fbd..e0d33b20 100644 --- a/src/armadillo_bits/spglue_kron_bones.hpp +++ b/src/armadillo_bits/spglue_kron_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class spglue_kron template struct traits { - static const bool is_row = (T1::is_row && T2::is_row); - static const bool is_col = (T1::is_col && T2::is_col); - static const bool is_xvec = false; + static constexpr bool is_row = (T1::is_row && T2::is_row); + static constexpr bool is_col = (T1::is_col && T2::is_col); + static constexpr bool is_xvec = false; }; template diff --git a/src/armadillo_bits/spglue_kron_meat.hpp b/src/armadillo_bits/spglue_kron_meat.hpp index 0784a5a0..b45f3e76 100644 --- a/src/armadillo_bits/spglue_kron_meat.hpp +++ b/src/armadillo_bits/spglue_kron_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -109,8 +111,8 @@ spglue_kron::apply_noalias(SpMat& out, const SpMat& A, const SpMat& // const SpMat& A = UA.M; // const SpMat& B = UB.M; // -// umat locs(2, A.n_nonzero * B.n_nonzero); -// Col vals( A.n_nonzero * B.n_nonzero); +// umat locs(2, A.n_nonzero * B.n_nonzero, arma_nozeros_indicator()); +// Col vals( A.n_nonzero * B.n_nonzero, arma_nozeros_indicator()); // // uword* locs_mem = locs.memptr(); // eT* vals_mem = vals.memptr(); diff --git a/src/armadillo_bits/spglue_max_bones.hpp b/src/armadillo_bits/spglue_max_bones.hpp index 85942fc7..156eeb52 100644 --- a/src/armadillo_bits/spglue_max_bones.hpp +++ b/src/armadillo_bits/spglue_max_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spglue_max_meat.hpp b/src/armadillo_bits/spglue_max_meat.hpp index 7782746e..4ee19a62 100644 --- a/src/armadillo_bits/spglue_max_meat.hpp +++ b/src/armadillo_bits/spglue_max_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,13 +58,9 @@ spglue_max::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& out, const SpProxy& pa, const SpProxy max_n_nonzero), "internal error: spglue_max::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -184,7 +184,7 @@ spglue_max::dense_sparse_max(Mat& out, const Base& X, const SpBase - arma_hot inline static void subview_merge(SpSubview& sv, const SpMat& B); + inline static void subview_merge(SpSubview& sv, const SpMat& B); template - arma_hot inline static void subview_merge(SpSubview& sv, const Mat& B); + inline static void subview_merge(SpSubview& sv, const Mat& B); template - arma_hot inline static void symmat_merge(SpMat& out, const SpMat& A, const SpMat& B); + inline static void symmat_merge(SpMat& out, const SpMat& A, const SpMat& B); template - arma_hot inline static void diagview_merge(SpMat& out, const SpMat& A, const SpMat& B); + inline static void diagview_merge(SpMat& out, const SpMat& A, const SpMat& B); }; diff --git a/src/armadillo_bits/spglue_merge_meat.hpp b/src/armadillo_bits/spglue_merge_meat.hpp index 45d539a0..18339da3 100644 --- a/src/armadillo_bits/spglue_merge_meat.hpp +++ b/src/armadillo_bits/spglue_merge_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void spglue_merge::subview_merge(SpSubview& sv, const SpMat& B) @@ -174,7 +175,7 @@ spglue_merge::subview_merge(SpSubview& sv, const SpMat& B) y_it_valid = (y_it != y_end); } - arma_check( (count != merge_n_nonzero), "spglue_merge::subview_merge(): internal error: count != merge_n_nonzero" ); + arma_check( (count != merge_n_nonzero), "internal error: spglue_merge::subview_merge(): count != merge_n_nonzero" ); const uword out_n_cols = out.n_cols; @@ -193,7 +194,6 @@ spglue_merge::subview_merge(SpSubview& sv, const SpMat& B) template -arma_hot inline void spglue_merge::subview_merge(SpSubview& sv, const Mat& B) @@ -362,7 +362,7 @@ spglue_merge::subview_merge(SpSubview& sv, const Mat& B) y_it_valid = (y_it != y_end); } - arma_check( (count != merge_n_nonzero), "spglue_merge::subview_merge(): internal error: count != merge_n_nonzero" ); + arma_check( (count != merge_n_nonzero), "internal error: spglue_merge::subview_merge(): count != merge_n_nonzero" ); const uword out_n_cols = out.n_cols; @@ -381,7 +381,6 @@ spglue_merge::subview_merge(SpSubview& sv, const Mat& B) template -arma_hot inline void spglue_merge::symmat_merge(SpMat& out, const SpMat& A, const SpMat& B) @@ -466,13 +465,14 @@ spglue_merge::symmat_merge(SpMat& out, const SpMat& A, const SpMat& template -arma_hot inline void spglue_merge::diagview_merge(SpMat& out, const SpMat& A, const SpMat& B) { arma_extra_debug_sigprint(); + // NOTE: assuming that B has non-zero elements only on the main diagonal + out.reserve(A.n_rows, A.n_cols, A.n_nonzero + B.n_nonzero); // worst case scenario typename SpMat::const_iterator x_it = A.begin(); @@ -485,7 +485,7 @@ spglue_merge::diagview_merge(SpMat& out, const SpMat& A, const SpMat while( (x_it != x_end) || (y_it != y_end) ) { - eT out_val; + eT out_val = eT(0); const uword x_it_col = x_it.col(); const uword x_it_row = x_it.row(); @@ -508,28 +508,29 @@ spglue_merge::diagview_merge(SpMat& out, const SpMat& A, const SpMat { if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end { - out_val = (*x_it); + if(x_it_col != x_it_row) { out_val = (*x_it); } // don't take values from the main diagonal of A ++x_it; } else { - out_val = (*y_it); + if(y_it_col == y_it_row) { out_val = (*y_it); use_y_loc = true; } // take values only from the main diagonal of B ++y_it; - - use_y_loc = true; } } - access::rw(out.values[count]) = out_val; - - const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; - const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; - - access::rw(out.row_indices[count]) = out_row; - access::rw(out.col_ptrs[out_col + 1])++; - ++count; + if(out_val != eT(0)) + { + access::rw(out.values[count]) = out_val; + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + } } const uword out_n_cols = out.n_cols; diff --git a/src/armadillo_bits/spglue_min_bones.hpp b/src/armadillo_bits/spglue_min_bones.hpp index c1db6e57..93e8c593 100644 --- a/src/armadillo_bits/spglue_min_bones.hpp +++ b/src/armadillo_bits/spglue_min_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spglue_min_meat.hpp b/src/armadillo_bits/spglue_min_meat.hpp index 6db86433..cdfc1971 100644 --- a/src/armadillo_bits/spglue_min_meat.hpp +++ b/src/armadillo_bits/spglue_min_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -56,13 +58,9 @@ spglue_min::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& out, const SpProxy& pa, const SpProxy max_n_nonzero), "internal error: spglue_min::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -184,7 +184,7 @@ spglue_min::dense_sparse_min(Mat& out, const Base& X, const SpBase - arma_hot inline static void apply(SpMat& out, const SpGlue& X); + inline static void apply(SpMat& out, const SpGlue& X); template - arma_hot inline static void apply_noalias(SpMat& result, const SpProxy& pa, const SpProxy& pb); + inline static void apply_noalias(SpMat& result, const SpProxy& pa, const SpProxy& pb); template - arma_hot inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); }; diff --git a/src/armadillo_bits/spglue_minus_meat.hpp b/src/armadillo_bits/spglue_minus_meat.hpp index 9d31a487..1ad71610 100644 --- a/src/armadillo_bits/spglue_minus_meat.hpp +++ b/src/armadillo_bits/spglue_minus_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void spglue_minus::apply(SpMat& out, const SpGlue& X) @@ -51,7 +52,6 @@ spglue_minus::apply(SpMat& out, const SpGlue -arma_hot inline void spglue_minus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) @@ -63,7 +63,7 @@ spglue_minus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy if(pa.get_n_nonzero() == 0) { out = pb.Q; out *= eT(-1); return; } if(pb.get_n_nonzero() == 0) { out = pa.Q; return; } - const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(pa, pb); + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); // Resize memory to upper bound out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); @@ -125,6 +125,8 @@ spglue_minus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy access::rw(out.col_ptrs[out_col + 1])++; ++count; } + + arma_check( (count > max_n_nonzero), "internal error: spglue_minus::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -156,7 +158,6 @@ spglue_minus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy template -arma_hot inline void spglue_minus::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) diff --git a/src/armadillo_bits/spglue_plus_bones.hpp b/src/armadillo_bits/spglue_plus_bones.hpp index f3ed2a23..b92cb71c 100644 --- a/src/armadillo_bits/spglue_plus_bones.hpp +++ b/src/armadillo_bits/spglue_plus_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,13 +27,13 @@ class spglue_plus public: template - arma_hot inline static void apply(SpMat& out, const SpGlue& X); + inline static void apply(SpMat& out, const SpGlue& X); template - arma_hot inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); template - arma_hot inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); }; diff --git a/src/armadillo_bits/spglue_plus_meat.hpp b/src/armadillo_bits/spglue_plus_meat.hpp index 83434ee1..b8eada0b 100644 --- a/src/armadillo_bits/spglue_plus_meat.hpp +++ b/src/armadillo_bits/spglue_plus_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void spglue_plus::apply(SpMat& out, const SpGlue& X) @@ -51,7 +52,6 @@ spglue_plus::apply(SpMat& out, const SpGlue -arma_hot inline void spglue_plus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) @@ -63,7 +63,7 @@ spglue_plus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy< if(pa.get_n_nonzero() == 0) { out = pb.Q; return; } if(pb.get_n_nonzero() == 0) { out = pa.Q; return; } - const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(pa, pb); + const uword max_n_nonzero = pa.get_n_nonzero() + pb.get_n_nonzero(); // Resize memory to upper bound out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); @@ -125,6 +125,8 @@ spglue_plus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy< access::rw(out.col_ptrs[out_col + 1])++; ++count; } + + arma_check( (count > max_n_nonzero), "internal error: spglue_plus::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -156,7 +158,6 @@ spglue_plus::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy< template -arma_hot inline void spglue_plus::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) diff --git a/src/armadillo_bits/spglue_relational_bones.hpp b/src/armadillo_bits/spglue_relational_bones.hpp index 4018e548..f84caf46 100644 --- a/src/armadillo_bits/spglue_relational_bones.hpp +++ b/src/armadillo_bits/spglue_relational_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -47,4 +49,32 @@ class spglue_rel_gt +class spglue_rel_and + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + +class spglue_rel_or + : public traits_glue_or + { + public: + + template + inline static void apply(SpMat& out, const mtSpGlue& X); + + template + inline static void apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB); + }; + + + //! @} diff --git a/src/armadillo_bits/spglue_relational_meat.hpp b/src/armadillo_bits/spglue_relational_meat.hpp index 495a7380..92564abc 100644 --- a/src/armadillo_bits/spglue_relational_meat.hpp +++ b/src/armadillo_bits/spglue_relational_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -58,7 +60,7 @@ spglue_rel_lt::apply_noalias(SpMat& out, const SpProxy& PA, const SpP arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator<"); - const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(PA, PB); + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); // Resize memory to upper bound out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); @@ -120,6 +122,8 @@ spglue_rel_lt::apply_noalias(SpMat& out, const SpProxy& PA, const SpP access::rw(out.col_ptrs[out_col + 1])++; ++count; } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_lt::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -150,6 +154,10 @@ spglue_rel_lt::apply_noalias(SpMat& out, const SpProxy& PA, const SpP +// + + + template inline void @@ -189,7 +197,7 @@ spglue_rel_gt::apply_noalias(SpMat& out, const SpProxy& PA, const SpP arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator>"); - const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_plus(PA, PB); + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); // Resize memory to upper bound out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); @@ -251,6 +259,259 @@ spglue_rel_gt::apply_noalias(SpMat& out, const SpProxy& PA, const SpP access::rw(out.col_ptrs[out_col + 1])++; ++count; } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_gt::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_rel_and::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_and::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_and::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_and::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator&&"); + + if( (PA.get_n_nonzero() == 0) || (PB.get_n_nonzero() == 0) ) + { + out.zeros(PA.get_n_rows(), PA.get_n_cols()); + return; + } + + const uword max_n_nonzero = (std::min)(PA.get_n_nonzero(), PB.get_n_nonzero()); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + const uword x_it_row = x_it.row(); + const uword x_it_col = x_it.col(); + + const uword y_it_row = y_it.row(); + const uword y_it_col = y_it.col(); + + if(x_it == y_it) + { + access::rw(out.values[count]) = uword(1); + + access::rw(out.row_indices[count]) = x_it_row; + access::rw(out.col_ptrs[x_it_col + 1])++; + ++count; + + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else + { + ++y_it; + } + } + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_and::apply_noalias(): count > max_n_nonzero" ); + } + + const uword out_n_cols = out.n_cols; + + uword* col_ptrs = access::rwp(out.col_ptrs); + + // Fix column pointers to be cumulative. + for(uword c = 1; c <= out_n_cols; ++c) + { + col_ptrs[c] += col_ptrs[c - 1]; + } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } + } + + + +// + + + +template +inline +void +spglue_rel_or::apply(SpMat& out, const mtSpGlue& X) + { + arma_extra_debug_sigprint(); + + const SpProxy PA(X.A); + const SpProxy PB(X.B); + + const bool is_alias = PA.is_alias(out) || PB.is_alias(out); + + if(is_alias == false) + { + spglue_rel_or::apply_noalias(out, PA, PB); + } + else + { + SpMat tmp; + + spglue_rel_or::apply_noalias(tmp, PA, PB); + + out.steal_mem(tmp); + } + } + + + +template +inline +void +spglue_rel_or::apply_noalias(SpMat& out, const SpProxy& PA, const SpProxy& PB) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + arma_debug_assert_same_size(PA.get_n_rows(), PA.get_n_cols(), PB.get_n_rows(), PB.get_n_cols(), "operator||"); + + const uword max_n_nonzero = PA.get_n_nonzero() + PB.get_n_nonzero(); + + // Resize memory to upper bound + out.reserve(PA.get_n_rows(), PA.get_n_cols(), max_n_nonzero); + + // Now iterate across both matrices. + typename SpProxy::const_iterator_type x_it = PA.begin(); + typename SpProxy::const_iterator_type x_end = PA.end(); + + typename SpProxy::const_iterator_type y_it = PB.begin(); + typename SpProxy::const_iterator_type y_end = PB.end(); + + uword count = 0; + + while( (x_it != x_end) || (y_it != y_end) ) + { + const uword x_it_col = x_it.col(); + const uword x_it_row = x_it.row(); + + const uword y_it_col = y_it.col(); + const uword y_it_row = y_it.row(); + + bool use_y_loc = false; + + if(x_it == y_it) + { + ++x_it; + ++y_it; + } + else + { + if((x_it_col < y_it_col) || ((x_it_col == y_it_col) && (x_it_row < y_it_row))) // if y is closer to the end + { + ++x_it; + } + else + { + ++y_it; + + use_y_loc = true; + } + } + + access::rw(out.values[count]) = uword(1); + + const uword out_row = (use_y_loc == false) ? x_it_row : y_it_row; + const uword out_col = (use_y_loc == false) ? x_it_col : y_it_col; + + access::rw(out.row_indices[count]) = out_row; + access::rw(out.col_ptrs[out_col + 1])++; + ++count; + + arma_check( (count > max_n_nonzero), "internal error: spglue_rel_or::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; diff --git a/src/armadillo_bits/spglue_schur_bones.hpp b/src/armadillo_bits/spglue_schur_bones.hpp index 92a1319d..605de3a8 100644 --- a/src/armadillo_bits/spglue_schur_bones.hpp +++ b/src/armadillo_bits/spglue_schur_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -25,13 +27,13 @@ class spglue_schur public: template - arma_hot inline static void apply(SpMat& out, const SpGlue& X); + inline static void apply(SpMat& out, const SpGlue& X); template - arma_hot inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); + inline static void apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb); template - arma_hot inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); + inline static void apply_noalias(SpMat& out, const SpMat& A, const SpMat& B); }; diff --git a/src/armadillo_bits/spglue_schur_meat.hpp b/src/armadillo_bits/spglue_schur_meat.hpp index 58bc4d8f..1ad8a8f9 100644 --- a/src/armadillo_bits/spglue_schur_meat.hpp +++ b/src/armadillo_bits/spglue_schur_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void spglue_schur::apply(SpMat& out, const SpGlue& X) @@ -51,7 +52,6 @@ spglue_schur::apply(SpMat& out, const SpGlue -arma_hot inline void spglue_schur::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy& pb) @@ -66,7 +66,7 @@ spglue_schur::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy return; } - const uword max_n_nonzero = spglue_elem_helper::max_n_nonzero_schur(pa, pb); + const uword max_n_nonzero = (std::min)(pa.get_n_nonzero(), pb.get_n_nonzero()); // Resize memory to upper bound out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); @@ -115,6 +115,8 @@ spglue_schur::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy ++y_it; } } + + arma_check( (count > max_n_nonzero), "internal error: spglue_schur::apply_noalias(): count > max_n_nonzero" ); } const uword out_n_cols = out.n_cols; @@ -146,7 +148,6 @@ spglue_schur::apply_noalias(SpMat& out, const SpProxy& pa, const SpProxy template -arma_hot inline void spglue_schur::apply_noalias(SpMat& out, const SpMat& A, const SpMat& B) @@ -181,42 +182,34 @@ spglue_schur_misc::dense_schur_sparse(SpMat& out, const arma_debug_assert_same_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "element-wise multiplication"); - // count new size - uword new_n_nonzero = 0; + const uword max_n_nonzero = pb.get_n_nonzero(); + + // Resize memory to upper bound. + out.reserve(pa.get_n_rows(), pa.get_n_cols(), max_n_nonzero); + + uword count = 0; typename SpProxy::const_iterator_type it = pb.begin(); typename SpProxy::const_iterator_type it_end = pb.end(); while(it != it_end) { - if( ((*it) * pa.at(it.row(), it.col())) != eT(0) ) { ++new_n_nonzero; } + const uword it_row = it.row(); + const uword it_col = it.col(); - ++it; - } - - // Resize memory accordingly. - out.reserve(pa.get_n_rows(), pa.get_n_cols(), new_n_nonzero); - - uword count = 0; - - typename SpProxy::const_iterator_type it2 = pb.begin(); - - while(it2 != it_end) - { - const uword it2_row = it2.row(); - const uword it2_col = it2.col(); - - const eT val = (*it2) * pa.at(it2_row, it2_col); + const eT val = (*it) * pa.at(it_row, it_col); if(val != eT(0)) { access::rw( out.values[count]) = val; - access::rw( out.row_indices[count]) = it2_row; - access::rw(out.col_ptrs[it2_col + 1])++; + access::rw( out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; ++count; } - ++it2; + ++it; + + arma_check( (count > max_n_nonzero), "internal error: spglue_schur_misc::dense_schur_sparse(): count > max_n_nonzero" ); } // Fix column pointers. @@ -224,6 +217,21 @@ spglue_schur_misc::dense_schur_sparse(SpMat& out, const { access::rw(out.col_ptrs[c]) += out.col_ptrs[c - 1]; } + + if(count < max_n_nonzero) + { + if(count <= (max_n_nonzero/2)) + { + out.mem_resize(count); + } + else + { + // quick resize without reallocating memory and copying data + access::rw( out.n_nonzero) = count; + access::rw( out.values[count]) = eT(0); + access::rw(out.row_indices[count]) = uword(0); + } + } } diff --git a/src/armadillo_bits/spglue_times_bones.hpp b/src/armadillo_bits/spglue_times_bones.hpp index 8364a2cc..63c21b3d 100644 --- a/src/armadillo_bits/spglue_times_bones.hpp +++ b/src/armadillo_bits/spglue_times_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,9 +28,9 @@ class spglue_times template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template @@ -38,28 +40,7 @@ class spglue_times inline static void apply(SpMat& out, const SpGlue,T2,spglue_times>& X); template - arma_hot inline static void apply_noalias(SpMat& c, const SpMat& x, const SpMat& y); - }; - - - -class spglue_times_misc - { - public: - - template - struct traits - { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; - }; - - template - inline static void sparse_times_dense(Mat& out, const T1& x, const T2& y); - - template - inline static void dense_times_sparse(Mat& out, const T1& x, const T2& y); + inline static void apply_noalias(SpMat& c, const SpMat& x, const SpMat& y); }; @@ -71,19 +52,13 @@ class spglue_times_mixed template struct traits { - static const bool is_row = T1::is_row; - static const bool is_col = T2::is_col; - static const bool is_xvec = false; + static constexpr bool is_row = T1::is_row; + static constexpr bool is_col = T2::is_col; + static constexpr bool is_xvec = false; }; template inline static void apply(SpMat::eT>& out, const mtSpGlue::eT, T1, T2, spglue_times_mixed>& expr); - - template - inline static void sparse_times_dense(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); - - template - inline static void dense_times_sparse(Mat< typename promote_type::result >& out, const T1& X, const T2& Y); }; diff --git a/src/armadillo_bits/spglue_times_meat.hpp b/src/armadillo_bits/spglue_times_meat.hpp index 00ac98c3..852dcad6 100644 --- a/src/armadillo_bits/spglue_times_meat.hpp +++ b/src/armadillo_bits/spglue_times_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -82,7 +84,6 @@ spglue_times::apply(SpMat& out, const SpGlue -arma_hot inline void spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y) @@ -93,9 +94,9 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y const uword x_n_cols = x.n_cols; const uword y_n_rows = y.n_rows; const uword y_n_cols = y.n_cols; - + arma_debug_assert_mul_size(x_n_rows, x_n_cols, y_n_rows, y_n_cols, "matrix multiplication"); - + // First we must determine the structure of the new matrix (column pointers). // This follows the algorithm described in 'Sparse Matrix Multiplication // Package (SMMP)' (R.E. Bank and C.C. Douglas, 2001). Their description of @@ -108,11 +109,8 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y //SpMat c(x_n_rows, y_n_cols); // Initializes col_ptrs to 0. c.zeros(x_n_rows, y_n_cols); - //if( (x.n_elem == 0) || (y.n_elem == 0) ) - if( (x.n_nonzero == 0) || (y.n_nonzero == 0) ) - { - return; - } + //if( (x.n_elem == 0) || (y.n_elem == 0) ) { return; } + if( (x.n_nonzero == 0) || (y.n_nonzero == 0) ) { return; } // Auxiliary storage which denotes when items have been found. podarray index(x_n_rows); @@ -120,7 +118,7 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y typename SpMat::const_iterator y_it = y.begin(); typename SpMat::const_iterator y_end = y.end(); - + // SYMBMM: calculate column pointers for resultant matrix to obtain a good // upper bound on the number of nonzero elements. uword cur_col_length = 0; @@ -143,20 +141,20 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y last_ind = x_it_row; ++cur_col_length; } - + ++x_it; } - + const uword old_col = y_it.col(); ++y_it; - + // See if column incremented. if(old_col != y_it.col()) { // Set column pointer (this is not a cumulative count; that is done later). access::rw(c.col_ptrs[old_col + 1]) = cur_col_length; cur_col_length = 0; - + // Return index markers to zero. Use last_ind for traversal. while(last_ind != x_n_rows + 1) { @@ -167,17 +165,20 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y } } while(y_it != y_end); - + // Accumulate column pointers. for(uword i = 0; i < c.n_cols; ++i) { access::rw(c.col_ptrs[i + 1]) += c.col_ptrs[i]; } - - // Now that we know a decent bound on the number of nonzero elements, allocate - // the memory and fill it. - c.mem_resize(c.col_ptrs[c.n_cols]); - + + // Now that we know a decent bound on the number of nonzero elements, + // allocate the memory and fill it. + + const uword max_n_nonzero = c.col_ptrs[c.n_cols]; + + c.mem_resize(max_n_nonzero); + // Now the implementation of the NUMBMM algorithm. uword cur_pos = 0; // Current position in c matrix. podarray sums(x_n_rows); // Partial sums. @@ -198,15 +199,12 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y access::rw(c.col_ptrs[cur_col]) = cur_pos; ++cur_col; } - - if(cur_col == c.n_cols) - { - break; - } - + + if(cur_col == c.n_cols) { break; } + // Update current column pointer. access::rw(c.col_ptrs[cur_col]) = cur_pos; - + // Check all elements in this column. typename SpMat::const_iterator y_col_it = y.begin_col_no_sync(cur_col); @@ -217,9 +215,9 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y // Check all elements in the column of the other matrix corresponding to // the row of this column. typename SpMat::const_iterator x_col_it = x.begin_col_no_sync(y_col_it_row); - + const eT y_value = (*y_col_it); - + while(x_col_it.col() == y_col_it_row) { const uword x_col_it_row = x_col_it.row(); @@ -228,26 +226,26 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y // Add to partial sum. const eT x_value = (*x_col_it); sums[x_col_it_row] += (x_value * y_value); - + // Add point if it hasn't already been marked. if(index[x_col_it_row] == x_n_rows) { index[x_col_it_row] = last_ind; last_ind = x_col_it_row; } - + ++x_col_it; } - + ++y_col_it; } - + // Now sort the indices that were used in this column. uword cur_index = 0; while(last_ind != x_n_rows + 1) { const uword tmp = last_ind; - + // Check that it wasn't a "fake" nonzero element. if(sums[tmp] != eT(0)) { @@ -259,9 +257,9 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y last_ind = index[tmp]; index[tmp] = x_n_rows; } - + // Now sort the indices. - if (cur_index != 0) + if(cur_index != 0) { op_sort::direct_sort_ascending(sorted_indices.memptr(), cur_index); @@ -278,189 +276,21 @@ spglue_times::apply_noalias(SpMat& c, const SpMat& x, const SpMat& y // Move to next column. ++cur_col; } - - // Update last column pointer and resize to actual memory size. - access::rw(c.col_ptrs[c.n_cols]) = cur_pos; - c.mem_resize(cur_pos); - } - - - -// -// -// - - - -template -inline -void -spglue_times_misc::sparse_times_dense(Mat& out, const T1& x, const T2& y) - { - arma_extra_debug_sigprint(); - typedef typename T1::elem_type eT; + // Update last column pointer and resize to actual memory size. - if(is_op_diagmat::value) - { - const SpMat tmp(y); - - out = x * tmp; - } - else - { - const unwrap_spmat UA(x); - const quasi_unwrap UB(y); - - const SpMat& A = UA.M; - const Mat& B = UB.M; - - const uword A_n_rows = A.n_rows; - const uword A_n_cols = A.n_cols; - - const uword B_n_rows = B.n_rows; - const uword B_n_cols = B.n_cols; - - arma_debug_assert_mul_size(A_n_rows, A_n_cols, B_n_rows, B_n_cols, "matrix multiplication"); - - if(B_n_cols >= (B_n_rows / uword(100))) - { - arma_extra_debug_print("using transpose-based multiplication"); - - const SpMat At = A.st(); - const Mat Bt = B.st(); - - if(A_n_rows == B_n_cols) - { - spglue_times_misc::dense_times_sparse(out, Bt, At); - - op_strans::apply_mat(out, out); // since 'out' is square-sized, this will do an inplace transpose - } - else - { - Mat tmp; - - spglue_times_misc::dense_times_sparse(tmp, Bt, At); - - op_strans::apply_mat(out, tmp); - } - } - else - { - arma_extra_debug_print("using standard multiplication"); - - out.zeros(A_n_rows, B_n_cols); - - typename SpMat::const_iterator A_it = A.begin(); - typename SpMat::const_iterator A_it_end = A.end(); - - while(A_it != A_it_end) - { - const eT A_it_val = (*A_it); - const uword A_it_row = A_it.row(); - const uword A_it_col = A_it.col(); - - for(uword col = 0; col < B_n_cols; ++col) - { - out.at(A_it_row, col) += A_it_val * B.at(A_it_col, col); - } - - ++A_it; - } - } - } - } - - - -template -inline -void -spglue_times_misc::dense_times_sparse(Mat& out, const T1& x, const T2& y) - { - arma_extra_debug_sigprint(); + // access::rw(c.col_ptrs[c.n_cols]) = cur_pos; + // c.mem_resize(cur_pos); - typedef typename T1::elem_type eT; + access::rw(c.col_ptrs[c.n_cols]) = cur_pos; - if(is_op_diagmat::value) - { - const SpMat tmp(x); - - out = tmp * y; - } - else - { - const Proxy pa(x); - const SpProxy pb(y); - - arma_debug_assert_mul_size(pa.get_n_rows(), pa.get_n_cols(), pb.get_n_rows(), pb.get_n_cols(), "matrix multiplication"); - - out.zeros(pa.get_n_rows(), pb.get_n_cols()); - - if( (pa.get_n_elem() > 0) && (pb.get_n_nonzero() > 0) ) - { - if( (arma_config::openmp) && (mp_thread_limit::in_parallel() == false) && (pa.get_n_rows() <= (pa.get_n_cols() / uword(100))) ) - { - #if defined(ARMA_USE_OPENMP) - { - arma_extra_debug_print("using parallelised multiplication"); - - const quasi_unwrap::stored_type> UX(pa.Q); - const unwrap_spmat::stored_type> UY(pb.Q); - - const Mat& X = UX.M; - const SpMat& Y = UY.M; - - const uword Y_n_cols = Y.n_cols; - const int n_threads = mp_thread_limit::get(); - - #pragma omp parallel for schedule(static) num_threads(n_threads) - for(uword i=0; i < Y_n_cols; ++i) - { - const uword col_offset_1 = Y.col_ptrs[i ]; - const uword col_offset_2 = Y.col_ptrs[i+1]; - - const uword col_offset_delta = col_offset_2 - col_offset_1; - - const uvec indices(const_cast(&(Y.row_indices[col_offset_1])), col_offset_delta, false, false); - const Col Y_col(const_cast< eT*>(&( Y.values[col_offset_1])), col_offset_delta, false, false); - - out.col(i) = X.cols(indices) * Y_col; - } - } - #endif - } - else - { - arma_extra_debug_print("using standard multiplication"); - - typename SpProxy::const_iterator_type y_it = pb.begin(); - typename SpProxy::const_iterator_type y_it_end = pb.end(); - - const uword out_n_rows = out.n_rows; - - while(y_it != y_it_end) - { - const eT y_it_val = (*y_it); - const uword y_it_col = y_it.col(); - const uword y_it_row = y_it.row(); - - eT* out_col = out.colptr(y_it_col); - - for(uword row = 0; row < out_n_rows; ++row) - { - out_col[row] += pa.at(row, y_it_row) * y_it_val; - } - - ++y_it; - } - } - } - } + if(cur_pos < max_n_nonzero) { c.mem_resize(cur_pos); } } +// +// // @@ -536,146 +366,4 @@ spglue_times_mixed::apply(SpMat::eT>& out, const mtS -template -inline -void -spglue_times_mixed::sparse_times_dense(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT1; - typedef typename T2::elem_type eT2; - - typedef typename promote_type::result out_eT; - - promote_type::check(); - - if( (is_same_type::no) && (is_same_type::yes) ) - { - // upgrade T1 - - const unwrap_spmat UA(X); - const quasi_unwrap UB(Y); - - const SpMat& A = UA.M; - const Mat& B = UB.M; - - SpMat AA(arma_layout_indicator(), A); - - for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } - - const Mat& BB = reinterpret_cast< const Mat& >(B); - - spglue_times_misc::sparse_times_dense(out, AA, BB); - } - else - if( (is_same_type::yes) && (is_same_type::no) ) - { - // upgrade T2 - - const unwrap_spmat UA(X); - const quasi_unwrap UB(Y); - - const SpMat& A = UA.M; - const Mat& B = UB.M; - - const SpMat& AA = reinterpret_cast< const SpMat& >(A); - - const Mat BB = conv_to< Mat >::from(B); - - spglue_times_misc::sparse_times_dense(out, AA, BB); - } - else - { - // upgrade T1 and T2 - - const unwrap_spmat UA(X); - const quasi_unwrap UB(Y); - - const SpMat& A = UA.M; - const Mat& B = UB.M; - - SpMat AA(arma_layout_indicator(), A); - - for(uword i=0; i < A.n_nonzero; ++i) { access::rw(AA.values[i]) = out_eT(A.values[i]); } - - const Mat BB = conv_to< Mat >::from(B); - - spglue_times_misc::sparse_times_dense(out, AA, BB); - } - } - - - -template -inline -void -spglue_times_mixed::dense_times_sparse(Mat< typename promote_type::result >& out, const T1& X, const T2& Y) - { - arma_extra_debug_sigprint(); - - typedef typename T1::elem_type eT1; - typedef typename T2::elem_type eT2; - - typedef typename promote_type::result out_eT; - - promote_type::check(); - - if( (is_same_type::no) && (is_same_type::yes) ) - { - // upgrade T1 - - const quasi_unwrap UA(X); - const unwrap_spmat UB(Y); - - const Mat& A = UA.M; - const SpMat& B = UB.M; - - const Mat AA = conv_to< Mat >::from(A); - - const SpMat& BB = reinterpret_cast< const SpMat& >(B); - - spglue_times_misc::dense_times_sparse(out, AA, BB); - } - else - if( (is_same_type::yes) && (is_same_type::no) ) - { - // upgrade T2 - - const quasi_unwrap UA(X); - const unwrap_spmat UB(Y); - - const Mat& A = UA.M; - const SpMat& B = UB.M; - - const Mat& AA = reinterpret_cast< const Mat& >(A); - - SpMat BB(arma_layout_indicator(), B); - - for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } - - spglue_times_misc::dense_times_sparse(out, AA, BB); - } - else - { - // upgrade T1 and T2 - - const quasi_unwrap UA(X); - const unwrap_spmat UB(Y); - - const Mat& A = UA.M; - const SpMat& B = UB.M; - - const Mat AA = conv_to< Mat >::from(A); - - SpMat BB(arma_layout_indicator(), B); - - for(uword i=0; i < B.n_nonzero; ++i) { access::rw(BB.values[i]) = out_eT(B.values[i]); } - - spglue_times_misc::dense_times_sparse(out, AA, BB); - } - } - - - //! @} diff --git a/src/armadillo_bits/spop_diagmat_bones.hpp b/src/armadillo_bits/spop_diagmat_bones.hpp index 96415cb1..41b1ae30 100644 --- a/src/armadillo_bits/spop_diagmat_bones.hpp +++ b/src/armadillo_bits/spop_diagmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_diagmat_meat.hpp b/src/armadillo_bits/spop_diagmat_meat.hpp index d8f1f2ac..a6f9faf9 100644 --- a/src/armadillo_bits/spop_diagmat_meat.hpp +++ b/src/armadillo_bits/spop_diagmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -425,7 +427,7 @@ spop_diagmat2::apply_noalias(SpMat& out, const SpMat& X, const uword row } else // generate a diagonal matrix out of a matrix { - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "diagmat(): requested diagonal out of bounds" diff --git a/src/armadillo_bits/spop_htrans_bones.hpp b/src/armadillo_bits/spop_htrans_bones.hpp index b188735c..609c64b5 100644 --- a/src/armadillo_bits/spop_htrans_bones.hpp +++ b/src/armadillo_bits/spop_htrans_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,16 +29,16 @@ class spop_htrans template struct traits { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; }; template - inline static void apply(SpMat& out, const SpOp& in, const typename arma_not_cx::result* junk = 0); + inline static void apply(SpMat& out, const SpOp& in, const typename arma_not_cx::result* junk = nullptr); template - inline static void apply(SpMat& out, const SpOp& in, const typename arma_cx_only::result* junk = 0); + inline static void apply(SpMat& out, const SpOp& in, const typename arma_cx_only::result* junk = nullptr); }; diff --git a/src/armadillo_bits/spop_htrans_meat.hpp b/src/armadillo_bits/spop_htrans_meat.hpp index 771d3f99..624d3995 100644 --- a/src/armadillo_bits/spop_htrans_meat.hpp +++ b/src/armadillo_bits/spop_htrans_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_max_bones.hpp b/src/armadillo_bits/spop_max_bones.hpp index ee73b051..4c8b60aa 100644 --- a/src/armadillo_bits/spop_max_bones.hpp +++ b/src/armadillo_bits/spop_max_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,10 +31,10 @@ class spop_max // template - inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static typename T1::elem_type vector_max(const T1& X, const typename arma_not_cx::result* junk = 0); + inline static typename T1::elem_type vector_max(const T1& X, const typename arma_not_cx::result* junk = nullptr); template inline static typename arma_not_cx::result max(const SpBase& X); @@ -43,10 +45,10 @@ class spop_max // template - inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = nullptr); template - inline static typename T1::elem_type vector_max(const T1& X, const typename arma_cx_only::result* junk = 0); + inline static typename T1::elem_type vector_max(const T1& X, const typename arma_cx_only::result* junk = nullptr); template inline static typename arma_cx_only::result max(const SpBase& X); diff --git a/src/armadillo_bits/spop_max_meat.hpp b/src/armadillo_bits/spop_max_meat.hpp index 9944c0cb..8f40a0e0 100644 --- a/src/armadillo_bits/spop_max_meat.hpp +++ b/src/armadillo_bits/spop_max_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -71,8 +73,8 @@ spop_max::apply_proxy if(dim == 0) // find the maximum in each column { - Row value(p_n_cols, fill::zeros); - urowvec count(p_n_cols, fill::zeros); + Row value(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); while(it != it_end) { @@ -93,8 +95,8 @@ spop_max::apply_proxy else if(dim == 1) // find the maximum in each row { - Col value(p_n_rows, fill::zeros); - ucolvec count(p_n_rows, fill::zeros); + Col value(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); while(it != it_end) { @@ -150,7 +152,7 @@ spop_max::vector_max } else { - return std::max(eT(0), op_max::direct_max(p.get_values(), p.get_n_nonzero())); + return (std::max)(eT(0), op_max::direct_max(p.get_values(), p.get_n_nonzero())); } } else @@ -175,7 +177,7 @@ spop_max::vector_max } else { - return std::max(eT(0), result); + return (std::max)(eT(0), result); } } } @@ -213,9 +215,9 @@ spop_max::max(const SpBase& X) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { - if ((*it) > max_val) { max_val = *it; } + if((*it) > max_val) { max_val = *it; } ++it; } @@ -234,7 +236,7 @@ spop_max::max(const SpBase& X) } else { - return std::max(eT(0), max_val); + return (std::max)(eT(0), max_val); } } @@ -272,9 +274,9 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { - if ((*it) > max_val) + if((*it) > max_val) { max_val = *it; index_of_max_val = it.row() + it.col() * n_rows; @@ -291,14 +293,14 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) // Convert to actual position in matrix. const uword row = P.get_row_indices()[index_of_max_val]; uword col = 0; - while (P.get_col_ptrs()[++col] <= index_of_max_val) { } + while(P.get_col_ptrs()[++col] <= index_of_max_val) { } index_of_max_val = (col - 1) * n_rows + row; } if(n_elem != n_nonzero) { - max_val = std::max(eT(0), max_val); + max_val = (std::max)(eT(0), max_val); // If the max_val is a nonzero element, we need its actual position in the matrix. if(max_val == eT(0)) @@ -312,25 +314,25 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { // Have we moved more than one position from the last place? - if ((it.col() == last_col) && (it.row() - last_row > 1)) + if((it.col() == last_col) && (it.row() - last_row > 1)) { index_of_max_val = it.col() * n_rows + last_row + 1; break; } - else if ((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) { index_of_max_val = last_col * n_rows + last_row + 1; break; } - else if ((it.col() == last_col + 1) && (it.row() > 0)) + else if((it.col() == last_col + 1) && (it.row() > 0)) { index_of_max_val = it.col() * n_rows; break; } - else if (it.col() > last_col + 1) + else if(it.col() > last_col + 1) { index_of_max_val = (last_col + 1) * n_rows; break; @@ -373,8 +375,8 @@ spop_max::apply_proxy if(dim == 0) // find the maximum in each column { - Row rawval(p_n_cols, fill::zeros); - Row< T> absval(p_n_cols, fill::zeros); + Row rawval(p_n_cols, arma_zeros_indicator()); + Row< T> absval(p_n_cols, arma_zeros_indicator()); while(it != it_end) { @@ -397,8 +399,8 @@ spop_max::apply_proxy else if(dim == 1) // find the maximum in each row { - Col rawval(p_n_rows, fill::zeros); - Col< T> absval(p_n_rows, fill::zeros); + Col rawval(p_n_rows, arma_zeros_indicator()); + Col< T> absval(p_n_rows, arma_zeros_indicator()); while(it != it_end) { @@ -536,11 +538,11 @@ spop_max::max(const SpBase& X) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { const T tmp_val = std::abs(*it); - if (tmp_val > max_val) + if(tmp_val > max_val) { max_val = tmp_val; ret_val = *it; @@ -603,11 +605,11 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { const T tmp_val = std::abs(*it); - if (tmp_val > max_val) + if(tmp_val > max_val) { max_val = tmp_val; index_of_max_val = it.row() + it.col() * n_rows; @@ -624,14 +626,14 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) // Convert to actual position in matrix. const uword row = P.get_row_indices()[index_of_max_val]; uword col = 0; - while (P.get_col_ptrs()[++col] <= index_of_max_val) { } + while(P.get_col_ptrs()[++col] <= index_of_max_val) { } index_of_max_val = (col - 1) * n_rows + row; } if(n_elem != n_nonzero) { - max_val = std::max(T(0), max_val); + max_val = (std::max)(T(0), max_val); // If the max_val is a nonzero element, we need its actual position in the matrix. if(max_val == T(0)) @@ -645,25 +647,25 @@ spop_max::max_with_index(const SpProxy& P, uword& index_of_max_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { // Have we moved more than one position from the last place? - if ((it.col() == last_col) && (it.row() - last_row > 1)) + if((it.col() == last_col) && (it.row() - last_row > 1)) { index_of_max_val = it.col() * n_rows + last_row + 1; break; } - else if ((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) { index_of_max_val = last_col * n_rows + last_row + 1; break; } - else if ((it.col() == last_col + 1) && (it.row() > 0)) + else if((it.col() == last_col + 1) && (it.row() > 0)) { index_of_max_val = it.col() * n_rows; break; } - else if (it.col() > last_col + 1) + else if(it.col() > last_col + 1) { index_of_max_val = (last_col + 1) * n_rows; break; diff --git a/src/armadillo_bits/spop_mean_bones.hpp b/src/armadillo_bits/spop_mean_bones.hpp index 2667868e..3d3e1021 100644 --- a/src/armadillo_bits/spop_mean_bones.hpp +++ b/src/armadillo_bits/spop_mean_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_mean_meat.hpp b/src/armadillo_bits/spop_mean_meat.hpp index 9f693bd1..dd979166 100644 --- a/src/armadillo_bits/spop_mean_meat.hpp +++ b/src/armadillo_bits/spop_mean_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -77,7 +79,7 @@ spop_mean::apply_noalias_fast if(dim == 0) // find the mean in each column { - Row acc(p_n_cols, fill::zeros); + Row acc(p_n_cols, arma_zeros_indicator()); eT* acc_mem = acc.memptr(); @@ -108,7 +110,7 @@ spop_mean::apply_noalias_fast else if(dim == 1) // find the mean in each row { - Col acc(p_n_rows, fill::zeros); + Col acc(p_n_rows, arma_zeros_indicator()); eT* acc_mem = acc.memptr(); @@ -123,7 +125,7 @@ spop_mean::apply_noalias_fast out = acc; } - if(out.is_finite() == false) + if(out.internal_has_nonfinite()) { spop_mean::apply_noalias_slow(out, p, dim); } @@ -331,7 +333,7 @@ spop_mean::iterator_mean(T1& it, const T1& end, const uword n_zero, const eT jun const uword it_begin_pos = it.pos(); - while (it != end) + while(it != end) { acc += (*it); ++it; @@ -360,7 +362,7 @@ spop_mean::iterator_mean_robust(T1& it, const T1& end, const uword n_zero, const const uword it_begin_pos = it.pos(); - while (it != end) + while(it != end) { r_mean += ((*it - r_mean) / T(n_zero + (it.pos() - it_begin_pos) + 1)); ++it; diff --git a/src/armadillo_bits/spop_min_bones.hpp b/src/armadillo_bits/spop_min_bones.hpp index cda3d2f4..1fd8ed3d 100644 --- a/src/armadillo_bits/spop_min_bones.hpp +++ b/src/armadillo_bits/spop_min_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -29,10 +31,10 @@ class spop_min // template - inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = 0); + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_not_cx::result* junk = nullptr); template - inline static typename T1::elem_type vector_min(const T1& X, const typename arma_not_cx::result* junk = 0); + inline static typename T1::elem_type vector_min(const T1& X, const typename arma_not_cx::result* junk = nullptr); template inline static typename arma_not_cx::result min(const SpBase& X); @@ -43,10 +45,10 @@ class spop_min // template - inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = 0); + inline static void apply_proxy(SpMat& out, const SpProxy& p, const uword dim, const typename arma_cx_only::result* junk = nullptr); template - inline static typename T1::elem_type vector_min(const T1& X, const typename arma_cx_only::result* junk = 0); + inline static typename T1::elem_type vector_min(const T1& X, const typename arma_cx_only::result* junk = nullptr); template inline static typename arma_cx_only::result min(const SpBase& X); diff --git a/src/armadillo_bits/spop_min_meat.hpp b/src/armadillo_bits/spop_min_meat.hpp index 7da09b12..b47c33ca 100644 --- a/src/armadillo_bits/spop_min_meat.hpp +++ b/src/armadillo_bits/spop_min_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -71,8 +73,8 @@ spop_min::apply_proxy if(dim == 0) // find the minimum in each column { - Row value(p_n_cols, fill::zeros); - urowvec count(p_n_cols, fill::zeros); + Row value(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); while(it != it_end) { @@ -93,8 +95,8 @@ spop_min::apply_proxy else if(dim == 1) // find the minimum in each row { - Col value(p_n_rows, fill::zeros); - ucolvec count(p_n_rows, fill::zeros); + Col value(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); while(it != it_end) { @@ -150,7 +152,7 @@ spop_min::vector_min } else { - return std::min(eT(0), op_min::direct_min(p.get_values(), p.get_n_nonzero())); + return (std::min)(eT(0), op_min::direct_min(p.get_values(), p.get_n_nonzero())); } } else @@ -175,7 +177,7 @@ spop_min::vector_min } else { - return std::min(eT(0), result); + return (std::min)(eT(0), result); } } } @@ -213,9 +215,9 @@ spop_min::min(const SpBase& X) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { - if ((*it) < min_val) { min_val = *it; } + if((*it) < min_val) { min_val = *it; } ++it; } @@ -234,7 +236,7 @@ spop_min::min(const SpBase& X) } else { - return std::min(eT(0), min_val); + return (std::min)(eT(0), min_val); } } @@ -272,9 +274,9 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { - if ((*it) < min_val) + if((*it) < min_val) { min_val = *it; index_of_min_val = it.row() + it.col() * n_rows; @@ -291,14 +293,14 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) // Convert to actual position in matrix. const uword row = P.get_row_indices()[index_of_min_val]; uword col = 0; - while (P.get_col_ptrs()[++col] < index_of_min_val + 1) { } + while(P.get_col_ptrs()[++col] < index_of_min_val + 1) { } index_of_min_val = (col - 1) * n_rows + row; } if(n_elem != n_nonzero) { - min_val = std::min(eT(0), min_val); + min_val = (std::min)(eT(0), min_val); // If the min_val is a nonzero element, we need its actual position in the matrix. if(min_val == eT(0)) @@ -312,25 +314,25 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { // Have we moved more than one position from the last place? - if ((it.col() == last_col) && (it.row() - last_row > 1)) + if((it.col() == last_col) && (it.row() - last_row > 1)) { index_of_min_val = it.col() * n_rows + last_row + 1; break; } - else if ((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) { index_of_min_val = last_col * n_rows + last_row + 1; break; } - else if ((it.col() == last_col + 1) && (it.row() > 0)) + else if((it.col() == last_col + 1) && (it.row() > 0)) { index_of_min_val = it.col() * n_rows; break; } - else if (it.col() > last_col + 1) + else if(it.col() > last_col + 1) { index_of_min_val = (last_col + 1) * n_rows; break; @@ -373,9 +375,9 @@ spop_min::apply_proxy if(dim == 0) // find the minimum in each column { - Row rawval(p_n_cols, fill::zeros); - Row< T> absval(p_n_cols, fill::zeros); - urowvec count(p_n_cols, fill::zeros); + Row rawval(p_n_cols, arma_zeros_indicator()); + Row< T> absval(p_n_cols, arma_zeros_indicator()); + urowvec count(p_n_cols, arma_zeros_indicator()); while(it != it_end) { @@ -415,9 +417,9 @@ spop_min::apply_proxy else if(dim == 1) // find the minimum in each row { - Col rawval(p_n_rows, fill::zeros); - Col< T> absval(p_n_rows, fill::zeros); - ucolvec count(p_n_rows, fill::zeros); + Col rawval(p_n_rows, arma_zeros_indicator()); + Col< T> absval(p_n_rows, arma_zeros_indicator()); + ucolvec count(p_n_rows, arma_zeros_indicator()); while(it != it_end) { @@ -572,11 +574,11 @@ spop_min::min(const SpBase& X) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { const T tmp_val = std::abs(*it); - if (tmp_val < min_val) + if(tmp_val < min_val) { min_val = tmp_val; ret_val = *it; @@ -639,11 +641,11 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { const T tmp_val = std::abs(*it); - if (tmp_val < min_val) + if(tmp_val < min_val) { min_val = tmp_val; index_of_min_val = it.row() + it.col() * n_rows; @@ -660,14 +662,14 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) // Convert to actual position in matrix. const uword row = P.get_row_indices()[index_of_min_val]; uword col = 0; - while (P.get_col_ptrs()[++col] < index_of_min_val + 1) { } + while(P.get_col_ptrs()[++col] < index_of_min_val + 1) { } index_of_min_val = (col - 1) * n_rows + row; } if(n_elem != n_nonzero) { - min_val = std::min(T(0), min_val); + min_val = (std::min)(T(0), min_val); // If the min_val is a nonzero element, we need its actual position in the matrix. if(min_val == T(0)) @@ -681,25 +683,25 @@ spop_min::min_with_index(const SpProxy& P, uword& index_of_min_val) it_type it = P.begin(); it_type it_end = P.end(); - while (it != it_end) + while(it != it_end) { // Have we moved more than one position from the last place? - if ((it.col() == last_col) && (it.row() - last_row > 1)) + if((it.col() == last_col) && (it.row() - last_row > 1)) { index_of_min_val = it.col() * n_rows + last_row + 1; break; } - else if ((it.col() >= last_col + 1) && (last_row < n_rows - 1)) + else if((it.col() >= last_col + 1) && (last_row < n_rows - 1)) { index_of_min_val = last_col * n_rows + last_row + 1; break; } - else if ((it.col() == last_col + 1) && (it.row() > 0)) + else if((it.col() == last_col + 1) && (it.row() > 0)) { index_of_min_val = it.col() * n_rows; break; } - else if (it.col() > last_col + 1) + else if(it.col() > last_col + 1) { index_of_min_val = (last_col + 1) * n_rows; break; diff --git a/src/armadillo_bits/spop_misc_bones.hpp b/src/armadillo_bits/spop_misc_bones.hpp index d63540ff..42117f9b 100644 --- a/src/armadillo_bits/spop_misc_bones.hpp +++ b/src/armadillo_bits/spop_misc_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_misc_meat.hpp b/src/armadillo_bits/spop_misc_meat.hpp index ab4611db..1ef51cc4 100644 --- a/src/armadillo_bits/spop_misc_meat.hpp +++ b/src/armadillo_bits/spop_misc_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -326,8 +328,8 @@ spop_repelem::apply(SpMat& out, const SpOp 0) && (out_n_cols > 0) && (out_nnz > 0) ) { - umat locs(2, out_nnz); - Col vals( out_nnz); + umat locs(2, out_nnz, arma_nozeros_indicator()); + Col vals( out_nnz, arma_nozeros_indicator()); uword* locs_mem = locs.memptr(); eT* vals_mem = vals.memptr(); @@ -524,7 +526,7 @@ spop_diagvec::apply(SpMat& out, const SpOp 0) ? a : 0; const uword col_offset = (b == 0) ? a : 0; - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= X.n_rows)) || ((col_offset > 0) && (col_offset >= X.n_cols)), "diagvec(): requested diagonal out of bounds" @@ -532,7 +534,7 @@ spop_diagvec::apply(SpMat& out, const SpOp cache(len); + Col cache(len, arma_nozeros_indicator()); eT* cache_mem = cache.memptr(); uword n_nonzero = 0; diff --git a/src/armadillo_bits/spop_norm_bones.hpp b/src/armadillo_bits/spop_norm_bones.hpp new file mode 100644 index 00000000..1d944518 --- /dev/null +++ b/src/armadillo_bits/spop_norm_bones.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_norm +//! @{ + + +class spop_norm + : public traits_op_default + { + public: + + template inline static typename get_pod_type::result mat_norm_1(const SpMat& X); + + template inline static typename get_pod_type::result mat_norm_2(const SpMat& X, const typename arma_real_only::result* junk = nullptr); + template inline static typename get_pod_type::result mat_norm_2(const SpMat& X, const typename arma_cx_only::result* junk = nullptr); + + template inline static typename get_pod_type::result mat_norm_inf(const SpMat& X); + + template inline static typename get_pod_type::result vec_norm_k(const eT* mem, const uword N, const uword k); + }; + + +//! @} diff --git a/src/armadillo_bits/spop_norm_meat.hpp b/src/armadillo_bits/spop_norm_meat.hpp new file mode 100644 index 00000000..402ea290 --- /dev/null +++ b/src/armadillo_bits/spop_norm_meat.hpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup op_norm +//! @{ + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_1(const SpMat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 0), 1) ); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_2(const SpMat& X, const typename arma_real_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + // norm = sqrt( largest eigenvalue of (A^H)*A ), where ^H is the conjugate transpose + // http://math.stackexchange.com/questions/4368/computing-the-largest-eigenvalue-of-a-very-large-sparse-matrix + + typedef typename get_pod_type::result T; + + const SpMat& A = X; + const SpMat B = trans(A); + + const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); + + Col eigval; + eigs_sym(eigval, C, 1); + + return (eigval.n_elem > 0) ? T(std::sqrt(eigval[0])) : T(0); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_2(const SpMat& X, const typename arma_cx_only::result* junk) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + typedef typename get_pod_type::result T; + + // we're calling eigs_gen(), which currently requires ARPACK + #if !defined(ARMA_USE_ARPACK) + { + arma_stop_logic_error("norm(): use of ARPACK must be enabled for norm of complex matrices"); + return T(0); + } + #endif + + const SpMat& A = X; + const SpMat B = trans(A); + + const SpMat C = (A.n_rows <= A.n_cols) ? (A*B) : (B*A); + + Col eigval; + eigs_gen(eigval, C, 1); + + return (eigval.n_elem > 0) ? T(std::sqrt(std::real(eigval[0]))) : T(0); + } + + + +template +inline +typename get_pod_type::result +spop_norm::mat_norm_inf(const SpMat& X) + { + arma_extra_debug_sigprint(); + + // TODO: this can be sped up with a dedicated implementation + return as_scalar( max( sum(abs(X), 1), 0) ); + } + + + +template +inline +typename get_pod_type::result +spop_norm::vec_norm_k(const eT* mem, const uword N, const uword k) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (k == 0), "norm(): unsupported vector norm type" ); + + // create a fake dense vector to allow reuse of code for dense vectors + Col fake_vector( access::rwp(mem), N, false ); + + const Proxy< Col > P_fake_vector(fake_vector); + + if(k == uword(1)) { return op_norm::vec_norm_1(P_fake_vector); } + if(k == uword(2)) { return op_norm::vec_norm_2(P_fake_vector); } + + return op_norm::vec_norm_k(P_fake_vector, int(k)); + } + + + +//! @} diff --git a/src/armadillo_bits/spop_normalise_bones.hpp b/src/armadillo_bits/spop_normalise_bones.hpp index bf98d91d..839b9ca7 100644 --- a/src/armadillo_bits/spop_normalise_bones.hpp +++ b/src/armadillo_bits/spop_normalise_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,7 +29,7 @@ class spop_normalise inline static void apply(SpMat& out, const SpOp& expr); template - inline static void apply_direct(SpMat& out, const SpMat& X, const uword p, const uword dim); + inline static void apply_direct(SpMat& out, const SpMat& X, const uword p); }; diff --git a/src/armadillo_bits/spop_normalise_meat.hpp b/src/armadillo_bits/spop_normalise_meat.hpp index 84bcef42..96a1759d 100644 --- a/src/armadillo_bits/spop_normalise_meat.hpp +++ b/src/armadillo_bits/spop_normalise_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,15 +28,38 @@ spop_normalise::apply(SpMat& out, const SpOp 1), "normalise(): parameter 'dim' must be 0 or 1" ); + arma_debug_check( (p == 0), "normalise(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "normalise(): parameter 'dim' must be 0 or 1" ); const unwrap_spmat U(expr.m); - spop_normalise::apply_direct(out, U.M, p, dim); + const SpMat& X = U.M; + + X.sync(); + + if( X.is_empty() || (X.n_nonzero == 0) ) { out.zeros(X.n_rows, X.n_cols); return; } + + if(dim == 0) + { + spop_normalise::apply_direct(out, X, p); + } + else + if(dim == 1) + { + SpMat tmp1; + SpMat tmp2; + + spop_strans::apply_noalias(tmp1, X); + + spop_normalise::apply_direct(tmp2, tmp1, p); + + spop_strans::apply_noalias(out, tmp2); + } } @@ -42,162 +67,65 @@ spop_normalise::apply(SpMat& out, const SpOp inline void -spop_normalise::apply_direct(SpMat& out, const SpMat& X, const uword p, const uword dim) +spop_normalise::apply_direct(SpMat& out, const SpMat& X, const uword p) { arma_extra_debug_sigprint(); typedef typename get_pod_type::result T; - X.sync(); + SpMat tmp(arma_reserve_indicator(), X.n_rows, X.n_cols, X.n_nonzero); - if( X.is_empty() || (X.n_nonzero == 0) ) { return; } + bool has_zero = false; - if(dim == 0) + podarray norm_vals(X.n_cols); + + T* norm_vals_mem = norm_vals.memptr(); + + for(uword col=0; col < X.n_cols; ++col) { - podarray norm_vals(X.n_cols); - - T* norm_vals_mem = norm_vals.memptr(); - - for(uword i=0; i < norm_vals.n_elem; ++i) - { - const uword col_offset = X.col_ptrs[i ]; - const uword next_col_offset = X.col_ptrs[i + 1]; - - const eT* start_ptr = &X.values[ col_offset]; - const eT* end_ptr = &X.values[next_col_offset]; - - const uword n_elem = end_ptr - start_ptr; - - const Col fake_vec(const_cast(start_ptr), n_elem, false, false); - - const T norm_val = norm(fake_vec, p); - - norm_vals_mem[i] = (norm_val != T(0)) ? norm_val : T(1); - } + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; - const uword N = X.n_nonzero; + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; - umat locs(2, N); - Col vals( N); + const uword n_elem = end_ptr - start_ptr; - uword* locs_mem = locs.memptr(); - eT* vals_mem = vals.memptr(); + const Col fake_vec(const_cast(start_ptr), n_elem, false, false); - typename SpMat::const_iterator it = X.begin(); + const T norm_val = norm(fake_vec, p); - uword new_n_nonzero = 0; - - for(uword i=0; i < N; ++i) - { - const uword row = it.row(); - const uword col = it.col(); - - const eT val = (*it) / norm_vals_mem[col]; - - if(val != eT(0)) - { - (*vals_mem) = val; vals_mem++; - - (*locs_mem) = row; locs_mem++; - (*locs_mem) = col; locs_mem++; - - new_n_nonzero++; - } - - ++it; - } - - const umat tmp_locs(locs.memptr(), 2, new_n_nonzero, false, false); - const Col tmp_vals(vals.memptr(), new_n_nonzero, false, false); - - SpMat tmp(tmp_locs, tmp_vals, X.n_rows, X.n_cols, false, false); - - out.steal_mem(tmp); + norm_vals_mem[col] = (norm_val != T(0)) ? norm_val : T(1); } - else - if(dim == 1) + + const uword N = X.n_nonzero; + + typename SpMat::const_iterator it = X.begin(); + + for(uword i=0; i < N; ++i) { - podarray< T> norm_vals(X.n_rows); - podarray row_vals(X.n_cols); // worst case scenario - - T* norm_vals_mem = norm_vals.memptr(); - eT* row_vals_mem = row_vals.memptr(); - - for(uword i=0; i < norm_vals.n_elem; ++i) - { - // typename SpMat::const_row_iterator row_it = X.begin_row(i); - // typename SpMat::const_row_iterator row_it_end = X.end_row(i); - // - // uword count = 0; - // - // for(; row_it != row_it_end; ++row_it) - // { - // row_vals_mem[count] = (*row_it); - // ++count; - // } - - - // using the .at() accessor, as it's faster than const_row_iterator for accessing a single row - - uword count = 0; - - for(uword col=0; col < X.n_cols; ++col) - { - const eT val = X.at(i,col); - - if(val != eT(0)) - { - row_vals_mem[count] = val; - ++count; - } - } - - const Row fake_vec(row_vals_mem, count, false, false); - - const T norm_val = norm(fake_vec, p); - - norm_vals_mem[i] = (norm_val != T(0)) ? norm_val : T(1); - } + const uword row = it.row(); + const uword col = it.col(); - const uword N = X.n_nonzero; + const eT val = (*it) / norm_vals_mem[col]; - umat locs(2, N); - Col vals( N); + if(val == eT(0)) { has_zero = true; } - uword* locs_mem = locs.memptr(); - eT* vals_mem = vals.memptr(); + access::rw(tmp.values[i]) = val; + access::rw(tmp.row_indices[i]) = row; + access::rw(tmp.col_ptrs[col + 1])++; - typename SpMat::const_iterator it = X.begin(); - - uword new_n_nonzero = 0; - - for(uword i=0; i < N; ++i) - { - const uword row = it.row(); - const uword col = it.col(); - - const eT val = (*it) / norm_vals_mem[row]; - - if(val != eT(0)) - { - (*vals_mem) = val; vals_mem++; - - (*locs_mem) = row; locs_mem++; - (*locs_mem) = col; locs_mem++; - - new_n_nonzero++; - } - - ++it; - } - - const umat tmp_locs(locs.memptr(), 2, new_n_nonzero, false, false); - const Col tmp_vals(vals.memptr(), new_n_nonzero, false, false); - - SpMat tmp(tmp_locs, tmp_vals, X.n_rows, X.n_cols, false, false); - - out.steal_mem(tmp); + ++it; } + + for(uword c=0; c < tmp.n_cols; ++c) + { + access::rw(tmp.col_ptrs[c + 1]) += tmp.col_ptrs[c]; + } + + if(has_zero) { tmp.remove_zeros(); } + + out.steal_mem(tmp); } diff --git a/src/armadillo_bits/spop_repmat_bones.hpp b/src/armadillo_bits/spop_repmat_bones.hpp index ad956743..7ee68432 100644 --- a/src/armadillo_bits/spop_repmat_bones.hpp +++ b/src/armadillo_bits/spop_repmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_repmat_meat.hpp b/src/armadillo_bits/spop_repmat_meat.hpp index 8b1ca516..4c09a3ea 100644 --- a/src/armadillo_bits/spop_repmat_meat.hpp +++ b/src/armadillo_bits/spop_repmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -126,8 +128,8 @@ spop_repmat::apply_noalias(SpMat& out, const uword A_n_rows, const uword A_n // // if( (out_n_rows > 0) && (out_n_cols > 0) && (out_nnz > 0) ) // { -// umat locs(2, out_nnz); -// Col vals( out_nnz); +// umat locs(2, out_nnz, arma_nozeros_indicator()); +// Col vals( out_nnz, arma_nozeros_indicator()); // // uword* locs_mem = locs.memptr(); // eT* vals_mem = vals.memptr(); diff --git a/src/armadillo_bits/spop_reverse_bones.hpp b/src/armadillo_bits/spop_reverse_bones.hpp index 4b55e3f2..3ea80b6e 100644 --- a/src/armadillo_bits/spop_reverse_bones.hpp +++ b/src/armadillo_bits/spop_reverse_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_reverse_meat.hpp b/src/armadillo_bits/spop_reverse_meat.hpp index ca366e59..2ba6e45c 100644 --- a/src/armadillo_bits/spop_reverse_meat.hpp +++ b/src/armadillo_bits/spop_reverse_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -40,7 +42,7 @@ spop_reverse::apply_spmat(SpMat& out, const SpMat& X, const uword dim) return; } - umat locs(2, N); + umat locs(2, N, arma_nozeros_indicator()); uword* locs_mem = locs.memptr(); @@ -108,8 +110,8 @@ spop_reverse::apply_proxy(SpMat& out, const T1& X, const return; } - umat locs(2, N); - Col vals( N); + umat locs(2, N, arma_nozeros_indicator()); + Col vals( N, arma_nozeros_indicator()); uword* locs_mem = locs.memptr(); eT* vals_mem = vals.memptr(); diff --git a/src/armadillo_bits/spop_strans_bones.hpp b/src/armadillo_bits/spop_strans_bones.hpp index 9655e861..7d52ccae 100644 --- a/src/armadillo_bits/spop_strans_bones.hpp +++ b/src/armadillo_bits/spop_strans_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -27,9 +29,9 @@ class spop_strans template struct traits { - static const bool is_row = T1::is_col; // deliberately swapped - static const bool is_col = T1::is_row; - static const bool is_xvec = T1::is_xvec; + static constexpr bool is_row = T1::is_col; // deliberately swapped + static constexpr bool is_col = T1::is_row; + static constexpr bool is_xvec = T1::is_xvec; }; template diff --git a/src/armadillo_bits/spop_strans_meat.hpp b/src/armadillo_bits/spop_strans_meat.hpp index c84b56e6..4cb83c9c 100644 --- a/src/armadillo_bits/spop_strans_meat.hpp +++ b/src/armadillo_bits/spop_strans_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_sum_bones.hpp b/src/armadillo_bits/spop_sum_bones.hpp index 1fa7f833..2e4e5582 100644 --- a/src/armadillo_bits/spop_sum_bones.hpp +++ b/src/armadillo_bits/spop_sum_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -24,7 +26,7 @@ class spop_sum public: template - arma_hot inline static void apply(SpMat& out, const SpOp& in); + inline static void apply(SpMat& out, const SpOp& in); }; diff --git a/src/armadillo_bits/spop_sum_meat.hpp b/src/armadillo_bits/spop_sum_meat.hpp index 6f652858..63badfd9 100644 --- a/src/armadillo_bits/spop_sum_meat.hpp +++ b/src/armadillo_bits/spop_sum_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ template -arma_hot inline void spop_sum::apply(SpMat& out, const SpOp& in) @@ -47,7 +48,7 @@ spop_sum::apply(SpMat& out, const SpOp& in) if(dim == 0) // find the sum in each column { - Row acc(p_n_cols, fill::zeros); + Row acc(p_n_cols, arma_zeros_indicator()); eT* acc_mem = acc.memptr(); @@ -80,7 +81,7 @@ spop_sum::apply(SpMat& out, const SpOp& in) else if(dim == 1) // find the sum in each row { - Col acc(p_n_rows, fill::zeros); + Col acc(p_n_rows, arma_zeros_indicator()); eT* acc_mem = acc.memptr(); diff --git a/src/armadillo_bits/spop_symmat_bones.hpp b/src/armadillo_bits/spop_symmat_bones.hpp index 0d57e15d..cf130f29 100644 --- a/src/armadillo_bits/spop_symmat_bones.hpp +++ b/src/armadillo_bits/spop_symmat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_symmat_meat.hpp b/src/armadillo_bits/spop_symmat_meat.hpp index f36827cc..2ce7cbaa 100644 --- a/src/armadillo_bits/spop_symmat_meat.hpp +++ b/src/armadillo_bits/spop_symmat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_trimat_bones.hpp b/src/armadillo_bits/spop_trimat_bones.hpp index 82579987..5b3aeccc 100644 --- a/src/armadillo_bits/spop_trimat_bones.hpp +++ b/src/armadillo_bits/spop_trimat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -33,4 +35,32 @@ class spop_trimat +class spop_trimatu_ext + : public traits_op_default + { + public: + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + +class spop_trimatl_ext + : public traits_op_default + { + public: + + template + inline static void apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset); + + template + inline static void apply(SpMat& out, const SpOp& in); + }; + + + //! @} diff --git a/src/armadillo_bits/spop_trimat_meat.hpp b/src/armadillo_bits/spop_trimat_meat.hpp index 86439ed1..548b8e8a 100644 --- a/src/armadillo_bits/spop_trimat_meat.hpp +++ b/src/armadillo_bits/spop_trimat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -139,4 +141,226 @@ spop_trimat::apply(SpMat& out, const SpOp +inline +void +spop_trimatu_ext::apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatu(): requested diagonal is out of bounds" ); + + if(A.n_nonzero == 0) { out.zeros(n_rows, n_cols); return; } + + out.reserve(n_rows, n_cols, A.n_nonzero); // upper bound on n_nonzero + + uword count = 0; + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < n_cols; ++i) + { + const uword col = i + col_offset; + + if(i < N) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + const uword end_row = i + row_offset; + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + + if(it_row <= end_row) + { + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + else + { + break; + } + } + } + else + { + if(col < n_cols) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(out.col_ptrs[i + 1]) += out.col_ptrs[i]; + } + + if(count < A.n_nonzero) { out.mem_resize(count); } + } + + + +template +inline +void +spop_trimatu_ext::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& A = U.M; + + arma_debug_check( (A.is_square() == false), "trimatu(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + if(U.is_alias(out)) + { + SpMat tmp; + spop_trimatu_ext::apply_noalias(tmp, A, row_offset, col_offset); + out.steal_mem(tmp); + } + else + { + spop_trimatu_ext::apply_noalias(out, A, row_offset, col_offset); + } + } + + + +// + + + +template +inline +void +spop_trimatl_ext::apply_noalias(SpMat& out, const SpMat& A, const uword row_offset, const uword col_offset) + { + arma_extra_debug_sigprint(); + + const uword n_rows = A.n_rows; + const uword n_cols = A.n_cols; + + arma_debug_check_bounds( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "trimatl(): requested diagonal is out of bounds" ); + + if(A.n_nonzero == 0) { out.zeros(n_rows, n_cols); return; } + + out.reserve(n_rows, n_cols, A.n_nonzero); // upper bound on n_nonzero + + uword count = 0; + + if(col_offset > 0) + { + typename SpMat::const_col_iterator it = A.begin_col_no_sync(0); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col_offset-1); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + + const uword N = (std::min)(n_rows - row_offset, n_cols - col_offset); + + for(uword i=0; i < N; ++i) + { + const uword start_row = i + row_offset; + const uword col = i + col_offset; + + typename SpMat::const_col_iterator it = A.begin_col_no_sync(col); + typename SpMat::const_col_iterator it_end = A.end_col_no_sync(col); + + for(; it != it_end; ++it) + { + const uword it_row = it.row(); + + if(it_row >= start_row) + { + const uword it_col = it.col(); + + access::rw(out.values[count]) = (*it); + access::rw(out.row_indices[count]) = it_row; + access::rw(out.col_ptrs[it_col + 1])++; + ++count; + } + } + } + + for(uword i=0; i < n_cols; ++i) + { + access::rw(out.col_ptrs[i + 1]) += out.col_ptrs[i]; + } + + if(count < A.n_nonzero) { out.mem_resize(count); } + } + + + +template +inline +void +spop_trimatl_ext::apply(SpMat& out, const SpOp& in) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + + const unwrap_spmat U(in.m); + const SpMat& A = U.M; + + arma_debug_check( (A.is_square() == false), "trimatl(): given matrix must be square sized" ); + + const uword row_offset = in.aux_uword_a; + const uword col_offset = in.aux_uword_b; + + if(U.is_alias(out)) + { + SpMat tmp; + spop_trimatl_ext::apply_noalias(tmp, A, row_offset, col_offset); + out.steal_mem(tmp); + } + else + { + spop_trimatl_ext::apply_noalias(out, A, row_offset, col_offset); + } + } + + + //! @} diff --git a/src/armadillo_bits/spop_var_bones.hpp b/src/armadillo_bits/spop_var_bones.hpp index b24faa1f..09f0e243 100644 --- a/src/armadillo_bits/spop_var_bones.hpp +++ b/src/armadillo_bits/spop_var_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -48,11 +50,11 @@ class spop_var // Calculate the variance using iterators, for non-complex numbers. template - inline static eT iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_not_cx::result* junk2 = 0); + inline static eT iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_not_cx::result* junk2 = nullptr); // Calculate the variance using iterators, for complex numbers. template - inline static typename get_pod_type::result iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_cx_only::result* junk2 = 0); + inline static typename get_pod_type::result iterator_var(T1& it, const T1& end, const uword n_zero, const uword norm_type, const eT junk1, const typename arma_cx_only::result* junk2 = nullptr); }; diff --git a/src/armadillo_bits/spop_var_meat.hpp b/src/armadillo_bits/spop_var_meat.hpp index f418f2fa..8d01a445 100644 --- a/src/armadillo_bits/spop_var_meat.hpp +++ b/src/armadillo_bits/spop_var_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -251,7 +253,7 @@ spop_var::direct_var T acc2 = T(0); eT acc3 = eT(0); - for (uword i = 0; i < length; ++i) + for(uword i = 0; i < length; ++i) { const eT tmp = acc1 - X[i]; @@ -315,7 +317,7 @@ spop_var::iterator_var const uword it_begin_pos = it.pos(); - while (it != end) + while(it != end) { const eT tmp = mean - (*it); @@ -326,12 +328,12 @@ spop_var::iterator_var } const uword n_nonzero = (it.pos() - it_begin_pos); - if (n_nonzero == 0) + if(n_nonzero == 0) { return eT(0); } - if (n_nonzero + n_zero == 1) + if(n_nonzero + n_zero == 1) { return eT(0); // only one element } @@ -376,7 +378,7 @@ spop_var::iterator_var const uword it_begin_pos = it.pos(); - while (it != end) + while(it != end) { eT tmp = mean - (*it); @@ -387,12 +389,12 @@ spop_var::iterator_var } const uword n_nonzero = (it.pos() - it_begin_pos); - if (n_nonzero == 0) + if(n_nonzero == 0) { return T(0); } - if (n_nonzero + n_zero == 1) + if(n_nonzero + n_zero == 1) { return T(0); // only one element } diff --git a/src/armadillo_bits/spop_vecnorm_bones.hpp b/src/armadillo_bits/spop_vecnorm_bones.hpp new file mode 100644 index 00000000..eaeecd88 --- /dev/null +++ b/src/armadillo_bits/spop_vecnorm_bones.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_vecnorm +//! @{ + + +class spop_vecnorm + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& expr); + + template + inline static void apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword k); + }; + + +// + + +class spop_vecnorm_ext + : public traits_op_xvec + { + public: + + template + inline static void apply(SpMat& out, const mtSpOp& expr); + + template + inline static void apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword method_id); + }; + + +//! @} diff --git a/src/armadillo_bits/spop_vecnorm_meat.hpp b/src/armadillo_bits/spop_vecnorm_meat.hpp new file mode 100644 index 00000000..56ef7d91 --- /dev/null +++ b/src/armadillo_bits/spop_vecnorm_meat.hpp @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spop_vecnorm +//! @{ + + + +template +inline +void +spop_vecnorm::apply(SpMat& out, const mtSpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword k = expr.aux_uword_a; + const uword dim = expr.aux_uword_b; + + arma_debug_check( (k == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.m); + const SpMat& X = U.M; + + X.sync(); + + if(dim == 0) + { + Mat tmp; + + spop_vecnorm::apply_direct(tmp, X, k); + + out = tmp; + } + else + if(dim == 1) + { + Mat< T> tmp; + SpMat Xt; + + spop_strans::apply_noalias(Xt, X); + + spop_vecnorm::apply_direct(tmp, Xt, k); + + out = tmp.t(); + } + } + + + +template +inline +void +spop_vecnorm::apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword k) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(1, X.n_cols); + + T* out_mem = out.memptr(); + + for(uword col=0; col < X.n_cols; ++col) + { + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; + + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; + + const uword n_elem = end_ptr - start_ptr; + + T out_val = T(0); + + if(n_elem > 0) + { + const Col tmp(const_cast(start_ptr), n_elem, false, false); + + const Proxy< Col > P(tmp); + + if(k == uword(1)) { out_val = op_norm::vec_norm_1(P); } + if(k == uword(2)) { out_val = op_norm::vec_norm_2(P); } + } + + out_mem[col] = out_val; + } + } + + + +// + + + +template +inline +void +spop_vecnorm_ext::apply(SpMat& out, const mtSpOp& expr) + { + arma_extra_debug_sigprint(); + + typedef typename T1::elem_type eT; + typedef typename T1::pod_type T; + + const uword method_id = expr.aux_uword_a; + const uword dim = expr.aux_uword_b; + + arma_debug_check( (method_id == 0), "vecnorm(): unsupported vector norm type" ); + arma_debug_check( (dim > 1), "vecnorm(): parameter 'dim' must be 0 or 1" ); + + const unwrap_spmat U(expr.m); + const SpMat& X = U.M; + + X.sync(); + + if(dim == 0) + { + Mat tmp; + + spop_vecnorm_ext::apply_direct(tmp, X, method_id); + + out = tmp; + } + else + if(dim == 1) + { + Mat< T> tmp; + SpMat Xt; + + spop_strans::apply_noalias(Xt, X); + + spop_vecnorm_ext::apply_direct(tmp, Xt, method_id); + + out = tmp.t(); + } + } + + + +template +inline +void +spop_vecnorm_ext::apply_direct(Mat< typename get_pod_type::result >& out, const SpMat& X, const uword method_id) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + out.zeros(1, X.n_cols); + + T* out_mem = out.memptr(); + + for(uword col=0; col < X.n_cols; ++col) + { + const uword col_offset = X.col_ptrs[col ]; + const uword next_col_offset = X.col_ptrs[col + 1]; + + const eT* start_ptr = &X.values[ col_offset]; + const eT* end_ptr = &X.values[next_col_offset]; + + const uword n_elem = end_ptr - start_ptr; + + T out_val = T(0); + + if(n_elem > 0) + { + const Col tmp(const_cast(start_ptr), n_elem, false, false); + + const Proxy< Col > P(tmp); + + if(method_id == uword(1)) + { + out_val = op_norm::vec_norm_max(P); + } + else + if(method_id == uword(2)) + { + const T tmp_val = op_norm::vec_norm_min(P); + + out_val = (n_elem < X.n_rows) ? T((std::min)(T(0), tmp_val)) : T(tmp_val); + } + } + + out_mem[col] = out_val; + } + } + + + +//! @} diff --git a/src/armadillo_bits/spop_vectorise_bones.hpp b/src/armadillo_bits/spop_vectorise_bones.hpp index 7ea5e4ae..3e38b256 100644 --- a/src/armadillo_bits/spop_vectorise_bones.hpp +++ b/src/armadillo_bits/spop_vectorise_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spop_vectorise_meat.hpp b/src/armadillo_bits/spop_vectorise_meat.hpp index 2f05b40d..56733901 100644 --- a/src/armadillo_bits/spop_vectorise_meat.hpp +++ b/src/armadillo_bits/spop_vectorise_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/spsolve_factoriser_bones.hpp b/src/armadillo_bits/spsolve_factoriser_bones.hpp new file mode 100644 index 00000000..4616e26e --- /dev/null +++ b/src/armadillo_bits/spsolve_factoriser_bones.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spsolve_factoriser +//! @{ + + + +class spsolve_factoriser + { + private: + + void_ptr worker_ptr = nullptr; + uword elem_type_indicator = 0; + uword n_rows = 0; + double rcond_value = double(0); + + template inline void delete_worker(); + + inline void cleanup(); + + + public: + + inline ~spsolve_factoriser(); + inline spsolve_factoriser(); + + inline void reset(); + + inline double rcond() const; + + template inline bool factorise(const SpBase& A_expr, const spsolve_opts_base& settings = spsolve_opts_none(), const typename arma_blas_type_only::result* junk = nullptr); + + template inline bool solve(Mat& X, const Base& B_expr, const typename arma_blas_type_only::result* junk = nullptr); + + inline spsolve_factoriser(const spsolve_factoriser&) = delete; + inline void operator= (const spsolve_factoriser&) = delete; + }; + + + +//! @} diff --git a/src/armadillo_bits/spsolve_factoriser_meat.hpp b/src/armadillo_bits/spsolve_factoriser_meat.hpp new file mode 100644 index 00000000..4450a15c --- /dev/null +++ b/src/armadillo_bits/spsolve_factoriser_meat.hpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup spsolve_factoriser +//! @{ + + + +template +inline +void +spsolve_factoriser::delete_worker() + { + arma_extra_debug_sigprint(); + + if(worker_ptr != nullptr) + { + worker_type* ptr = reinterpret_cast(worker_ptr); + + delete ptr; + + worker_ptr = nullptr; + } + } + + + +inline +void +spsolve_factoriser::cleanup() + { + arma_extra_debug_sigprint(); + + #if defined(ARMA_USE_SUPERLU) + { + if(elem_type_indicator == 1) { delete_worker< superlu_worker< float> >(); } + else if(elem_type_indicator == 2) { delete_worker< superlu_worker< double> >(); } + else if(elem_type_indicator == 3) { delete_worker< superlu_worker< cx_float> >(); } + else if(elem_type_indicator == 4) { delete_worker< superlu_worker >(); } + } + #endif + + worker_ptr = nullptr; + elem_type_indicator = 0; + n_rows = 0; + rcond_value = double(0); + } + + + +inline +spsolve_factoriser::~spsolve_factoriser() + { + arma_extra_debug_sigprint_this(this); + + cleanup(); + } + + + +inline +spsolve_factoriser::spsolve_factoriser() + { + arma_extra_debug_sigprint_this(this); + } + + + +inline +void +spsolve_factoriser::reset() + { + arma_extra_debug_sigprint(); + + cleanup(); + } + + + +inline +double +spsolve_factoriser::rcond() const + { + arma_extra_debug_sigprint(); + + return rcond_value; + } + + + +template +inline +bool +spsolve_factoriser::factorise + ( + const SpBase& A_expr, + const spsolve_opts_base& settings, + const typename arma_blas_type_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::elem_type eT; + typedef typename get_pod_type::result T; + + typedef superlu_worker worker_type; + + // + + cleanup(); + + // + + const unwrap_spmat U(A_expr.get_ref()); + const SpMat& A = U.M; + + if(A.is_square() == false) + { + arma_debug_warn_level(1, "spsolve_factoriser::factorise(): solving under-determined / over-determined systems is currently not supported"); + return false; + } + + n_rows = A.n_rows; + + // + + superlu_opts superlu_opts_default; + + const superlu_opts& opts = (settings.id == 1) ? static_cast(settings) : superlu_opts_default; + + if( (opts.pivot_thresh < double(0)) || (opts.pivot_thresh > double(1)) ) + { + arma_debug_warn_level(1, "spsolve_factoriser::factorise(): pivot_thresh must be in the [0,1] interval" ); + return false; + } + + // + + worker_ptr = new(std::nothrow) worker_type; + + if(worker_ptr == nullptr) + { + arma_debug_warn_level(3, "spsolve_factoriser::factorise(): could not construct worker object"); + return false; + } + + // + + if( is_float::value) { elem_type_indicator = 1; } + else if( is_double::value) { elem_type_indicator = 2; } + else if( is_cx_float::value) { elem_type_indicator = 3; } + else if(is_cx_double::value) { elem_type_indicator = 4; } + + // + + worker_type* local_worker_ptr = reinterpret_cast(worker_ptr); + worker_type& local_worker_ref = (*local_worker_ptr); + + // + + T local_rcond_value = T(0); + + const bool status = local_worker_ref.factorise(local_rcond_value, A, opts); + + rcond_value = double(local_rcond_value); + + if( (status == false) || arma_isnan(local_rcond_value) || ((opts.allow_ugly == false) && (local_rcond_value < std::numeric_limits::epsilon())) ) + { + arma_debug_warn_level(3, "spsolve_factoriser::factorise(): factorisation failed; rcond: ", local_rcond_value); + delete_worker(); + return false; + } + + return true; + } + #else + { + arma_ignore(A_expr); + arma_ignore(settings); + arma_stop_logic_error("spsolve_factoriser::factorise(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +template +inline +bool +spsolve_factoriser::solve + ( + Mat& X, + const Base& B_expr, + const typename arma_blas_type_only::result* junk + ) + { + arma_extra_debug_sigprint(); + arma_ignore(junk); + + #if defined(ARMA_USE_SUPERLU) + { + typedef typename T1::elem_type eT; + + typedef superlu_worker worker_type; + + if(worker_ptr == nullptr) + { + arma_debug_warn_level(2, "spsolve_factoriser::solve(): no factorisation available"); + X.soft_reset(); + return false; + } + + bool type_mismatch = false; + + if( (is_float::value) && (elem_type_indicator != 1) ) { type_mismatch = true; } + else if( (is_double::value) && (elem_type_indicator != 2) ) { type_mismatch = true; } + else if( (is_cx_float::value) && (elem_type_indicator != 3) ) { type_mismatch = true; } + else if((is_cx_double::value) && (elem_type_indicator != 4) ) { type_mismatch = true; } + + if(type_mismatch) + { + arma_debug_warn_level(1, "spsolve_factoriser::solve(): matrix type mismatch"); + X.soft_reset(); + return false; + } + + const quasi_unwrap U(B_expr.get_ref()); + const Mat& B = U.M; + + if(n_rows != B.n_rows) + { + arma_debug_warn_level(1, "spsolve_factoriser::solve(): matrix size mismatch"); + X.soft_reset(); + return false; + } + + const bool is_alias = U.is_alias(X); + + Mat tmp; + Mat& out = is_alias ? tmp : X; + + worker_type* local_worker_ptr = reinterpret_cast(worker_ptr); + worker_type& local_worker_ref = (*local_worker_ptr); + + const bool status = local_worker_ref.solve(out,B); + + if(is_alias) { X.steal_mem(tmp); } + + if(status == false) + { + arma_debug_warn_level(3, "spsolve_factoriser::solve(): solution not found"); + X.soft_reset(); + return false; + } + + return true; + } + #else + { + arma_ignore(X); + arma_ignore(B_expr); + arma_stop_logic_error("spsolve_factoriser::solve(): use of SuperLU must be enabled"); + return false; + } + #endif + } + + + +//! @} diff --git a/src/armadillo_bits/strip.hpp b/src/armadillo_bits/strip.hpp index c2936089..73850c58 100644 --- a/src/armadillo_bits/strip.hpp +++ b/src/armadillo_bits/strip.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -31,7 +33,7 @@ struct strip_diagmat arma_extra_debug_sigprint(); } - static const bool do_diagmat = false; + static constexpr bool do_diagmat = false; const T1& M; }; @@ -50,7 +52,7 @@ struct strip_diagmat< Op > arma_extra_debug_sigprint(); } - static const bool do_diagmat = true; + static constexpr bool do_diagmat = true; const T1& M; }; @@ -71,19 +73,19 @@ struct strip_inv const T1& M; - static const bool do_inv = false; - static const bool do_inv_sympd = false; + static constexpr bool do_inv_gen = false; + static constexpr bool do_inv_spd = false; }; template -struct strip_inv< Op > +struct strip_inv< Op > { typedef T1 stored_type; inline - strip_inv(const Op& X) + strip_inv(const Op& X) : M(X.m) { arma_extra_debug_sigprint(); @@ -91,19 +93,19 @@ struct strip_inv< Op > const T1& M; - static const bool do_inv = true; - static const bool do_inv_sympd = false; + static constexpr bool do_inv_gen = true; + static constexpr bool do_inv_spd = false; }; template -struct strip_inv< Op > +struct strip_inv< Op > { typedef T1 stored_type; inline - strip_inv(const Op& X) + strip_inv(const Op& X) : M(X.m) { arma_extra_debug_sigprint(); @@ -111,8 +113,8 @@ struct strip_inv< Op > const T1& M; - static const bool do_inv = true; - static const bool do_inv_sympd = true; + static constexpr bool do_inv_gen = false; + static constexpr bool do_inv_spd = true; }; @@ -124,9 +126,9 @@ struct strip_trimat const T1& M; - static const bool do_trimat = false; - static const bool do_triu = false; - static const bool do_tril = false; + static constexpr bool do_trimat = false; + static constexpr bool do_triu = false; + static constexpr bool do_tril = false; inline strip_trimat(const T1& X) @@ -145,9 +147,10 @@ struct strip_trimat< Op > const T1& M; - static const bool do_trimat = true; - const bool do_triu; - const bool do_tril; + static constexpr bool do_trimat = true; + + const bool do_triu; + const bool do_tril; inline strip_trimat(const Op& X) @@ -161,4 +164,68 @@ struct strip_trimat< Op > +// + + + +template +struct sp_strip_trans + { + typedef T1 stored_type; + + inline + sp_strip_trans(const T1& X) + : M(X) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = false; + static constexpr bool do_strans = false; + + const T1& M; + }; + + + +template +struct sp_strip_trans< SpOp > + { + typedef T1 stored_type; + + inline + sp_strip_trans(const SpOp& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = true; + static constexpr bool do_strans = false; + + const T1& M; + }; + + + +template +struct sp_strip_trans< SpOp > + { + typedef T1 stored_type; + + inline + sp_strip_trans(const SpOp& X) + : M(X.m) + { + arma_extra_debug_sigprint(); + } + + static constexpr bool do_htrans = false; + static constexpr bool do_strans = true; + + const T1& M; + }; + + + //! @} diff --git a/src/armadillo_bits/subview_bones.hpp b/src/armadillo_bits/subview_bones.hpp index 1a21335c..95553a10 100644 --- a/src/armadillo_bits/subview_bones.hpp +++ b/src/armadillo_bits/subview_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,9 +21,9 @@ //! Class for storing data required to construct or apply operations to a submatrix -//! (i.e. where the submatrix starts and ends as well as a reference/pointer to the original matrix), +//! (ie. where the submatrix starts and ends as well as a reference/pointer to the original matrix), template -class subview : public Base > +class subview : public Base< eT, subview > { public: @@ -30,9 +32,9 @@ class subview : public Base > arma_aligned const Mat& m; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; const uword aux_row1; const uword aux_col1; @@ -49,6 +51,10 @@ class subview : public Base > public: inline ~subview(); + inline subview() = delete; + + inline subview(const subview& in); + inline subview( subview&& in); template inline void inplace_op(const eT val ); template inline void inplace_op(const Base& x, const char* identifier); @@ -83,6 +89,9 @@ class subview : public Base > template inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + inline void operator=(const std::initializer_list& list); + inline void operator=(const std::initializer_list< std::initializer_list >& list); + inline static void extract(Mat& out, const subview& in); @@ -101,6 +110,8 @@ class subview : public Base > inline void clean(const pod_type threshold); + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); @@ -108,19 +119,25 @@ class subview : public Base > inline void randu(); inline void randn(); - inline eT at_alt (const uword ii) const; + arma_warn_unused inline eT at_alt (const uword ii) const; - inline eT& operator[](const uword ii); - inline eT operator[](const uword ii) const; + arma_warn_unused inline eT& operator[](const uword ii); + arma_warn_unused inline eT operator[](const uword ii) const; - inline eT& operator()(const uword ii); - inline eT operator()(const uword ii) const; + arma_warn_unused inline eT& operator()(const uword ii); + arma_warn_unused inline eT operator()(const uword ii) const; - inline eT& operator()(const uword in_row, const uword in_col); - inline eT operator()(const uword in_row, const uword in_col) const; + arma_warn_unused inline eT& operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; - inline eT& at(const uword in_row, const uword in_col); - inline eT at(const uword in_row, const uword in_col) const; + arma_warn_unused inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + arma_warn_unused inline eT& front(); + arma_warn_unused inline eT front() const; + + arma_warn_unused inline eT& back(); + arma_warn_unused inline eT back() const; arma_inline eT* colptr(const uword in_col); arma_inline const eT* colptr(const uword in_col) const; @@ -128,11 +145,13 @@ class subview : public Base > template inline bool check_overlap(const subview& x) const; - inline arma_warn_unused bool is_vec() const; - inline arma_warn_unused bool is_finite() const; + arma_warn_unused inline bool is_vec() const; + arma_warn_unused inline bool is_finite() const; + arma_warn_unused inline bool is_zero(const pod_type tol = 0) const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; inline subview_row row(const uword row_num); inline const subview_row row(const uword row_num) const; @@ -170,13 +189,11 @@ class subview : public Base > template inline subview_each2< subview, 0, T1 > each_col(const Base& indices); template inline subview_each2< subview, 1, T1 > each_row(const Base& indices); - #if defined(ARMA_USE_CXX11) inline void each_col(const std::function< void( Col&) >& F); inline void each_col(const std::function< void(const Col&) >& F) const; inline void each_row(const std::function< void( Row&) >& F); inline void each_row(const std::function< void(const Row&) >& F) const; - #endif inline diagview diag(const sword in_id = 0); inline const diagview diag(const sword in_id = 0) const; @@ -195,15 +212,15 @@ class subview : public Base > inline iterator(const iterator& X); inline iterator(subview& in_sv, const uword in_row, const uword in_col); - inline arma_warn_unused eT& operator*(); + arma_warn_unused inline eT& operator*(); - inline iterator& operator++(); - inline arma_warn_unused iterator operator++(int); + inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); - inline arma_warn_unused bool operator==(const iterator& rhs) const; - inline arma_warn_unused bool operator!=(const iterator& rhs) const; - inline arma_warn_unused bool operator==(const const_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_iterator& rhs) const; + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; typedef std::forward_iterator_tag iterator_category; typedef eT value_type; @@ -230,15 +247,15 @@ class subview : public Base > inline const_iterator(const const_iterator& X); inline const_iterator(const subview& in_sv, const uword in_row, const uword in_col); - inline arma_warn_unused const eT& operator*(); + arma_warn_unused inline const eT& operator*(); - inline const_iterator& operator++(); - inline arma_warn_unused const_iterator operator++(int); + inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); - inline arma_warn_unused bool operator==(const iterator& rhs) const; - inline arma_warn_unused bool operator!=(const iterator& rhs) const; - inline arma_warn_unused bool operator==(const const_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_iterator& rhs) const; + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; // So that we satisfy the STL iterator types. typedef std::forward_iterator_tag iterator_category; @@ -267,15 +284,15 @@ class subview : public Base > inline row_iterator(const row_iterator& X); inline row_iterator(subview& in_sv, const uword in_row, const uword in_col); - inline arma_warn_unused eT& operator* (); + arma_warn_unused inline eT& operator* (); - inline row_iterator& operator++(); - inline arma_warn_unused row_iterator operator++(int); + inline row_iterator& operator++(); + arma_warn_unused inline row_iterator operator++(int); - inline arma_warn_unused bool operator!=(const row_iterator& X) const; - inline arma_warn_unused bool operator==(const row_iterator& X) const; - inline arma_warn_unused bool operator!=(const const_row_iterator& X) const; - inline arma_warn_unused bool operator==(const const_row_iterator& X) const; + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; typedef std::forward_iterator_tag iterator_category; typedef eT value_type; @@ -284,7 +301,6 @@ class subview : public Base > typedef eT& reference; arma_aligned Mat* M; - arma_aligned eT* current_ptr; arma_aligned uword current_row; arma_aligned uword current_col; @@ -302,15 +318,15 @@ class subview : public Base > inline const_row_iterator(const const_row_iterator& X); inline const_row_iterator(const subview& in_sv, const uword in_row, const uword in_col); - inline arma_warn_unused const eT& operator*() const; + arma_warn_unused inline const eT& operator*() const; - inline const_row_iterator& operator++(); - inline arma_warn_unused const_row_iterator operator++(int); + inline const_row_iterator& operator++(); + arma_warn_unused inline const_row_iterator operator++(int); - inline arma_warn_unused bool operator!=(const row_iterator& X) const; - inline arma_warn_unused bool operator==(const row_iterator& X) const; - inline arma_warn_unused bool operator!=(const const_row_iterator& X) const; - inline arma_warn_unused bool operator==(const const_row_iterator& X) const; + arma_warn_unused inline bool operator!=(const row_iterator& X) const; + arma_warn_unused inline bool operator==(const row_iterator& X) const; + arma_warn_unused inline bool operator!=(const const_row_iterator& X) const; + arma_warn_unused inline bool operator==(const const_row_iterator& X) const; typedef std::forward_iterator_tag iterator_category; typedef eT value_type; @@ -319,7 +335,6 @@ class subview : public Base > typedef const eT& reference; arma_aligned const Mat* M; - arma_aligned const eT* current_ptr; arma_aligned uword current_row; arma_aligned uword current_col; @@ -338,10 +353,7 @@ class subview : public Base > inline const_iterator cend() const; - private: - friend class Mat; - subview(); }; @@ -354,27 +366,28 @@ class subview_col : public subview typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; const eT* colmem; inline void operator= (const subview& x); inline void operator= (const subview_col& x); inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); - template - inline void operator= (const Base& x); + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); template inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); - arma_inline const Op,op_htrans> t() const; - arma_inline const Op,op_htrans> ht() const; - arma_inline const Op,op_strans> st() const; + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; - arma_inline const Op,op_strans> as_row() const; + arma_warn_unused arma_inline const Op,op_strans> as_row() const; inline void fill(const eT val); inline void zeros(); @@ -412,29 +425,90 @@ class subview_col : public subview inline subview_col tail(const uword N); inline const subview_col tail(const uword N) const; - inline arma_warn_unused eT min() const; - inline arma_warn_unused eT max() const; + arma_warn_unused inline eT min() const; + arma_warn_unused inline eT max() const; inline eT min(uword& index_of_min_val) const; inline eT max(uword& index_of_max_val) const; - inline arma_warn_unused uword index_min() const; - inline arma_warn_unused uword index_max() const; + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; + + inline subview_col(const subview_col& in); + inline subview_col( subview_col&& in); protected: inline subview_col(const Mat& in_m, const uword in_col); inline subview_col(const Mat& in_m, const uword in_col, const uword in_row1, const uword in_n_rows); + inline subview_col() = delete; - private: - friend class Mat; friend class Col; friend class subview; + }; + + + +template +class subview_cols : public subview + { + public: + + typedef eT elem_type; + typedef typename get_pod_type::result pod_type; + + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; + + inline subview_cols(const subview_cols& in); + inline subview_cols( subview_cols&& in); + + inline void operator= (const subview& x); + inline void operator= (const subview_cols& x); + inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); + inline void operator= (const std::initializer_list< std::initializer_list >& list); + + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); + + template + inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); + + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; + + arma_warn_unused arma_inline const Op,op_vectorise_col> as_col() const; + + arma_warn_unused inline eT at_alt (const uword ii) const; + + arma_warn_unused inline eT& operator[](const uword ii); + arma_warn_unused inline eT operator[](const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword ii); + arma_warn_unused inline eT operator()(const uword ii) const; + + arma_warn_unused inline eT& operator()(const uword in_row, const uword in_col); + arma_warn_unused inline eT operator()(const uword in_row, const uword in_col) const; + + arma_warn_unused inline eT& at(const uword in_row, const uword in_col); + arma_warn_unused inline eT at(const uword in_row, const uword in_col) const; + + arma_inline eT* colptr(const uword in_col); + arma_inline const eT* colptr(const uword in_col) const; + + protected: + + inline subview_cols(const Mat& in_m, const uword in_col1, const uword in_n_cols); + inline subview_cols() = delete; - subview_col(); + friend class Mat; + friend class subview; }; @@ -447,25 +521,26 @@ class subview_row : public subview typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = true; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = true; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; inline void operator= (const subview& x); inline void operator= (const subview_row& x); inline void operator= (const eT val); + inline void operator= (const std::initializer_list& list); - template - inline void operator= (const Base& x); + template inline void operator= (const Base& x); + template inline void operator= (const SpBase& x); template inline typename enable_if2< is_same_type::value, void>::result operator=(const Gen& x); - arma_inline const Op,op_htrans> t() const; - arma_inline const Op,op_htrans> ht() const; - arma_inline const Op,op_strans> st() const; + arma_warn_unused arma_inline const Op,op_htrans> t() const; + arma_warn_unused arma_inline const Op,op_htrans> ht() const; + arma_warn_unused arma_inline const Op,op_strans> st() const; - arma_inline const Op,op_strans> as_col() const; + arma_warn_unused arma_inline const Op,op_strans> as_col() const; inline eT at_alt (const uword i) const; @@ -496,8 +571,8 @@ class subview_row : public subview inline subview_row tail(const uword N); inline const subview_row tail(const uword N) const; - inline arma_warn_unused uword index_min() const; - inline arma_warn_unused uword index_max() const; + arma_warn_unused inline uword index_min() const; + arma_warn_unused inline uword index_max() const; inline typename subview::row_iterator begin(); inline typename subview::const_row_iterator begin() const; @@ -507,40 +582,42 @@ class subview_row : public subview inline typename subview::const_row_iterator end() const; inline typename subview::const_row_iterator cend() const; + inline subview_row(const subview_row& in); + inline subview_row( subview_row&& in); + + protected: inline subview_row(const Mat& in_m, const uword in_row); inline subview_row(const Mat& in_m, const uword in_row, const uword in_col1, const uword in_n_cols); + inline subview_row() = delete; - private: - friend class Mat; friend class Row; friend class subview; - - subview_row(); }; template -class subview_row_strans : public Base > +class subview_row_strans : public Base< eT, subview_row_strans > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row& sv_row; - const uword n_rows; // equal to n_elem - const uword n_elem; - static const uword n_cols = 1; + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; inline explicit subview_row_strans(const subview_row& in_sv_row); @@ -559,22 +636,23 @@ class subview_row_strans : public Base > template -class subview_row_htrans : public Base > +class subview_row_htrans : public Base< eT, subview_row_htrans > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const subview_row& sv_row; - const uword n_rows; // equal to n_elem - const uword n_elem; - static const uword n_cols = 1; + const uword n_rows; // equal to n_elem + const uword n_elem; + + static constexpr uword n_cols = 1; inline explicit subview_row_htrans(const subview_row& in_sv_row); diff --git a/src/armadillo_bits/subview_cube_bones.hpp b/src/armadillo_bits/subview_cube_bones.hpp index f47b7e10..ae71e677 100644 --- a/src/armadillo_bits/subview_cube_bones.hpp +++ b/src/armadillo_bits/subview_cube_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,9 +21,9 @@ //! Class for storing data required to construct or apply operations to a subcube -//! (i.e. where the subcube starts and ends as well as a reference/pointer to the original cube), +//! (ie. where the subcube starts and ends as well as a reference/pointer to the original cube), template -class subview_cube : public BaseCube > +class subview_cube : public BaseCube< eT, subview_cube > { public: @@ -49,6 +51,14 @@ class subview_cube : public BaseCube > public: inline ~subview_cube(); + inline subview_cube() = delete; + + inline subview_cube(const subview_cube& in); + inline subview_cube( subview_cube&& in); + + template inline void inplace_op(const eT val ); + template inline void inplace_op(const BaseCube& x, const char* identifier); + template inline void inplace_op(const subview_cube& x, const char* identifier); inline void operator= (const eT val); inline void operator+= (const eT val); @@ -95,25 +105,27 @@ class subview_cube : public BaseCube > template inline void transform(functor F); template inline void imbue(functor F); - #if defined(ARMA_USE_CXX11) inline void each_slice(const std::function< void( Mat&) >& F); inline void each_slice(const std::function< void(const Mat&) >& F) const; - #endif inline void replace(const eT old_val, const eT new_val); inline void clean(const pod_type threshold); + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); inline void randu(); inline void randn(); - inline arma_warn_unused bool is_finite() const; + arma_warn_unused inline bool is_finite() const; + arma_warn_unused inline bool is_zero(const pod_type tol = 0) const; - inline arma_warn_unused bool has_inf() const; - inline arma_warn_unused bool has_nan() const; + arma_warn_unused inline bool has_inf() const; + arma_warn_unused inline bool has_nan() const; + arma_warn_unused inline bool has_nonfinite() const; inline eT at_alt (const uword i) const; @@ -132,8 +144,10 @@ class subview_cube : public BaseCube > arma_inline eT* slice_colptr(const uword in_slice, const uword in_col); arma_inline const eT* slice_colptr(const uword in_slice, const uword in_col) const; - inline bool check_overlap(const subview_cube& x) const; - inline bool check_overlap(const Mat& x) const; + template + inline bool check_overlap(const subview_cube& x) const; + + inline bool check_overlap(const Mat& x) const; class const_iterator; @@ -146,15 +160,15 @@ class subview_cube : public BaseCube > inline iterator(const iterator& X); inline iterator(subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice); - inline arma_warn_unused eT& operator*(); + arma_warn_unused inline eT& operator*(); - inline iterator& operator++(); - inline arma_warn_unused iterator operator++(int); + inline iterator& operator++(); + arma_warn_unused inline iterator operator++(int); - inline arma_warn_unused bool operator==(const iterator& rhs) const; - inline arma_warn_unused bool operator!=(const iterator& rhs) const; - inline arma_warn_unused bool operator==(const const_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_iterator& rhs) const; + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; typedef std::forward_iterator_tag iterator_category; typedef eT value_type; @@ -185,15 +199,15 @@ class subview_cube : public BaseCube > inline const_iterator(const const_iterator& X); inline const_iterator(const subview_cube& in_sv, const uword in_row, const uword in_col, const uword in_slice); - inline arma_warn_unused const eT& operator*(); + arma_warn_unused inline const eT& operator*(); - inline const_iterator& operator++(); - inline arma_warn_unused const_iterator operator++(int); + inline const_iterator& operator++(); + arma_warn_unused inline const_iterator operator++(int); - inline arma_warn_unused bool operator==(const iterator& rhs) const; - inline arma_warn_unused bool operator!=(const iterator& rhs) const; - inline arma_warn_unused bool operator==(const const_iterator& rhs) const; - inline arma_warn_unused bool operator!=(const const_iterator& rhs) const; + arma_warn_unused inline bool operator==(const iterator& rhs) const; + arma_warn_unused inline bool operator!=(const iterator& rhs) const; + arma_warn_unused inline bool operator==(const const_iterator& rhs) const; + arma_warn_unused inline bool operator!=(const const_iterator& rhs) const; // So that we satisfy the STL iterator types. typedef std::forward_iterator_tag iterator_category; @@ -225,12 +239,8 @@ class subview_cube : public BaseCube > inline const_iterator cend() const; - private: - friend class Mat; friend class Cube; - - subview_cube(); }; diff --git a/src/armadillo_bits/subview_cube_each_bones.hpp b/src/armadillo_bits/subview_cube_each_bones.hpp index 96551727..29d81fd2 100644 --- a/src/armadillo_bits/subview_cube_each_bones.hpp +++ b/src/armadillo_bits/subview_cube_each_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,19 +28,17 @@ class subview_cube_each_common const Cube& P; - inline void check_size(const Mat& A) const; + template + inline void check_size(const Mat& A) const; protected: arma_inline subview_cube_each_common(const Cube& in_p); + inline subview_cube_each_common() = delete; - arma_cold inline const std::string incompat_size_string(const Mat& A) const; - - - private: - - subview_cube_each_common(); + template + arma_cold inline const std::string incompat_size_string(const Mat& A) const; }; @@ -50,6 +50,7 @@ class subview_cube_each1 : public subview_cube_each_common protected: arma_inline subview_cube_each1(const Cube& in_p); + inline subview_cube_each1() = delete; public: @@ -65,8 +66,6 @@ class subview_cube_each1 : public subview_cube_each_common template inline void operator*= (const Base& x); - private: - friend class Cube; }; @@ -78,6 +77,7 @@ class subview_cube_each2 : public subview_cube_each_common protected: inline subview_cube_each2(const Cube& in_p, const Base& in_indices); + inline subview_cube_each2() = delete; public: @@ -95,8 +95,6 @@ class subview_cube_each2 : public subview_cube_each_common template inline void operator/= (const Base& x); - private: - friend class Cube; }; @@ -108,7 +106,7 @@ class subview_cube_each1_aux template static inline Cube operator_plus(const subview_cube_each1& X, const Base& Y); - + template static inline Cube operator_minus(const subview_cube_each1& X, const Base& Y); @@ -139,7 +137,7 @@ class subview_cube_each2_aux template static inline Cube operator_plus(const subview_cube_each2& X, const Base& Y); - + template static inline Cube operator_minus(const subview_cube_each2& X, const Base& Y); diff --git a/src/armadillo_bits/subview_cube_each_meat.hpp b/src/armadillo_bits/subview_cube_each_meat.hpp index e819ebe6..a3069182 100644 --- a/src/armadillo_bits/subview_cube_each_meat.hpp +++ b/src/armadillo_bits/subview_cube_each_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -33,9 +35,10 @@ subview_cube_each_common::subview_cube_each_common(const Cube& in_p) template +template inline void -subview_cube_each_common::check_size(const Mat& A) const +subview_cube_each_common::check_size(const Mat& A) const { if(arma_config::debug) { @@ -49,10 +52,10 @@ subview_cube_each_common::check_size(const Mat& A) const template -arma_cold +template inline const std::string -subview_cube_each_common::incompat_size_string(const Mat& A) const +subview_cube_each_common::incompat_size_string(const Mat& A) const { std::ostringstream tmp; @@ -132,7 +135,7 @@ subview_cube_each1::operator+= (const Base& in) const uword p_n_elem_slice = p.n_elem_slice; const eT* A_mem = A.memptr(); - + for(uword i=0; i < p_n_slices; ++i) { arrayops::inplace_plus( p.slice_memptr(i), A_mem, p_n_elem_slice ); } } @@ -285,7 +288,7 @@ subview_cube_each2::operator= (const Base& in) const uword p_n_slices = p.n_slices; const uword p_n_elem_slice = p.n_elem_slice; - + const uword* indices_mem = U.M.memptr(); const uword N = U.M.n_elem; @@ -295,7 +298,7 @@ subview_cube_each2::operator= (const Base& in) { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::copy(p.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -324,7 +327,7 @@ subview_cube_each2::operator+= (const Base& in) const uword p_n_slices = p.n_slices; const uword p_n_elem_slice = p.n_elem_slice; - + const uword* indices_mem = U.M.memptr(); const uword N = U.M.n_elem; @@ -334,7 +337,7 @@ subview_cube_each2::operator+= (const Base& in) { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_plus(p.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -363,7 +366,7 @@ subview_cube_each2::operator-= (const Base& in) const uword p_n_slices = p.n_slices; const uword p_n_elem_slice = p.n_elem_slice; - + const uword* indices_mem = U.M.memptr(); const uword N = U.M.n_elem; @@ -373,7 +376,7 @@ subview_cube_each2::operator-= (const Base& in) { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_minus(p.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -402,7 +405,7 @@ subview_cube_each2::operator%= (const Base& in) const uword p_n_slices = p.n_slices; const uword p_n_elem_slice = p.n_elem_slice; - + const uword* indices_mem = U.M.memptr(); const uword N = U.M.n_elem; @@ -412,7 +415,7 @@ subview_cube_each2::operator%= (const Base& in) { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_mul(p.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -441,7 +444,7 @@ subview_cube_each2::operator/= (const Base& in) const uword p_n_slices = p.n_slices; const uword p_n_elem_slice = p.n_elem_slice; - + const uword* indices_mem = U.M.memptr(); const uword N = U.M.n_elem; @@ -451,7 +454,7 @@ subview_cube_each2::operator/= (const Base& in) { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_div(p.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -482,7 +485,7 @@ subview_cube_each1_aux::operator_plus const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -519,7 +522,7 @@ subview_cube_each1_aux::operator_minus const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -556,7 +559,7 @@ subview_cube_each1_aux::operator_minus const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(X.get_ref()); const Mat& A = tmp.M; @@ -593,7 +596,7 @@ subview_cube_each1_aux::operator_schur const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -630,7 +633,7 @@ subview_cube_each1_aux::operator_div const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -667,7 +670,7 @@ subview_cube_each1_aux::operator_div const uword p_n_cols = p.n_cols; const uword p_n_slices = p.n_slices; - Cube out(p_n_rows, p_n_cols, p_n_slices); + Cube out(p_n_rows, p_n_cols, p_n_slices, arma_nozeros_indicator()); const unwrap tmp(X.get_ref()); const Mat& A = tmp.M; @@ -703,7 +706,7 @@ subview_cube_each1_aux::operator_times const unwrap tmp(Y.get_ref()); const Mat& M = tmp.M; - Cube out(C.n_rows, M.n_cols, C.n_slices); + Cube out(C.n_rows, M.n_cols, C.n_slices, arma_nozeros_indicator()); for(uword i=0; i < C.n_slices; ++i) { @@ -734,7 +737,7 @@ subview_cube_each1_aux::operator_times const Cube& C = Y.P; - Cube out(M.n_rows, C.n_cols, C.n_slices); + Cube out(M.n_rows, C.n_cols, C.n_slices, arma_nozeros_indicator()); for(uword i=0; i < C.n_slices; ++i) { @@ -790,7 +793,7 @@ subview_cube_each2_aux::operator_plus { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_plus(out.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -835,7 +838,7 @@ subview_cube_each2_aux::operator_minus { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_minus(out.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -879,7 +882,7 @@ subview_cube_each2_aux::operator_minus { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); Mat out_slice( out.slice_memptr(slice), p_n_rows, p_n_cols, false, true); const Mat p_slice(const_cast(p.slice_memptr(slice)), p_n_rows, p_n_cols, false, true); @@ -927,7 +930,7 @@ subview_cube_each2_aux::operator_schur { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_mul(out.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -972,7 +975,7 @@ subview_cube_each2_aux::operator_div { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); arrayops::inplace_div(out.slice_memptr(slice), A_mem, p_n_elem_slice); } @@ -1016,7 +1019,7 @@ subview_cube_each2_aux::operator_div { const uword slice = indices_mem[i]; - arma_debug_check( (slice >= p_n_slices), "each_slice(): index out of bounds" ); + arma_debug_check_bounds( (slice >= p_n_slices), "each_slice(): index out of bounds" ); Mat out_slice( out.slice_memptr(slice), p_n_rows, p_n_cols, false, true); const Mat p_slice(const_cast(p.slice_memptr(slice)), p_n_rows, p_n_cols, false, true); diff --git a/src/armadillo_bits/subview_cube_meat.hpp b/src/armadillo_bits/subview_cube_meat.hpp index bfbf2e49..039e80c4 100644 --- a/src/armadillo_bits/subview_cube_meat.hpp +++ b/src/armadillo_bits/subview_cube_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,7 +24,7 @@ template inline subview_cube::~subview_cube() { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); } @@ -49,68 +51,178 @@ subview_cube::subview_cube , n_slices (in_n_slices) , n_elem (n_elem_slice * in_n_slices) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); } template inline -void -subview_cube::operator= (const eT val) +subview_cube::subview_cube(const subview_cube& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , aux_slice1 (in.aux_slice1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem_slice(in.n_elem_slice) + , n_slices (in.n_slices ) + , n_elem (in.n_elem ) { - arma_extra_debug_sigprint(); - - if(n_elem != 1) - { - arma_debug_assert_same_size(n_rows, n_cols, n_slices, 1, 1, 1, "copy into subcube"); - } + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +subview_cube::subview_cube(subview_cube&& in) + : m (in.m ) + , aux_row1 (in.aux_row1 ) + , aux_col1 (in.aux_col1 ) + , aux_slice1 (in.aux_slice1 ) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem_slice(in.n_elem_slice) + , n_slices (in.n_slices ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); - Cube& Q = const_cast< Cube& >(m); + // for paranoia - Q.at(aux_row1, aux_col1, aux_slice1) = val; + access::rw(in.aux_row1 ) = 0; + access::rw(in.aux_col1 ) = 0; + access::rw(in.aux_slice1 ) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem_slice) = 0; + access::rw(in.n_slices ) = 0; + access::rw(in.n_elem ) = 0; } template +template inline void -subview_cube::operator+= (const eT val) +subview_cube::inplace_op(const eT val) { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; - const uword local_n_slices = n_slices; + subview_cube& t = *this; - for(uword slice = 0; slice < local_n_slices; ++slice) + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) { - for(uword col = 0; col < local_n_cols; ++col) - { - arrayops::inplace_plus( slice_colptr(slice,col), val, local_n_rows ); - } + if(is_same_type::yes) { arrayops::inplace_plus ( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( slice_colptr(s,c), val, t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( slice_colptr(s,c), val, t_n_rows ); } } } + + + template +template inline void -subview_cube::operator-= (const eT val) +subview_cube::inplace_op(const BaseCube& in, const char* identifier) { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; - const uword local_n_slices = n_slices; + const ProxyCube P(in.get_ref()); - for(uword slice = 0; slice < local_n_slices; ++slice) + subview_cube& t = *this; + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + arma_debug_assert_same_size(t, P, identifier); + + const bool use_mp = arma_config::openmp && ProxyCube::use_mp && mp_gate::eval(t.n_elem); + const bool has_overlap = P.has_overlap(t); + + if(has_overlap) { arma_extra_debug_print("aliasing or overlap detected"); } + + if( (is_Cube::stored_type>::value) || (use_mp) || (has_overlap) ) { - for(uword col = 0; col < local_n_cols; ++col) + const unwrap_cube_check::stored_type> tmp(P.Q, has_overlap); + const Cube& B = tmp.M; + + if( (is_same_type::yes) && (t.aux_row1 == 0) && (t_n_rows == t.m.n_rows) ) { - arrayops::inplace_minus( slice_colptr(slice,col), val, local_n_rows ); + for(uword s=0; s < t_n_slices; ++s) + { + arrayops::copy( t.slice_colptr(s,0), B.slice_colptr(s,0), t.n_elem_slice ); + } + } + else + { + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + if(is_same_type::yes) { arrayops::copy ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( t.slice_colptr(s,c), B.slice_colptr(s,c), t_n_rows ); } + } + } + } + else // use the Proxy + { + if(ProxyCube::use_at) + { + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + eT* t_col_data = t.slice_colptr(s,c); + + for(uword r=0; r < t_n_rows; ++r) + { + const eT tmp = P.at(r,c,s); + + if(is_same_type::yes) { (*t_col_data) = tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) += tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) -= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) *= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) /= tmp; t_col_data++; } + } + } + } + else + { + typename ProxyCube::ea_type Pea = P.get_ea(); + + uword count = 0; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) + { + eT* t_col_data = t.slice_colptr(s,c); + + for(uword r=0; r < t_n_rows; ++r) + { + const eT tmp = Pea[count]; count++; + + if(is_same_type::yes) { (*t_col_data) = tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) += tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) -= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) *= tmp; t_col_data++; } + if(is_same_type::yes) { (*t_col_data) /= tmp; t_col_data++; } + } + } } } } @@ -118,22 +230,42 @@ subview_cube::operator-= (const eT val) template +template inline void -subview_cube::operator*= (const eT val) +subview_cube::inplace_op(const subview_cube& x, const char* identifier) { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; - const uword local_n_slices = n_slices; + if(check_overlap(x)) + { + const Cube tmp(x); + + if(is_same_type::yes) { (*this).operator= (tmp); } + if(is_same_type::yes) { (*this).operator+=(tmp); } + if(is_same_type::yes) { (*this).operator-=(tmp); } + if(is_same_type::yes) { (*this).operator%=(tmp); } + if(is_same_type::yes) { (*this).operator/=(tmp); } + + return; + } - for(uword slice = 0; slice < local_n_slices; ++slice) + subview_cube& t = *this; + + arma_debug_assert_same_size(t, x, identifier); + + const uword t_n_rows = t.n_rows; + const uword t_n_cols = t.n_cols; + const uword t_n_slices = t.n_slices; + + for(uword s=0; s < t_n_slices; ++s) + for(uword c=0; c < t_n_cols; ++c) { - for(uword col = 0; col < local_n_cols; ++col) - { - arrayops::inplace_mul( slice_colptr(slice,col), val, local_n_rows ); - } + if(is_same_type::yes) { arrayops::copy ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_plus ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_minus( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_mul ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } + if(is_same_type::yes) { arrayops::inplace_div ( t.slice_colptr(s,c), x.slice_colptr(s,c), t_n_rows ); } } } @@ -142,51 +274,79 @@ subview_cube::operator*= (const eT val) template inline void -subview_cube::operator/= (const eT val) +subview_cube::operator= (const eT val) { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; - const uword local_n_slices = n_slices; - - for(uword slice = 0; slice < local_n_slices; ++slice) + if(n_elem != 1) { - for(uword col = 0; col < local_n_cols; ++col) - { - arrayops::inplace_div( slice_colptr(slice,col), val, local_n_rows ); - } + arma_debug_assert_same_size(n_rows, n_cols, n_slices, 1, 1, 1, "copy into subcube"); } + + Cube& Q = const_cast< Cube& >(m); + + Q.at(aux_row1, aux_col1, aux_slice1) = val; } template -template inline void -subview_cube::operator= (const BaseCube& in) +subview_cube::operator+= (const eT val) { arma_extra_debug_sigprint(); - const unwrap_cube tmp(in.get_ref()); + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator-= (const eT val) + { + arma_extra_debug_sigprint(); - const Cube& x = tmp.M; - subview_cube& t = *this; + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator*= (const eT val) + { + arma_extra_debug_sigprint(); - arma_debug_assert_same_size(t, x, "copy into subcube"); + inplace_op(val); + } + + + +template +inline +void +subview_cube::operator/= (const eT val) + { + arma_extra_debug_sigprint(); - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; + inplace_op(val); + } + + + +template +template +inline +void +subview_cube::operator= (const BaseCube& in) + { + arma_extra_debug_sigprint(); - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::copy( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(in, "copy into subcube"); } @@ -199,24 +359,7 @@ subview_cube::operator+= (const BaseCube& in) { arma_extra_debug_sigprint(); - const unwrap_cube tmp(in.get_ref()); - - const Cube& x = tmp.M; - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "addition"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_plus( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(in, "addition"); } @@ -229,24 +372,7 @@ subview_cube::operator-= (const BaseCube& in) { arma_extra_debug_sigprint(); - const unwrap_cube tmp(in.get_ref()); - - const Cube& x = tmp.M; - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "subtraction"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_minus( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(in, "subtraction"); } @@ -259,24 +385,7 @@ subview_cube::operator%= (const BaseCube& in) { arma_extra_debug_sigprint(); - const unwrap_cube tmp(in.get_ref()); - - const Cube& x = tmp.M; - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "element-wise multiplication"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_mul( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(in, "element-wise multiplication"); } @@ -289,24 +398,7 @@ subview_cube::operator/= (const BaseCube& in) { arma_extra_debug_sigprint(); - const unwrap_cube tmp(in.get_ref()); - - const Cube& x = tmp.M; - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "element-wise division"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_div( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(in, "element-wise division"); } @@ -319,30 +411,7 @@ subview_cube::operator= (const subview_cube& x) { arma_extra_debug_sigprint(); - if(check_overlap(x)) - { - const Cube tmp(x); - - (*this).operator=(tmp); - - return; - } - - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "copy into subcube"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::copy( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(x, "copy into subcube"); } @@ -354,30 +423,7 @@ subview_cube::operator+= (const subview_cube& x) { arma_extra_debug_sigprint(); - if(check_overlap(x)) - { - const Cube tmp(x); - - (*this).operator+=(tmp); - - return; - } - - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "addition"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_plus( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(x, "addition"); } @@ -389,30 +435,7 @@ subview_cube::operator-= (const subview_cube& x) { arma_extra_debug_sigprint(); - if(check_overlap(x)) - { - const Cube tmp(x); - - (*this).operator-=(tmp); - - return; - } - - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "subtraction"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_minus( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(x, "subtraction"); } @@ -424,30 +447,7 @@ subview_cube::operator%= (const subview_cube& x) { arma_extra_debug_sigprint(); - if(check_overlap(x)) - { - const Cube tmp(x); - - (*this).operator%=(tmp); - - return; - } - - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "element-wise multiplication"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_mul( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(x, "element-wise multiplication"); } @@ -459,30 +459,7 @@ subview_cube::operator/= (const subview_cube& x) { arma_extra_debug_sigprint(); - if(check_overlap(x)) - { - const Cube tmp(x); - - (*this).operator/=(tmp); - - return; - } - - subview_cube& t = *this; - - arma_debug_assert_same_size(t, x, "element-wise division"); - - const uword t_n_rows = t.n_rows; - const uword t_n_cols = t.n_cols; - const uword t_n_slices = t.n_slices; - - for(uword slice = 0; slice < t_n_slices; ++slice) - { - for(uword col = 0; col < t_n_cols; ++col) - { - arrayops::inplace_div( t.slice_colptr(slice,col), x.slice_colptr(slice,col), t_n_rows ); - } - } + inplace_op(x, "element-wise division"); } @@ -495,7 +472,7 @@ subview_cube::operator= (const Base& in) { arma_extra_debug_sigprint(); - const unwrap tmp(in.get_ref()); + const quasi_unwrap tmp(in.get_ref()); const Mat& x = tmp.M; subview_cube& t = *this; @@ -600,7 +577,7 @@ subview_cube::operator+= (const Base& in) { arma_extra_debug_sigprint(); - const unwrap tmp(in.get_ref()); + const quasi_unwrap tmp(in.get_ref()); const Mat& x = tmp.M; subview_cube& t = *this; @@ -703,7 +680,7 @@ subview_cube::operator-= (const Base& in) { arma_extra_debug_sigprint(); - const unwrap tmp(in.get_ref()); + const quasi_unwrap tmp(in.get_ref()); const Mat& x = tmp.M; subview_cube& t = *this; @@ -806,7 +783,7 @@ subview_cube::operator%= (const Base& in) { arma_extra_debug_sigprint(); - const unwrap tmp(in.get_ref()); + const quasi_unwrap tmp(in.get_ref()); const Mat& x = tmp.M; subview_cube& t = *this; @@ -909,7 +886,7 @@ subview_cube::operator/= (const Base& in) { arma_extra_debug_sigprint(); - const unwrap tmp(in.get_ref()); + const quasi_unwrap tmp(in.get_ref()); const Mat& x = tmp.M; subview_cube& t = *this; @@ -1134,59 +1111,55 @@ subview_cube::imbue(functor F) -#if defined(ARMA_USE_CXX11) +//! apply a lambda function to each slice, where each slice is interpreted as a matrix +template +inline +void +subview_cube::each_slice(const std::function< void(Mat&) >& F) + { + arma_extra_debug_sigprint(); + + Mat tmp1(n_rows, n_cols, arma_nozeros_indicator()); + Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); - //! apply a lambda function to each slice, where each slice is interpreted as a matrix - template - inline - void - subview_cube::each_slice(const std::function< void(Mat&) >& F) + for(uword slice_id=0; slice_id < n_slices; ++slice_id) { - arma_extra_debug_sigprint(); + for(uword col_id=0; col_id < n_cols; ++col_id) + { + arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); + } - Mat tmp1(n_rows, n_cols); - Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); + F(tmp2); - for(uword slice_id=0; slice_id < n_slices; ++slice_id) + for(uword col_id=0; col_id < n_cols; ++col_id) { - for(uword col_id=0; col_id < n_cols; ++col_id) - { - arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); - } - - F(tmp2); - - for(uword col_id=0; col_id < n_cols; ++col_id) - { - arrayops::copy( slice_colptr(slice_id, col_id), tmp1.colptr(col_id), n_rows ); - } + arrayops::copy( slice_colptr(slice_id, col_id), tmp1.colptr(col_id), n_rows ); } } + } + + + +template +inline +void +subview_cube::each_slice(const std::function< void(const Mat&) >& F) const + { + arma_extra_debug_sigprint(); + Mat tmp1(n_rows, n_cols, arma_nozeros_indicator()); + const Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); - - template - inline - void - subview_cube::each_slice(const std::function< void(const Mat&) >& F) const + for(uword slice_id=0; slice_id < n_slices; ++slice_id) { - arma_extra_debug_sigprint(); - - Mat tmp1(n_rows, n_cols); - const Mat tmp2('j', tmp1.memptr(), n_rows, n_cols); - - for(uword slice_id=0; slice_id < n_slices; ++slice_id) + for(uword col_id=0; col_id < n_cols; ++col_id) { - for(uword col_id=0; col_id < n_cols; ++col_id) - { - arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); - } - - F(tmp2); + arrayops::copy( tmp1.colptr(col_id), slice_colptr(slice_id, col_id), n_rows ); } + + F(tmp2); } - -#endif + } @@ -1234,6 +1207,38 @@ subview_cube::clean(const typename get_pod_type::result threshold) +template +inline +void +subview_cube::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview_cube::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview_cube::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "subview_cube::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + arrayops::clamp( slice_colptr(slice,col), local_n_rows, min_val, max_val ); + } + } + } + + + template inline void @@ -1336,12 +1341,13 @@ subview_cube::randn() template inline -arma_warn_unused bool subview_cube::is_finite() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; const uword local_n_slices = n_slices; @@ -1361,12 +1367,37 @@ subview_cube::is_finite() const template inline -arma_warn_unused +bool +subview_cube::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::is_zero(slice_colptr(slice,col), local_n_rows, tol) == false) { return false; } + } + } + + return true; + } + + + +template +inline bool subview_cube::has_inf() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; const uword local_n_slices = n_slices; @@ -1386,12 +1417,13 @@ subview_cube::has_inf() const template inline -arma_warn_unused bool subview_cube::has_nan() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; const uword local_n_slices = n_slices; @@ -1409,6 +1441,32 @@ subview_cube::has_nan() const +template +inline +bool +subview_cube::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + const uword local_n_slices = n_slices; + + for(uword slice = 0; slice < local_n_slices; ++slice) + { + for(uword col = 0; col < local_n_cols; ++col) + { + if(arrayops::is_finite(slice_colptr(slice,col), local_n_rows) == false) { return true; } + } + } + + return false; + } + + + template inline eT @@ -1462,7 +1520,7 @@ inline eT& subview_cube::operator()(const uword i) { - arma_debug_check( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); const uword in_slice = i / n_elem_slice; const uword offset = in_slice * n_elem_slice; @@ -1483,7 +1541,7 @@ inline eT subview_cube::operator()(const uword i) const { - arma_debug_check( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "subview_cube::operator(): index out of bounds" ); const uword in_slice = i / n_elem_slice; const uword offset = in_slice * n_elem_slice; @@ -1504,7 +1562,7 @@ arma_inline eT& subview_cube::operator()(const uword in_row, const uword in_col, const uword in_slice) { - arma_debug_check( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); + arma_debug_check_bounds( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; @@ -1518,7 +1576,7 @@ arma_inline eT subview_cube::operator()(const uword in_row, const uword in_col, const uword in_slice) const { - arma_debug_check( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); + arma_debug_check_bounds( ( (in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices) ), "subview_cube::operator(): location out of bounds" ); const uword index = (in_slice + aux_slice1)*m.n_elem_slice + (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; @@ -1572,51 +1630,44 @@ subview_cube::slice_colptr(const uword in_slice, const uword in_col) const template +template inline bool -subview_cube::check_overlap(const subview_cube& x) const +subview_cube::check_overlap(const subview_cube& x) const { - const subview_cube& t = *this; + if(is_same_type::value == false) { return false; } - if(&t.m != &x.m) - { - return false; - } - else - { - if( (t.n_elem == 0) || (x.n_elem == 0) ) - { - return false; - } - else - { - const uword t_row_start = t.aux_row1; - const uword t_row_end_p1 = t_row_start + t.n_rows; - - const uword t_col_start = t.aux_col1; - const uword t_col_end_p1 = t_col_start + t.n_cols; - - const uword t_slice_start = t.aux_slice1; - const uword t_slice_end_p1 = t_slice_start + t.n_slices; - - - const uword x_row_start = x.aux_row1; - const uword x_row_end_p1 = x_row_start + x.n_rows; - - const uword x_col_start = x.aux_col1; - const uword x_col_end_p1 = x_col_start + x.n_cols; - - const uword x_slice_start = x.aux_slice1; - const uword x_slice_end_p1 = x_slice_start + x.n_slices; - - - const bool outside_rows = ( (x_row_start >= t_row_end_p1 ) || (t_row_start >= x_row_end_p1 ) ); - const bool outside_cols = ( (x_col_start >= t_col_end_p1 ) || (t_col_start >= x_col_end_p1 ) ); - const bool outside_slices = ( (x_slice_start >= t_slice_end_p1) || (t_slice_start >= x_slice_end_p1) ); - - return ( (outside_rows == false) && (outside_cols == false) && (outside_slices == false) ); - } - } + const subview_cube& t = (*this); + + if(void_ptr(&(t.m)) != void_ptr(&(x.m))) { return false; } + + if( (t.n_elem == 0) || (x.n_elem == 0) ) { return false; } + + const uword t_row_start = t.aux_row1; + const uword t_row_end_p1 = t_row_start + t.n_rows; + + const uword t_col_start = t.aux_col1; + const uword t_col_end_p1 = t_col_start + t.n_cols; + + const uword t_slice_start = t.aux_slice1; + const uword t_slice_end_p1 = t_slice_start + t.n_slices; + + + const uword x_row_start = x.aux_row1; + const uword x_row_end_p1 = x_row_start + x.n_rows; + + const uword x_col_start = x.aux_col1; + const uword x_col_end_p1 = x_col_start + x.n_cols; + + const uword x_slice_start = x.aux_slice1; + const uword x_slice_end_p1 = x_slice_start + x.n_slices; + + + const bool outside_rows = ( (x_row_start >= t_row_end_p1 ) || (t_row_start >= x_row_end_p1 ) ); + const bool outside_cols = ( (x_col_start >= t_col_end_p1 ) || (t_col_start >= x_col_end_p1 ) ); + const bool outside_slices = ( (x_slice_start >= t_slice_end_p1) || (t_slice_start >= x_slice_end_p1) ); + + return ( (outside_rows == false) && (outside_cols == false) && (outside_slices == false) ); } @@ -1633,7 +1684,7 @@ subview_cube::check_overlap(const Mat& x) const for(uword slice = t_aux_slice1; slice < t_aux_slice2_plus_1; ++slice) { - if(t.m.mat_ptrs[slice] != NULL) + if(t.m.mat_ptrs[slice] != nullptr) { const Mat& y = *(t.m.mat_ptrs[slice]); @@ -1661,15 +1712,22 @@ subview_cube::extract(Cube& out, const subview_cube& in) const uword n_cols = in.n_cols; const uword n_slices = in.n_slices; - arma_extra_debug_print(arma_str::format("out.n_rows = %d out.n_cols = %d out.n_slices = %d in.m.n_rows = %d in.m.n_cols = %d in.m.n_slices = %d") % out.n_rows % out.n_cols % out.n_slices % in.m.n_rows % in.m.n_cols % in.m.n_slices); - + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u out.n_slices = %u in.m.n_rows = %u in.m.n_cols = %u in.m.n_slices = %u") % out.n_rows % out.n_cols % out.n_slices % in.m.n_rows % in.m.n_cols % in.m.n_slices); - for(uword slice = 0; slice < n_slices; ++slice) + if( (in.aux_row1 == 0) && (n_rows == in.m.n_rows) ) { - for(uword col = 0; col < n_cols; ++col) + for(uword s=0; s < n_slices; ++s) { - arrayops::copy( out.slice_colptr(slice,col), in.slice_colptr(slice,col), n_rows ); + arrayops::copy( out.slice_colptr(s,0), in.slice_colptr(s,0), in.n_elem_slice ); } + + return; + } + + for(uword s=0; s < n_slices; ++s) + for(uword c=0; c < n_cols; ++c) + { + arrayops::copy( out.slice_colptr(s,c), in.slice_colptr(s,c), n_rows ); } } @@ -2342,15 +2400,15 @@ subview_cube::cend() const template inline subview_cube::iterator::iterator() - : M (NULL) - , current_ptr (NULL) - , current_row (0 ) - , current_col (0 ) - , current_slice(0 ) - , aux_row1 (0 ) - , aux_col1 (0 ) - , aux_row2_p1 (0 ) - , aux_col2_p1 (0 ) + : M (nullptr) + , current_ptr (nullptr) + , current_row (0 ) + , current_col (0 ) + , current_slice(0 ) + , aux_row1 (0 ) + , aux_col1 (0 ) + , aux_row2_p1 (0 ) + , aux_col2_p1 (0 ) { arma_extra_debug_sigprint(); // Technically this iterator is invalid (it does not point to a valid element) @@ -2396,7 +2454,6 @@ subview_cube::iterator::iterator(subview_cube& in_sv, const uword in_row template inline -arma_warn_unused eT& subview_cube::iterator::operator*() { @@ -2437,7 +2494,6 @@ subview_cube::iterator::operator++() template inline -arma_warn_unused typename subview_cube::iterator subview_cube::iterator::operator++(int) { @@ -2452,7 +2508,6 @@ subview_cube::iterator::operator++(int) template inline -arma_warn_unused bool subview_cube::iterator::operator==(const iterator& rhs) const { @@ -2463,7 +2518,6 @@ subview_cube::iterator::operator==(const iterator& rhs) const template inline -arma_warn_unused bool subview_cube::iterator::operator!=(const iterator& rhs) const { @@ -2474,7 +2528,6 @@ subview_cube::iterator::operator!=(const iterator& rhs) const template inline -arma_warn_unused bool subview_cube::iterator::operator==(const const_iterator& rhs) const { @@ -2485,7 +2538,6 @@ subview_cube::iterator::operator==(const const_iterator& rhs) const template inline -arma_warn_unused bool subview_cube::iterator::operator!=(const const_iterator& rhs) const { @@ -2503,8 +2555,8 @@ subview_cube::iterator::operator!=(const const_iterator& rhs) const template inline subview_cube::const_iterator::const_iterator() - : M (NULL) - , current_ptr (NULL) + : M (nullptr) + , current_ptr (nullptr) , current_row (0 ) , current_col (0 ) , current_slice(0 ) @@ -2575,7 +2627,6 @@ subview_cube::const_iterator::const_iterator(const subview_cube& in_sv, template inline -arma_warn_unused const eT& subview_cube::const_iterator::operator*() { @@ -2616,7 +2667,6 @@ subview_cube::const_iterator::operator++() template inline -arma_warn_unused typename subview_cube::const_iterator subview_cube::const_iterator::operator++(int) { @@ -2631,7 +2681,6 @@ subview_cube::const_iterator::operator++(int) template inline -arma_warn_unused bool subview_cube::const_iterator::operator==(const iterator& rhs) const { @@ -2642,7 +2691,6 @@ subview_cube::const_iterator::operator==(const iterator& rhs) const template inline -arma_warn_unused bool subview_cube::const_iterator::operator!=(const iterator& rhs) const { @@ -2653,7 +2701,6 @@ subview_cube::const_iterator::operator!=(const iterator& rhs) const template inline -arma_warn_unused bool subview_cube::const_iterator::operator==(const const_iterator& rhs) const { @@ -2664,7 +2711,6 @@ subview_cube::const_iterator::operator==(const const_iterator& rhs) const template inline -arma_warn_unused bool subview_cube::const_iterator::operator!=(const const_iterator& rhs) const { diff --git a/src/armadillo_bits/subview_cube_slices_bones.hpp b/src/armadillo_bits/subview_cube_slices_bones.hpp index 47a74c3f..e19890f1 100644 --- a/src/armadillo_bits/subview_cube_slices_bones.hpp +++ b/src/armadillo_bits/subview_cube_slices_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,15 +22,14 @@ template -class subview_cube_slices : public BaseCube > +class subview_cube_slices : public BaseCube< eT, subview_cube_slices > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - arma_aligned const Cube& m; - + arma_aligned const Cube& m; arma_aligned const Base& base_si; @@ -40,6 +41,7 @@ class subview_cube_slices : public BaseCube > public: inline ~subview_cube_slices(); + inline subview_cube_slices() = delete; inline void inplace_rand(const uword rand_mode); @@ -82,11 +84,7 @@ class subview_cube_slices : public BaseCube > inline static void div_inplace(Cube& out, const subview_cube_slices& in); - - private: - friend class Cube; - subview_cube_slices(); }; diff --git a/src/armadillo_bits/subview_cube_slices_meat.hpp b/src/armadillo_bits/subview_cube_slices_meat.hpp index 67f6e570..f520da07 100644 --- a/src/armadillo_bits/subview_cube_slices_meat.hpp +++ b/src/armadillo_bits/subview_cube_slices_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -59,7 +61,7 @@ subview_cube_slices::inplace_rand(const uword rand_mode) arma_debug_check ( ( (si.is_vec() == false) && (si.is_empty() == false) ), - "Cube::slices(): given object is not a vector" + "Cube::slices(): given object must be a vector" ); const uword* si_mem = si.memptr(); @@ -69,7 +71,7 @@ subview_cube_slices::inplace_rand(const uword rand_mode) { const uword i = si_mem[si_count]; - arma_debug_check( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); eT* m_slice_ptr = m_local.slice_memptr(i); @@ -99,7 +101,7 @@ subview_cube_slices::inplace_op(const eT val) arma_debug_check ( ( (si.is_vec() == false) && (si.is_empty() == false) ), - "Cube::slices(): given object is not a vector" + "Cube::slices(): given object must be a vector" ); const uword* si_mem = si.memptr(); @@ -109,7 +111,7 @@ subview_cube_slices::inplace_op(const eT val) { const uword i = si_mem[si_count]; - arma_debug_check( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); eT* m_slice_ptr = m_local.slice_memptr(i); @@ -142,7 +144,7 @@ subview_cube_slices::inplace_op(const BaseCube& x) arma_debug_check ( ( (si.is_vec() == false) && (si.is_empty() == false) ), - "Cube::slices(): given object is not a vector" + "Cube::slices(): given object must be a vector" ); const uword* si_mem = si.memptr(); @@ -157,7 +159,7 @@ subview_cube_slices::inplace_op(const BaseCube& x) { const uword i = si_mem[si_count]; - arma_debug_check( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); eT* m_slice_ptr = m_local.slice_memptr(i); const eT* X_slice_ptr = X.slice_memptr(si_count); @@ -470,7 +472,7 @@ subview_cube_slices::extract(Cube& out, const subview_cube_slices::extract(Cube& out, const subview_cube_slices= m_n_slices), "Cube::slices(): index out of bounds" ); + arma_debug_check_bounds( (i >= m_n_slices), "Cube::slices(): index out of bounds" ); eT* out_slice_ptr = out.slice_memptr(si_count); const eT* m_slice_ptr = m_local.slice_memptr(i); diff --git a/src/armadillo_bits/subview_each_bones.hpp b/src/armadillo_bits/subview_each_bones.hpp index c2871143..dcb58cd8 100644 --- a/src/armadillo_bits/subview_each_bones.hpp +++ b/src/armadillo_bits/subview_each_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -28,24 +30,22 @@ class subview_each_common const parent& P; - inline void check_size(const Mat& A) const; + template + inline void check_size(const Mat& A) const; protected: arma_inline subview_each_common(const parent& in_P); + inline subview_each_common() = delete; arma_inline const Mat& get_mat_ref_helper(const Mat & X) const; arma_inline const Mat& get_mat_ref_helper(const subview& X) const; arma_inline const Mat& get_mat_ref() const; - arma_cold inline const std::string incompat_size_string(const Mat& A) const; - - - private: - - subview_each_common(); + template + arma_cold inline const std::string incompat_size_string(const Mat& A) const; }; @@ -64,6 +64,7 @@ class subview_each1 : public subview_each_common typedef typename parent::elem_type eT; inline ~subview_each1(); + inline subview_each1() = delete; // deliberately returning void template inline void operator= (const Base& x); @@ -73,8 +74,6 @@ class subview_each1 : public subview_each_common template inline void operator/= (const Base& x); - private: - friend class Mat; friend class subview; }; @@ -96,7 +95,9 @@ class subview_each2 : public subview_each_common typedef typename parent::elem_type eT; inline void check_indices(const Mat& indices) const; + inline ~subview_each2(); + inline subview_each2() = delete; // deliberately returning void template inline void operator= (const Base& x); @@ -106,8 +107,6 @@ class subview_each2 : public subview_each_common template inline void operator/= (const Base& x); - private: - friend class Mat; friend class subview; }; @@ -120,7 +119,7 @@ class subview_each1_aux template static inline Mat operator_plus(const subview_each1& X, const Base& Y); - + template static inline Mat operator_minus(const subview_each1& X, const Base& Y); @@ -145,7 +144,7 @@ class subview_each2_aux template static inline Mat operator_plus(const subview_each2& X, const Base& Y); - + template static inline Mat operator_minus(const subview_each2& X, const Base& Y); diff --git a/src/armadillo_bits/subview_each_meat.hpp b/src/armadillo_bits/subview_each_meat.hpp index 16c14baf..12d263ef 100644 --- a/src/armadillo_bits/subview_each_meat.hpp +++ b/src/armadillo_bits/subview_each_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -63,11 +65,12 @@ subview_each_common::get_mat_ref() const template +template inline void -subview_each_common::check_size(const Mat& A) const +subview_each_common::check_size(const Mat& A) const { - if(arma_config::debug == true) + if(arma_config::debug) { if(mode == 0) { @@ -89,10 +92,10 @@ subview_each_common::check_size(const Mat -arma_cold +template inline const std::string -subview_each_common::incompat_size_string(const Mat& A) const +subview_each_common::incompat_size_string(const Mat& A) const { std::ostringstream tmp; @@ -190,7 +193,7 @@ subview_each1::operator+= (const Base& in) const eT* A_mem = A.memptr(); const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - + if(mode == 0) // each column { for(uword i=0; i < p_n_cols; ++i) @@ -227,7 +230,7 @@ subview_each1::operator-= (const Base& in) const eT* A_mem = A.memptr(); const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - + if(mode == 0) // each column { for(uword i=0; i < p_n_cols; ++i) @@ -395,7 +398,7 @@ subview_each2::operator= (const Base& in) { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::copy( p.colptr(col), A_mem, p_n_rows ); } @@ -406,7 +409,7 @@ subview_each2::operator= (const Base& in) { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); for(uword col=0; col < p_n_cols; ++col) { @@ -451,7 +454,7 @@ subview_each2::operator+= (const Base& in) { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_plus( p.colptr(col), A_mem, p_n_rows ); } @@ -462,7 +465,7 @@ subview_each2::operator+= (const Base& in) { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); p.row(row) += A; } @@ -504,7 +507,7 @@ subview_each2::operator-= (const Base& in) { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_minus( p.colptr(col), A_mem, p_n_rows ); } @@ -515,7 +518,7 @@ subview_each2::operator-= (const Base& in) { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); p.row(row) -= A; } @@ -557,7 +560,7 @@ subview_each2::operator%= (const Base& in) { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_mul( p.colptr(col), A_mem, p_n_rows ); } @@ -568,7 +571,7 @@ subview_each2::operator%= (const Base& in) { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); p.row(row) %= A; } @@ -610,7 +613,7 @@ subview_each2::operator/= (const Base& in) { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_div( p.colptr(col), A_mem, p_n_rows ); } @@ -621,7 +624,7 @@ subview_each2::operator/= (const Base& in) { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); p.row(row) /= A; } @@ -654,7 +657,7 @@ subview_each1_aux::operator_plus const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -716,7 +719,7 @@ subview_each1_aux::operator_minus const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -778,7 +781,7 @@ subview_each1_aux::operator_minus const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(X.get_ref()); const Mat& A = tmp.M; @@ -840,7 +843,7 @@ subview_each1_aux::operator_schur const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -902,7 +905,7 @@ subview_each1_aux::operator_div const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(Y.get_ref()); const Mat& A = tmp.M; @@ -964,7 +967,7 @@ subview_each1_aux::operator_div const uword p_n_rows = p.n_rows; const uword p_n_cols = p.n_cols; - Mat out(p_n_rows, p_n_cols); + Mat out(p_n_rows, p_n_cols, arma_nozeros_indicator()); const quasi_unwrap tmp(X.get_ref()); const Mat& A = tmp.M; @@ -1053,7 +1056,7 @@ subview_each2_aux::operator_plus { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_plus( out.colptr(col), A_mem, p_n_rows ); } @@ -1065,7 +1068,7 @@ subview_each2_aux::operator_plus { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) += A; } @@ -1115,7 +1118,7 @@ subview_each2_aux::operator_minus { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_minus( out.colptr(col), A_mem, p_n_rows ); } @@ -1127,7 +1130,7 @@ subview_each2_aux::operator_minus { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) -= A; } @@ -1177,7 +1180,7 @@ subview_each2_aux::operator_minus { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); const eT* p_mem = p.colptr(col); eT* out_mem = out.colptr(col); @@ -1195,7 +1198,7 @@ subview_each2_aux::operator_minus { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) = A - p.row(row); } @@ -1245,7 +1248,7 @@ subview_each2_aux::operator_schur { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_mul( out.colptr(col), A_mem, p_n_rows ); } @@ -1257,7 +1260,7 @@ subview_each2_aux::operator_schur { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) %= A; } @@ -1307,7 +1310,7 @@ subview_each2_aux::operator_div { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); arrayops::inplace_div( out.colptr(col), A_mem, p_n_rows ); } @@ -1319,7 +1322,7 @@ subview_each2_aux::operator_div { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) /= A; } @@ -1369,7 +1372,7 @@ subview_each2_aux::operator_div { const uword col = indices_mem[i]; - arma_debug_check( (col >= p_n_cols), "each_col(): index out of bounds" ); + arma_debug_check_bounds( (col >= p_n_cols), "each_col(): index out of bounds" ); const eT* p_mem = p.colptr(col); eT* out_mem = out.colptr(col); @@ -1387,7 +1390,7 @@ subview_each2_aux::operator_div { const uword row = indices_mem[i]; - arma_debug_check( (row >= p_n_rows), "each_row(): index out of bounds" ); + arma_debug_check_bounds( (row >= p_n_rows), "each_row(): index out of bounds" ); out.row(row) = A / p.row(row); } diff --git a/src/armadillo_bits/subview_elem1_bones.hpp b/src/armadillo_bits/subview_elem1_bones.hpp index 7ce66fb1..2ac3cdad 100644 --- a/src/armadillo_bits/subview_elem1_bones.hpp +++ b/src/armadillo_bits/subview_elem1_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class subview_elem1 : public Base > +class subview_elem1 : public Base< eT, subview_elem1 > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = true; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = true; + static constexpr bool is_xvec = false; arma_aligned const Mat fake_m; arma_aligned const Mat& m; @@ -45,6 +47,7 @@ class subview_elem1 : public Base > public: inline ~subview_elem1(); + inline subview_elem1() = delete; template inline void inplace_op(const eT val); template inline void inplace_op(const subview_elem1& x ); @@ -56,6 +59,10 @@ class subview_elem1 : public Base > inline void replace(const eT old_val, const eT new_val); + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); @@ -93,13 +100,8 @@ class subview_elem1 : public Base > inline static void div_inplace(Mat& out, const subview_elem1& in); - - private: - friend class Mat; friend class Cube; - - subview_elem1(); }; diff --git a/src/armadillo_bits/subview_elem1_meat.hpp b/src/armadillo_bits/subview_elem1_meat.hpp index 8a991e04..d1b67128 100644 --- a/src/armadillo_bits/subview_elem1_meat.hpp +++ b/src/armadillo_bits/subview_elem1_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -33,8 +35,6 @@ subview_elem1::subview_elem1(const Mat& in_m, const Base& i , a(in_a) { arma_extra_debug_sigprint(); - - // TODO: refactor to unwrap 'in_a' instead of storing a ref to it; this will allow removal of carrying T1 around and repetition of size checks } @@ -70,7 +70,7 @@ subview_elem1::inplace_op(const eT val) arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -82,7 +82,7 @@ subview_elem1::inplace_op(const eT val) const uword ii = aa_mem[iq]; const uword jj = aa_mem[jq]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = val; m_mem[jj] = val; } if(is_same_type::yes) { m_mem[ii] += val; m_mem[jj] += val; } @@ -95,7 +95,7 @@ subview_elem1::inplace_op(const eT val) { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = val; } if(is_same_type::yes) { m_mem[ii] += val; } @@ -143,7 +143,7 @@ subview_elem1::inplace_op(const subview_elem1& x) arma_debug_check ( ( ((s_aa.is_vec() == false) && (s_aa.is_empty() == false)) || ((x_aa.is_vec() == false) && (x_aa.is_empty() == false)) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* s_aa_mem = s_aa.memptr(); @@ -169,7 +169,7 @@ subview_elem1::inplace_op(const subview_elem1& x) const uword x_ii = x_aa_mem[iq]; const uword x_jj = x_aa_mem[jq]; - arma_debug_check + arma_debug_check_bounds ( (s_ii >= s_m_n_elem) || (s_jj >= s_m_n_elem) || (x_ii >= x_m_n_elem) || (x_jj >= x_m_n_elem), "Mat::elem(): index out of bounds" @@ -187,7 +187,7 @@ subview_elem1::inplace_op(const subview_elem1& x) const uword s_ii = s_aa_mem[iq]; const uword x_ii = x_aa_mem[iq]; - arma_debug_check + arma_debug_check_bounds ( ( (s_ii >= s_m_n_elem) || (x_ii >= x_m_n_elem) ), "Mat::elem(): index out of bounds" @@ -223,7 +223,7 @@ subview_elem1::inplace_op(const Base& x) arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -245,7 +245,7 @@ subview_elem1::inplace_op(const Base& x) const uword ii = aa_mem[iq]; const uword jj = aa_mem[jq]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = X[iq]; m_mem[jj] = X[jq]; } if(is_same_type::yes) { m_mem[ii] += X[iq]; m_mem[jj] += X[jq]; } @@ -258,7 +258,7 @@ subview_elem1::inplace_op(const Base& x) { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = X[iq]; } if(is_same_type::yes) { m_mem[ii] += X[iq]; } @@ -282,7 +282,7 @@ subview_elem1::inplace_op(const Base& x) const uword ii = aa_mem[iq]; const uword jj = aa_mem[jq]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = X[iq]; m_mem[jj] = X[jq]; } if(is_same_type::yes) { m_mem[ii] += X[iq]; m_mem[jj] += X[jq]; } @@ -295,7 +295,7 @@ subview_elem1::inplace_op(const Base& x) { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_mem[ii] = X[iq]; } if(is_same_type::yes) { m_mem[ii] += X[iq]; } @@ -361,7 +361,7 @@ subview_elem1::replace(const eT old_val, const eT new_val) arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -373,7 +373,7 @@ subview_elem1::replace(const eT old_val, const eT new_val) { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); eT& val = m_mem[ii]; @@ -386,7 +386,7 @@ subview_elem1::replace(const eT old_val, const eT new_val) { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem), "Mat::elem(): index out of bounds" ); eT& val = m_mem[ii]; @@ -397,6 +397,38 @@ subview_elem1::replace(const eT old_val, const eT new_val) +template +inline +void +subview_elem1::clean(const pod_type threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem1::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + template inline void @@ -451,7 +483,7 @@ subview_elem1::randu() arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -463,7 +495,7 @@ subview_elem1::randu() const uword ii = aa_mem[iq]; const uword jj = aa_mem[jq]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); const eT val1 = eT(arma_rng::randu()); const eT val2 = eT(arma_rng::randu()); @@ -476,7 +508,7 @@ subview_elem1::randu() { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); m_mem[ii] = eT(arma_rng::randu()); } @@ -502,7 +534,7 @@ subview_elem1::randn() arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -514,7 +546,7 @@ subview_elem1::randn() const uword ii = aa_mem[iq]; const uword jj = aa_mem[jq]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); arma_rng::randn::dual_val( m_mem[ii], m_mem[jj] ); } @@ -523,7 +555,7 @@ subview_elem1::randn() { const uword ii = aa_mem[iq]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); m_mem[ii] = eT(arma_rng::randn()); } @@ -759,7 +791,7 @@ subview_elem1::extract(Mat& actual_out, const subview_elem1& i arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -774,7 +806,7 @@ subview_elem1::extract(Mat& actual_out, const subview_elem1& i if(alias) { arma_extra_debug_print("subview_elem1::extract(): aliasing detected"); } - Mat* tmp_out = alias ? new Mat() : 0; + Mat* tmp_out = alias ? new Mat() : nullptr; Mat& out = alias ? *tmp_out : actual_out; out.set_size(aa_n_elem, 1); @@ -787,7 +819,7 @@ subview_elem1::extract(Mat& actual_out, const subview_elem1& i const uword ii = aa_mem[i]; const uword jj = aa_mem[j]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); out_mem[i] = m_mem[ii]; out_mem[j] = m_mem[jj]; @@ -797,7 +829,7 @@ subview_elem1::extract(Mat& actual_out, const subview_elem1& i { const uword ii = aa_mem[i]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); out_mem[i] = m_mem[ii]; } @@ -825,7 +857,7 @@ subview_elem1::mat_inplace_op(Mat& out, const subview_elem1& in) arma_debug_check ( ( (aa.is_vec() == false) && (aa.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* aa_mem = aa.memptr(); @@ -847,7 +879,7 @@ subview_elem1::mat_inplace_op(Mat& out, const subview_elem1& in) const uword ii = aa_mem[i]; const uword jj = aa_mem[j]; - arma_debug_check( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( ( (ii >= m_n_elem) || (jj >= m_n_elem) ), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { out_mem[i] += m_mem[ii]; out_mem[j] += m_mem[jj]; } if(is_same_type::yes) { out_mem[i] -= m_mem[ii]; out_mem[j] -= m_mem[jj]; } @@ -859,7 +891,7 @@ subview_elem1::mat_inplace_op(Mat& out, const subview_elem1& in) { const uword ii = aa_mem[i]; - arma_debug_check( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (ii >= m_n_elem) , "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { out_mem[i] += m_mem[ii]; } if(is_same_type::yes) { out_mem[i] -= m_mem[ii]; } diff --git a/src/armadillo_bits/subview_elem2_bones.hpp b/src/armadillo_bits/subview_elem2_bones.hpp index 459fb119..d4c4cbe6 100644 --- a/src/armadillo_bits/subview_elem2_bones.hpp +++ b/src/armadillo_bits/subview_elem2_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,16 +22,16 @@ template -class subview_elem2 : public Base > +class subview_elem2 : public Base< eT, subview_elem2 > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; arma_aligned const Mat& m; @@ -48,6 +50,7 @@ class subview_elem2 : public Base > public: inline ~subview_elem2(); + inline subview_elem2() = delete; template inline void inplace_op(const eT val); @@ -55,6 +58,12 @@ class subview_elem2 : public Base > template inline void inplace_op(const Base& x); + inline void replace(const eT old_val, const eT new_val); + + inline void clean(const pod_type threshold); + + inline void clamp(const eT min_val, const eT max_val); + inline void fill(const eT val); inline void zeros(); inline void ones(); @@ -81,6 +90,12 @@ class subview_elem2 : public Base > template inline void operator%= (const Base& x); template inline void operator/= (const Base& x); + template inline void operator= (const SpBase& x); + template inline void operator+= (const SpBase& x); + template inline void operator-= (const SpBase& x); + template inline void operator%= (const SpBase& x); + template inline void operator/= (const SpBase& x); + inline static void extract(Mat& out, const subview_elem2& in); inline static void plus_inplace(Mat& out, const subview_elem2& in); @@ -89,11 +104,7 @@ class subview_elem2 : public Base > inline static void div_inplace(Mat& out, const subview_elem2& in); - - private: - friend class Mat; - subview_elem2(); }; diff --git a/src/armadillo_bits/subview_elem2_meat.hpp b/src/armadillo_bits/subview_elem2_meat.hpp index baee228f..69d5f5d7 100644 --- a/src/armadillo_bits/subview_elem2_meat.hpp +++ b/src/armadillo_bits/subview_elem2_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -71,7 +73,7 @@ subview_elem2::inplace_op(const eT val) arma_debug_check ( ( ((ri.is_vec() == false) && (ri.is_empty() == false)) || ((ci.is_vec() == false) && (ci.is_empty() == false)) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ri_mem = ri.memptr(); @@ -84,13 +86,13 @@ subview_elem2::inplace_op(const eT val) { const uword col = ci_mem[ci_count]; - arma_debug_check( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) { const uword row = ri_mem[ri_count]; - arma_debug_check( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_local.at(row,col) = val; } if(is_same_type::yes) { m_local.at(row,col) += val; } @@ -110,7 +112,7 @@ subview_elem2::inplace_op(const eT val) arma_debug_check ( ( (ci.is_vec() == false) && (ci.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ci_mem = ci.memptr(); @@ -120,7 +122,7 @@ subview_elem2::inplace_op(const eT val) { const uword col = ci_mem[ci_count]; - arma_debug_check( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); eT* colptr = m_local.colptr(col); @@ -141,7 +143,7 @@ subview_elem2::inplace_op(const eT val) arma_debug_check ( ( (ri.is_vec() == false) && (ri.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ri_mem = ri.memptr(); @@ -153,7 +155,7 @@ subview_elem2::inplace_op(const eT val) { const uword row = ri_mem[ri_count]; - arma_debug_check( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_local.at(row,col) = val; } if(is_same_type::yes) { m_local.at(row,col) += val; } @@ -194,7 +196,7 @@ subview_elem2::inplace_op(const Base& x) arma_debug_check ( ( ((ri.is_vec() == false) && (ri.is_empty() == false)) || ((ci.is_vec() == false) && (ci.is_empty() == false)) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ri_mem = ri.memptr(); @@ -209,13 +211,13 @@ subview_elem2::inplace_op(const Base& x) { const uword col = ci_mem[ci_count]; - arma_debug_check( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) { const uword row = ri_mem[ri_count]; - arma_debug_check( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_local.at(row,col) = X.at(ri_count, ci_count); } if(is_same_type::yes) { m_local.at(row,col) += X.at(ri_count, ci_count); } @@ -235,7 +237,7 @@ subview_elem2::inplace_op(const Base& x) arma_debug_check ( ( (ci.is_vec() == false) && (ci.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ci_mem = ci.memptr(); @@ -247,7 +249,7 @@ subview_elem2::inplace_op(const Base& x) { const uword col = ci_mem[ci_count]; - arma_debug_check( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); eT* m_colptr = m_local.colptr(col); const eT* X_colptr = X.colptr(ci_count); @@ -269,7 +271,7 @@ subview_elem2::inplace_op(const Base& x) arma_debug_check ( ( (ri.is_vec() == false) && (ri.is_empty() == false) ), - "Mat::elem(): given object is not a vector" + "Mat::elem(): given object must be a vector" ); const uword* ri_mem = ri.memptr(); @@ -283,7 +285,7 @@ subview_elem2::inplace_op(const Base& x) { const uword row = ri_mem[ri_count]; - arma_debug_check( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); if(is_same_type::yes) { m_local.at(row,col) = X.at(ri_count, col); } if(is_same_type::yes) { m_local.at(row,col) += X.at(ri_count, col); } @@ -302,6 +304,54 @@ subview_elem2::inplace_op(const Base& x) +template +inline +void +subview_elem2::replace(const eT old_val, const eT new_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.replace(old_val, new_val); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem2::clean(const pod_type threshold) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clean(threshold); + + (*this).operator=(tmp); + } + + + +template +inline +void +subview_elem2::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + Mat tmp(*this); + + tmp.clamp(min_val, max_val); + + (*this).operator=(tmp); + } + + + template inline void @@ -553,6 +603,86 @@ subview_elem2::operator/= (const Base& x) +template +template +inline +void +subview_elem2::operator= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator+= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator-= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator%= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +template +template +inline +void +subview_elem2::operator/= (const SpBase& x) + { + arma_extra_debug_sigprint(); + + const Mat tmp(x); + + inplace_op(tmp); + } + + + +// +// + + + template inline void @@ -569,7 +699,7 @@ subview_elem2::extract(Mat& actual_out, const subview_elem2* tmp_out = alias ? new Mat() : 0; + Mat* tmp_out = alias ? new Mat() : nullptr; Mat& out = alias ? *tmp_out : actual_out; if( (in.all_rows == false) && (in.all_cols == false) ) @@ -583,7 +713,7 @@ subview_elem2::extract(Mat& actual_out, const subview_elem2::extract(Mat& actual_out, const subview_elem2= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); for(uword ri_count=0; ri_count < ri_n_elem; ++ri_count) { const uword row = ri_mem[ri_count]; - arma_debug_check( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); out_mem[out_count] = m_local.at(row,col); ++out_count; @@ -624,7 +754,7 @@ subview_elem2::extract(Mat& actual_out, const subview_elem2::extract(Mat& actual_out, const subview_elem2= m_n_cols), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (col >= m_n_cols), "Mat::elem(): index out of bounds" ); arrayops::copy( out.colptr(ci_count), m_local.colptr(col), m_n_rows ); } @@ -651,7 +781,7 @@ subview_elem2::extract(Mat& actual_out, const subview_elem2::extract(Mat& actual_out, const subview_elem2= m_n_rows), "Mat::elem(): index out of bounds" ); + arma_debug_check_bounds( (row >= m_n_rows), "Mat::elem(): index out of bounds" ); out.at(ri_count,col) = m_local.at(row,col); } diff --git a/src/armadillo_bits/subview_field_bones.hpp b/src/armadillo_bits/subview_field_bones.hpp index aca7ca0d..8ea83156 100644 --- a/src/armadillo_bits/subview_field_bones.hpp +++ b/src/armadillo_bits/subview_field_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,7 +21,7 @@ //! Class for storing data required to construct or apply operations to a subfield -//! (i.e. where the subfield starts and ends as well as a reference/pointer to the original field), +//! (ie. where the subfield starts and ends as well as a reference/pointer to the original field), template class subview_field { @@ -48,29 +50,30 @@ class subview_field public: inline ~subview_field(); + inline subview_field() = delete; inline void operator= (const field& x); inline void operator= (const subview_field& x); - arma_inline oT& operator[](const uword i); - arma_inline const oT& operator[](const uword i) const; + arma_warn_unused arma_inline oT& operator[](const uword i); + arma_warn_unused arma_inline const oT& operator[](const uword i) const; - arma_inline oT& operator()(const uword i); - arma_inline const oT& operator()(const uword i) const; + arma_warn_unused arma_inline oT& operator()(const uword i); + arma_warn_unused arma_inline const oT& operator()(const uword i) const; - arma_inline oT& at(const uword row, const uword col); - arma_inline const oT& at(const uword row, const uword col) const; + arma_warn_unused arma_inline oT& at(const uword row, const uword col); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col) const; - arma_inline oT& at(const uword row, const uword col, const uword slice); - arma_inline const oT& at(const uword row, const uword col, const uword slice) const; + arma_warn_unused arma_inline oT& at(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& at(const uword row, const uword col, const uword slice) const; - arma_inline oT& operator()(const uword row, const uword col); - arma_inline const oT& operator()(const uword row, const uword col) const; + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col) const; - arma_inline oT& operator()(const uword row, const uword col, const uword slice); - arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; + arma_warn_unused arma_inline oT& operator()(const uword row, const uword col, const uword slice); + arma_warn_unused arma_inline const oT& operator()(const uword row, const uword col, const uword slice) const; - arma_inline bool is_empty() const; + arma_warn_unused arma_inline bool is_empty() const; inline bool check_overlap(const subview_field& x) const; @@ -85,13 +88,7 @@ class subview_field inline static void extract(field& out, const subview_field& in); - private: - friend class field; - - - subview_field(); - //subview_field(const subview_field&); }; diff --git a/src/armadillo_bits/subview_field_meat.hpp b/src/armadillo_bits/subview_field_meat.hpp index a3ef0f57..dafc4df4 100644 --- a/src/armadillo_bits/subview_field_meat.hpp +++ b/src/armadillo_bits/subview_field_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -198,7 +200,7 @@ arma_inline oT& subview_field::operator()(const uword i) { - arma_debug_check( (i >= n_elem), "subview_field::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "subview_field::operator(): index out of bounds" ); return operator[](i); } @@ -210,7 +212,7 @@ arma_inline const oT& subview_field::operator()(const uword i) const { - arma_debug_check( (i >= n_elem), "subview_field::operator(): index out of bounds" ); + arma_debug_check_bounds( (i >= n_elem), "subview_field::operator(): index out of bounds" ); return operator[](i); } @@ -242,7 +244,7 @@ arma_inline oT& subview_field::operator()(const uword in_row, const uword in_col, const uword in_slice) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; @@ -256,7 +258,7 @@ arma_inline const oT& subview_field::operator()(const uword in_row, const uword in_col, const uword in_slice) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols) || (in_slice >= n_slices)), "subview_field::operator(): index out of bounds" ); const uword index = (in_slice + aux_slice1)*(f.n_rows*f.n_cols) + (in_col + aux_col1)*f.n_rows + aux_row1 + in_row; @@ -512,7 +514,7 @@ subview_field::extract(field& actual_out, const subview_field& in) // const bool alias = (&actual_out == &in.f); - field* tmp = (alias) ? new field : 0; + field* tmp = (alias) ? new field : nullptr; field& out = (alias) ? (*tmp) : actual_out; // @@ -523,7 +525,7 @@ subview_field::extract(field& actual_out, const subview_field& in) out.set_size(n_rows, n_cols, n_slices); - arma_extra_debug_print(arma_str::format("out.n_rows = %d out.n_cols = %d out.n_slices = %d in.m.n_rows = %d in.m.n_cols = %d in.m.n_slices = %d") % out.n_rows % out.n_cols % out.n_slices % in.f.n_rows % in.f.n_cols % in.f.n_slices); + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u out.n_slices = %u in.m.n_rows = %u in.m.n_cols = %u in.m.n_slices = %u") % out.n_rows % out.n_cols % out.n_slices % in.f.n_rows % in.f.n_cols % in.f.n_slices); if(n_slices == 1) { diff --git a/src/armadillo_bits/subview_meat.hpp b/src/armadillo_bits/subview_meat.hpp index c64c98f7..543383db 100644 --- a/src/armadillo_bits/subview_meat.hpp +++ b/src/armadillo_bits/subview_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,40 +24,60 @@ template inline subview::~subview() { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); } + template inline subview::subview(const Mat& in_m, const uword in_row1, const uword in_col1, const uword in_n_rows, const uword in_n_cols) - : m(in_m) - , aux_row1(in_row1) - , aux_col1(in_col1) - , n_rows(in_n_rows) - , n_cols(in_n_cols) - , n_elem(in_n_rows*in_n_cols) + : m (in_m ) + , aux_row1(in_row1 ) + , aux_col1(in_col1 ) + , n_rows (in_n_rows) + , n_cols (in_n_cols) + , n_elem (in_n_rows*in_n_cols) { - arma_extra_debug_sigprint(); + arma_extra_debug_sigprint_this(this); } template inline -void -subview::operator= (const eT val) +subview::subview(const subview& in) + : m (in.m ) + , aux_row1(in.aux_row1) + , aux_col1(in.aux_col1) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) { - arma_extra_debug_sigprint(); - - if(n_elem != 1) - { - arma_debug_assert_same_size(n_rows, n_cols, 1, 1, "copy into submatrix"); - } + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); + } + + + +template +inline +subview::subview(subview&& in) + : m (in.m ) + , aux_row1(in.aux_row1) + , aux_col1(in.aux_col1) + , n_rows (in.n_rows ) + , n_cols (in.n_cols ) + , n_elem (in.n_elem ) + { + arma_extra_debug_sigprint(arma_str::format("this = %x in = %x") % this % &in); - Mat& X = const_cast< Mat& >(m); + // for paranoia - X.at(aux_row1, aux_col1) = val; + access::rw(in.aux_row1) = 0; + access::rw(in.aux_col1) = 0; + access::rw(in.n_rows ) = 0; + access::rw(in.n_cols ) = 0; + access::rw(in.n_elem ) = 0; } @@ -129,7 +151,7 @@ subview::inplace_op(const Base& in, const char* identifier) arma_debug_assert_same_size(s, P, identifier); - const bool use_mp = arma_config::cxx11 && arma_config::openmp && Proxy::use_mp && mp_gate::eval(s.n_elem); + const bool use_mp = arma_config::openmp && Proxy::use_mp && mp_gate::eval(s.n_elem); const bool has_overlap = P.has_overlap(s); if(has_overlap) { arma_extra_debug_print("aliasing or overlap detected"); } @@ -172,9 +194,13 @@ subview::inplace_op(const Base& in, const char* identifier) } else // not a row vector { - if( (is_same_type::yes) && (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) + if((s.aux_row1 == 0) && (s_n_rows == s.m.n_rows)) { - arrayops::copy( s.colptr(0), B.memptr(), s.n_elem ); + if(is_same_type::yes) { arrayops::copy ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_plus ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_minus( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_mul ( s.colptr(0), B.memptr(), s.n_elem ); } + if(is_same_type::yes) { arrayops::inplace_div ( s.colptr(0), B.memptr(), s.n_elem ); } } else { @@ -373,6 +399,25 @@ subview::inplace_op(const subview& x, const char* identifier) +template +inline +void +subview::operator= (const eT val) + { + arma_extra_debug_sigprint(); + + if(n_elem != 1) + { + arma_debug_assert_same_size(n_rows, n_cols, 1, 1, "copy into submatrix"); + } + + Mat& X = const_cast< Mat& >(m); + + X.at(aux_row1, aux_col1) = val; + } + + + template inline void @@ -687,8 +732,8 @@ subview::operator/=(const SpBase& x) // This is probably going to fill your subview with a bunch of NaNs, // so I'm not going to bother to implement it fast. // You can have slow NaNs. They're fine too. - for (uword c = 0; c < n_cols; ++c) - for (uword r = 0; r < n_rows; ++r) + for(uword c = 0; c < n_cols; ++c) + for(uword r = 0; r < n_rows; ++r) { at(r, c) /= p.at(r, c); } @@ -711,6 +756,50 @@ subview::operator= (const Gen& in) +template +inline +void +subview::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + arma_debug_check( (is_vec() == false), "copy into submatrix: size mismatch" ); + + const uword N = uword(list.size()); + + if(n_rows == 1) + { + arma_debug_assert_same_size(1, n_cols, 1, N, "copy into submatrix"); + + auto it = list.begin(); + + for(uword ii=0; ii < N; ++ii) { (*this).at(0,ii) = (*it); ++it; } + } + else + if(n_cols == 1) + { + arma_debug_assert_same_size(n_rows, 1, N, 1, "copy into submatrix"); + + arrayops::copy( (*this).colptr(0), list.begin(), N ); + } + } + + + +template +inline +void +subview::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + const Mat tmp(list); + + (*this).operator=(tmp); + } + + + //! apply a functor to each element template template @@ -940,6 +1029,36 @@ subview::clean(const typename get_pod_type::result threshold) +template +inline +void +subview::clamp(const eT min_val, const eT max_val) + { + arma_extra_debug_sigprint(); + + if(is_cx::no) + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview::clamp(): min_val must be less than max_val" ); + } + else + { + arma_debug_check( (access::tmp_real(min_val) > access::tmp_real(max_val)), "subview::clamp(): real(min_val) must be less than real(max_val)" ); + arma_debug_check( (access::tmp_imag(min_val) > access::tmp_imag(max_val)), "subview::clamp(): imag(min_val) must be less than imag(max_val)" ); + } + + subview& s = *this; + + const uword s_n_cols = s.n_cols; + const uword s_n_rows = s.n_rows; + + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::clamp( s.colptr(ucol), s_n_rows, min_val, max_val ); + } + } + + + template inline void @@ -977,13 +1096,13 @@ subview::fill(const eT val) if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) { arrayops::inplace_set( s.colptr(0), val, s.n_elem ); - - return; } - - for(uword ucol=0; ucol < s_n_cols; ++ucol) + else { - arrayops::inplace_set( s.colptr(ucol), val, s_n_rows ); + for(uword ucol=0; ucol < s_n_cols; ++ucol) + { + arrayops::inplace_set( s.colptr(ucol), val, s_n_rows ); + } } } } @@ -1040,21 +1159,33 @@ subview::randu() { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; + subview& s = (*this); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; - if(local_n_rows == 1) + if(s_n_rows == 1) { - for(uword ii=0; ii < local_n_cols; ++ii) - { - at(0,ii) = eT(arma_rng::randu()); - } + podarray tmp(s_n_cols); + + eT* tmp_mem = tmp.memptr(); + + arma_rng::randu::fill( tmp_mem, s_n_cols ); + + for(uword ii=0; ii < s_n_cols; ++ii) { at(0,ii) = tmp_mem[ii]; } } else { - for(uword ii=0; ii < local_n_cols; ++ii) + if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) + { + arma_rng::randu::fill( s.colptr(0), s.n_elem ); + } + else { - arma_rng::randu::fill( colptr(ii), local_n_rows ); + for(uword ii=0; ii < s_n_cols; ++ii) + { + arma_rng::randu::fill( s.colptr(ii), s_n_rows ); + } } } } @@ -1068,21 +1199,33 @@ subview::randn() { arma_extra_debug_sigprint(); - const uword local_n_rows = n_rows; - const uword local_n_cols = n_cols; + subview& s = (*this); + + const uword s_n_rows = s.n_rows; + const uword s_n_cols = s.n_cols; - if(local_n_rows == 1) + if(s_n_rows == 1) { - for(uword ii=0; ii < local_n_cols; ++ii) - { - at(0,ii) = eT(arma_rng::randn()); - } + podarray tmp(s_n_cols); + + eT* tmp_mem = tmp.memptr(); + + arma_rng::randn::fill( tmp_mem, s_n_cols ); + + for(uword ii=0; ii < s_n_cols; ++ii) { at(0,ii) = tmp_mem[ii]; } } else { - for(uword ii=0; ii < local_n_cols; ++ii) + if( (s.aux_row1 == 0) && (s_n_rows == s.m.n_rows) ) { - arma_rng::randn::fill( colptr(ii), local_n_rows ); + arma_rng::randn::fill( s.colptr(0), s.n_elem ); + } + else + { + for(uword ii=0; ii < s_n_cols; ++ii) + { + arma_rng::randn::fill( s.colptr(ii), s_n_rows ); + } } } } @@ -1134,7 +1277,7 @@ inline eT& subview::operator()(const uword ii) { - arma_debug_check( (ii >= n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= n_elem), "subview::operator(): index out of bounds" ); const uword in_col = ii / n_rows; const uword in_row = ii % n_rows; @@ -1151,7 +1294,7 @@ inline eT subview::operator()(const uword ii) const { - arma_debug_check( (ii >= n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= n_elem), "subview::operator(): index out of bounds" ); const uword in_col = ii / n_rows; const uword in_row = ii % n_rows; @@ -1168,7 +1311,7 @@ inline eT& subview::operator()(const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds" ); const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; @@ -1182,7 +1325,7 @@ inline eT subview::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= n_rows) || (in_col >= n_cols)), "subview::operator(): index out of bounds" ); const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; @@ -1215,6 +1358,60 @@ subview::at(const uword in_row, const uword in_col) const +template +inline +eT& +subview::front() + { + const uword index = aux_col1*m.n_rows + aux_row1; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::front() const + { + const uword index = aux_col1*m.n_rows + aux_row1; + + return m.mem[index]; + } + + + +template +inline +eT& +subview::back() + { + const uword in_row = n_rows - 1; + const uword in_col = n_cols - 1; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return access::rw( (const_cast< Mat& >(m)).mem[index] ); + } + + + +template +inline +eT +subview::back() const + { + const uword in_row = n_rows - 1; + const uword in_col = n_cols - 1; + + const uword index = (in_col + aux_col1)*m.n_rows + aux_row1 + in_row; + + return m.mem[index]; + } + + + template arma_inline eT* @@ -1273,7 +1470,6 @@ subview::check_overlap(const subview& x) const template inline -arma_warn_unused bool subview::is_vec() const { @@ -1284,12 +1480,13 @@ subview::is_vec() const template inline -arma_warn_unused bool subview::is_finite() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "is_finite(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; @@ -1305,12 +1502,33 @@ subview::is_finite() const template inline -arma_warn_unused +bool +subview::is_zero(const typename get_pod_type::result tol) const + { + arma_extra_debug_sigprint(); + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii +inline bool subview::has_inf() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_inf(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; @@ -1326,12 +1544,13 @@ subview::has_inf() const template inline -arma_warn_unused bool subview::has_nan() const { arma_extra_debug_sigprint(); + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nan(): detection of non-finite values is not reliable in fast math mode"); } + const uword local_n_rows = n_rows; const uword local_n_cols = n_cols; @@ -1345,6 +1564,28 @@ subview::has_nan() const +template +inline +bool +subview::has_nonfinite() const + { + arma_extra_debug_sigprint(); + + if(arma_config::fast_math_warn) { arma_debug_warn_level(1, "has_nonfinite(): detection of non-finite values is not reliable in fast math mode"); } + + const uword local_n_rows = n_rows; + const uword local_n_cols = n_cols; + + for(uword ii=0; ii inline @@ -1359,10 +1600,10 @@ subview::extract(Mat& out, const subview& in) const uword n_rows = in.n_rows; // number of rows in the subview const uword n_cols = in.n_cols; // number of columns in the subview - arma_extra_debug_print(arma_str::format("out.n_rows = %d out.n_cols = %d in.m.n_rows = %d in.m.n_cols = %d") % out.n_rows % out.n_cols % in.m.n_rows % in.m.n_cols ); + arma_extra_debug_print(arma_str::format("out.n_rows = %u out.n_cols = %u in.m.n_rows = %u in.m.n_cols = %u") % out.n_rows % out.n_cols % in.m.n_rows % in.m.n_cols ); - if(in.is_vec() == true) + if(in.is_vec()) { if(n_cols == 1) // a column vector { @@ -1371,7 +1612,8 @@ subview::extract(Mat& out, const subview& in) // in.colptr(0) the first column of the subview, taking into account any row offset arrayops::copy( out.memptr(), in.colptr(0), n_rows ); } - else // a row vector (possibly empty) + else + if(n_rows == 1) // a row vector { arma_extra_debug_print("subview::extract(): copying row (going across columns)"); @@ -1405,13 +1647,13 @@ subview::extract(Mat& out, const subview& in) if( (in.aux_row1 == 0) && (n_rows == in.m.n_rows) ) { arrayops::copy( out.memptr(), in.colptr(0), in.n_elem ); - - return; } - - for(uword col=0; col < n_cols; ++col) + else { - arrayops::copy( out.colptr(col), in.colptr(col), n_rows ); + for(uword col=0; col < n_cols; ++col) + { + arrayops::copy( out.colptr(col), in.colptr(col), n_rows ); + } } } } @@ -1618,7 +1860,7 @@ subview::row(const uword row_num) { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "subview::row(): out of bounds" ); + arma_debug_check_bounds( row_num >= n_rows, "subview::row(): out of bounds" ); const uword base_row = aux_row1 + row_num; @@ -1635,7 +1877,7 @@ subview::row(const uword row_num) const { arma_extra_debug_sigprint(); - arma_debug_check( row_num >= n_rows, "subview::row(): out of bounds" ); + arma_debug_check_bounds( row_num >= n_rows, "subview::row(): out of bounds" ); const uword base_row = aux_row1 + row_num; @@ -1662,7 +1904,7 @@ subview::operator()(const uword row_num, const span& col_span) const uword base_col1 = aux_col1 + in_col1; const uword base_row = aux_row1 + row_num; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -1694,7 +1936,7 @@ subview::operator()(const uword row_num, const span& col_span) const const uword base_col1 = aux_col1 + in_col1; const uword base_row = aux_row1 + row_num; - arma_debug_check + arma_debug_check_bounds ( (row_num >= n_rows) || @@ -1716,7 +1958,7 @@ subview::col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "subview::col(): out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "subview::col(): out of bounds" ); const uword base_col = aux_col1 + col_num; @@ -1733,7 +1975,7 @@ subview::col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "subview::col(): out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "subview::col(): out of bounds" ); const uword base_col = aux_col1 + col_num; @@ -1760,7 +2002,7 @@ subview::operator()(const span& row_span, const uword col_num) const uword base_row1 = aux_row1 + in_row1; const uword base_col = aux_col1 + col_num; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -1792,7 +2034,7 @@ subview::operator()(const span& row_span, const uword col_num) const const uword base_row1 = aux_row1 + in_row1; const uword base_col = aux_col1 + col_num; - arma_debug_check + arma_debug_check_bounds ( (col_num >= n_cols) || @@ -1818,7 +2060,7 @@ subview::unsafe_col(const uword col_num) { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "subview::unsafe_col(): out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "subview::unsafe_col(): out of bounds" ); return Col(colptr(col_num), n_rows, false, true); } @@ -1837,7 +2079,7 @@ subview::unsafe_col(const uword col_num) const { arma_extra_debug_sigprint(); - arma_debug_check( col_num >= n_cols, "subview::unsafe_col(): out of bounds"); + arma_debug_check_bounds( col_num >= n_cols, "subview::unsafe_col(): out of bounds" ); return Col(const_cast(colptr(col_num)), n_rows, false, true); } @@ -1852,7 +2094,7 @@ subview::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "subview::rows(): indices out of bounds or incorrectly used" @@ -1874,7 +2116,7 @@ subview::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_row2 >= n_rows), "subview::rows(): indices out of bounds or incorrectly used" @@ -1896,7 +2138,7 @@ subview::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "subview::cols(): indices out of bounds or incorrectly used" @@ -1918,7 +2160,7 @@ subview::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 > in_col2) || (in_col2 >= n_cols), "subview::cols(): indices out of bounds or incorrectly used" @@ -1940,7 +2182,7 @@ subview::submat(const uword in_row1, const uword in_col1, const uword in_row { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "subview::submat(): indices out of bounds or incorrectly used" @@ -1965,7 +2207,7 @@ subview::submat(const uword in_row1, const uword in_col1, const uword in_row { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 > in_row2) || (in_col1 > in_col2) || (in_row2 >= n_rows) || (in_col2 >= n_cols), "subview::submat(): indices out of bounds or incorrectly used" @@ -2004,7 +2246,7 @@ subview::submat(const span& row_span, const span& col_span) const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -2043,7 +2285,7 @@ subview::submat(const span& row_span, const span& col_span) const const uword in_col2 = col_span.b; const uword submat_n_cols = col_all ? local_n_cols : in_col2 - in_col1 + 1; - arma_debug_check + arma_debug_check_bounds ( ( row_all ? false : ((in_row1 > in_row2) || (in_row2 >= local_n_rows)) ) || @@ -2134,135 +2376,131 @@ subview::each_row(const Base& indices) -#if defined(ARMA_USE_CXX11) +//! apply a lambda function to each column, where each column is interpreted as a column vector +template +inline +void +subview::each_col(const std::function< void(Col&) >& F) + { + arma_extra_debug_sigprint(); - //! apply a lambda function to each column, where each column is interpreted as a column vector - template - inline - void - subview::each_col(const std::function< void(Col&) >& F) + for(uword ii=0; ii < n_cols; ++ii) { - arma_extra_debug_sigprint(); - - for(uword ii=0; ii < n_cols; ++ii) - { - Col tmp(colptr(ii), n_rows, false, true); - F(tmp); - } + Col tmp(colptr(ii), n_rows, false, true); + F(tmp); } + } + + + +template +inline +void +subview::each_col(const std::function< void(const Col&) >& F) const + { + arma_extra_debug_sigprint(); - - - template - inline - void - subview::each_col(const std::function< void(const Col&) >& F) const + for(uword ii=0; ii < n_cols; ++ii) { - arma_extra_debug_sigprint(); - - for(uword ii=0; ii < n_cols; ++ii) - { - const Col tmp(colptr(ii), n_rows, false, true); - F(tmp); - } + const Col tmp(colptr(ii), n_rows, false, true); + F(tmp); } - - - - //! apply a lambda function to each row, where each row is interpreted as a row vector - template - inline - void - subview::each_row(const std::function< void(Row&) >& F) - { - arma_extra_debug_sigprint(); - - podarray array1(n_cols); - podarray array2(n_cols); - - Row tmp1( array1.memptr(), n_cols, false, true ); - Row tmp2( array2.memptr(), n_cols, false, true ); - - eT* tmp1_mem = tmp1.memptr(); - eT* tmp2_mem = tmp2.memptr(); - - uword ii, jj; - - for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + } + + + +//! apply a lambda function to each row, where each row is interpreted as a row vector +template +inline +void +subview::each_row(const std::function< void(Row&) >& F) + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) { - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - const eT* col_mem = colptr(col_id); - - tmp1_mem[col_id] = col_mem[ii]; - tmp2_mem[col_id] = col_mem[jj]; - } + const eT* col_mem = colptr(col_id); - F(tmp1); - F(tmp2); - - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - eT* col_mem = colptr(col_id); - - col_mem[ii] = tmp1_mem[col_id]; - col_mem[jj] = tmp2_mem[col_id]; - } + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; } - if(ii < n_rows) + F(tmp1); + F(tmp2); + + for(uword col_id = 0; col_id < n_cols; ++col_id) { - tmp1 = (*this).row(ii); - - F(tmp1); + eT* col_mem = colptr(col_id); - (*this).row(ii) = tmp1; + col_mem[ii] = tmp1_mem[col_id]; + col_mem[jj] = tmp2_mem[col_id]; } } - - - template - inline - void - subview::each_row(const std::function< void(const Row&) >& F) const + if(ii < n_rows) { - arma_extra_debug_sigprint(); - - podarray array1(n_cols); - podarray array2(n_cols); - - Row tmp1( array1.memptr(), n_cols, false, true ); - Row tmp2( array2.memptr(), n_cols, false, true ); + tmp1 = (*this).row(ii); - eT* tmp1_mem = tmp1.memptr(); - eT* tmp2_mem = tmp2.memptr(); + F(tmp1); - uword ii, jj; - - for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + (*this).row(ii) = tmp1; + } + } + + + +template +inline +void +subview::each_row(const std::function< void(const Row&) >& F) const + { + arma_extra_debug_sigprint(); + + podarray array1(n_cols); + podarray array2(n_cols); + + Row tmp1( array1.memptr(), n_cols, false, true ); + Row tmp2( array2.memptr(), n_cols, false, true ); + + eT* tmp1_mem = tmp1.memptr(); + eT* tmp2_mem = tmp2.memptr(); + + uword ii, jj; + + for(ii=0, jj=1; jj < n_rows; ii+=2, jj+=2) + { + for(uword col_id = 0; col_id < n_cols; ++col_id) { - for(uword col_id = 0; col_id < n_cols; ++col_id) - { - const eT* col_mem = colptr(col_id); - - tmp1_mem[col_id] = col_mem[ii]; - tmp2_mem[col_id] = col_mem[jj]; - } + const eT* col_mem = colptr(col_id); - F(tmp1); - F(tmp2); + tmp1_mem[col_id] = col_mem[ii]; + tmp2_mem[col_id] = col_mem[jj]; } - if(ii < n_rows) - { - tmp1 = (*this).row(ii); - - F(tmp1); - } + F(tmp1); + F(tmp2); } -#endif + if(ii < n_rows) + { + tmp1 = (*this).row(ii); + + F(tmp1); + } + } @@ -2277,7 +2515,7 @@ subview::diag(const sword in_id) const uword row_offset = (in_id < 0) ? uword(-in_id) : 0; const uword col_offset = (in_id > 0) ? uword( in_id) : 0; - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "subview::diag(): requested diagonal out of bounds" @@ -2304,7 +2542,7 @@ subview::diag(const sword in_id) const const uword row_offset = uword( (in_id < 0) ? -in_id : 0 ); const uword col_offset = uword( (in_id > 0) ? in_id : 0 ); - arma_debug_check + arma_debug_check_bounds ( ((row_offset > 0) && (row_offset >= n_rows)) || ((col_offset > 0) && (col_offset >= n_cols)), "subview::diag(): requested diagonal out of bounds" @@ -2327,7 +2565,7 @@ subview::swap_rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_row1 >= n_rows) || (in_row2 >= n_rows), "subview::swap_rows(): out of bounds" @@ -2359,7 +2597,7 @@ subview::swap_cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check + arma_debug_check_bounds ( (in_col1 >= n_cols) || (in_col2 >= n_cols), "subview::swap_cols(): out of bounds" @@ -2448,12 +2686,12 @@ subview::cend() const template inline subview::iterator::iterator() - : M (NULL) - , current_ptr(NULL) - , current_row(0 ) - , current_col(0 ) - , aux_row1 (0 ) - , aux_row2_p1(0 ) + : M (nullptr) + , current_ptr(nullptr) + , current_row(0 ) + , current_col(0 ) + , aux_row1 (0 ) + , aux_row2_p1(0 ) { arma_extra_debug_sigprint(); // Technically this iterator is invalid (it does not point to a valid element) @@ -2493,7 +2731,6 @@ subview::iterator::iterator(subview& in_sv, const uword in_row, const uw template inline -arma_warn_unused eT& subview::iterator::operator*() { @@ -2528,7 +2765,6 @@ subview::iterator::operator++() template inline -arma_warn_unused typename subview::iterator subview::iterator::operator++(int) { @@ -2543,7 +2779,6 @@ subview::iterator::operator++(int) template inline -arma_warn_unused bool subview::iterator::operator==(const iterator& rhs) const { @@ -2554,7 +2789,6 @@ subview::iterator::operator==(const iterator& rhs) const template inline -arma_warn_unused bool subview::iterator::operator!=(const iterator& rhs) const { @@ -2565,7 +2799,6 @@ subview::iterator::operator!=(const iterator& rhs) const template inline -arma_warn_unused bool subview::iterator::operator==(const const_iterator& rhs) const { @@ -2576,7 +2809,6 @@ subview::iterator::operator==(const const_iterator& rhs) const template inline -arma_warn_unused bool subview::iterator::operator!=(const const_iterator& rhs) const { @@ -2594,8 +2826,8 @@ subview::iterator::operator!=(const const_iterator& rhs) const template inline subview::const_iterator::const_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) + , current_ptr(nullptr) , current_row(0 ) , current_col(0 ) , aux_row1 (0 ) @@ -2654,7 +2886,6 @@ subview::const_iterator::const_iterator(const subview& in_sv, const uwor template inline -arma_warn_unused const eT& subview::const_iterator::operator*() { @@ -2689,7 +2920,6 @@ subview::const_iterator::operator++() template inline -arma_warn_unused typename subview::const_iterator subview::const_iterator::operator++(int) { @@ -2704,7 +2934,6 @@ subview::const_iterator::operator++(int) template inline -arma_warn_unused bool subview::const_iterator::operator==(const iterator& rhs) const { @@ -2715,7 +2944,6 @@ subview::const_iterator::operator==(const iterator& rhs) const template inline -arma_warn_unused bool subview::const_iterator::operator!=(const iterator& rhs) const { @@ -2726,7 +2954,6 @@ subview::const_iterator::operator!=(const iterator& rhs) const template inline -arma_warn_unused bool subview::const_iterator::operator==(const const_iterator& rhs) const { @@ -2737,7 +2964,6 @@ subview::const_iterator::operator==(const const_iterator& rhs) const template inline -arma_warn_unused bool subview::const_iterator::operator!=(const const_iterator& rhs) const { @@ -2755,8 +2981,7 @@ subview::const_iterator::operator!=(const const_iterator& rhs) const template inline subview::row_iterator::row_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) , current_row(0 ) , current_col(0 ) , aux_col1 (0 ) @@ -2772,7 +2997,6 @@ template inline subview::row_iterator::row_iterator(const row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) , aux_col1 (X.aux_col1 ) @@ -2787,7 +3011,6 @@ template inline subview::row_iterator::row_iterator(subview& in_sv, const uword in_row, const uword in_col) : M (&(const_cast< Mat& >(in_sv.m))) - , current_ptr(&(M->at(in_row,in_col)) ) , current_row(in_row ) , current_col(in_col ) , aux_col1 (in_sv.aux_col1 ) @@ -2800,11 +3023,10 @@ subview::row_iterator::row_iterator(subview& in_sv, const uword in_row, template inline -arma_warn_unused eT& subview::row_iterator::operator*() { - return (*current_ptr); + return M->at(current_row,current_col); } @@ -2820,12 +3042,6 @@ subview::row_iterator::operator++() { current_col = aux_col1; current_row++; - - current_ptr = &( (*M).at(current_row,current_col) ); - } - else - { - current_ptr += (*M).n_rows; } return *this; @@ -2835,7 +3051,6 @@ subview::row_iterator::operator++() template inline -arma_warn_unused typename subview::row_iterator subview::row_iterator::operator++(int) { @@ -2850,44 +3065,40 @@ subview::row_iterator::operator++(int) template inline -arma_warn_unused bool subview::row_iterator::operator==(const row_iterator& rhs) const { - return (current_ptr == rhs.current_ptr); + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); } template inline -arma_warn_unused bool subview::row_iterator::operator!=(const row_iterator& rhs) const { - return (current_ptr != rhs.current_ptr); + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); } template inline -arma_warn_unused bool subview::row_iterator::operator==(const const_row_iterator& rhs) const { - return (current_ptr == rhs.current_ptr); + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); } template inline -arma_warn_unused bool subview::row_iterator::operator!=(const const_row_iterator& rhs) const { - return (current_ptr != rhs.current_ptr); + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); } @@ -2901,8 +3112,7 @@ subview::row_iterator::operator!=(const const_row_iterator& rhs) const template inline subview::const_row_iterator::const_row_iterator() - : M (NULL) - , current_ptr(NULL) + : M (nullptr) , current_row(0 ) , current_col(0 ) , aux_col1 (0 ) @@ -2918,7 +3128,6 @@ template inline subview::const_row_iterator::const_row_iterator(const row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) , aux_col1 (X.aux_col1 ) @@ -2933,7 +3142,6 @@ template inline subview::const_row_iterator::const_row_iterator(const const_row_iterator& X) : M (X.M ) - , current_ptr(X.current_ptr) , current_row(X.current_row) , current_col(X.current_col) , aux_col1 (X.aux_col1 ) @@ -2948,7 +3156,6 @@ template inline subview::const_row_iterator::const_row_iterator(const subview& in_sv, const uword in_row, const uword in_col) : M (&(in_sv.m) ) - , current_ptr(&(M->at(in_row,in_col)) ) , current_row(in_row ) , current_col(in_col ) , aux_col1 (in_sv.aux_col1 ) @@ -2961,11 +3168,10 @@ subview::const_row_iterator::const_row_iterator(const subview& in_sv, co template inline -arma_warn_unused const eT& subview::const_row_iterator::operator*() const { - return (*current_ptr); + return M->at(current_row,current_col); } @@ -2981,12 +3187,6 @@ subview::const_row_iterator::operator++() { current_col = aux_col1; current_row++; - - current_ptr = &( (*M).at(current_row,current_col) ); - } - else - { - current_ptr += (*M).n_rows; } return *this; @@ -2996,7 +3196,6 @@ subview::const_row_iterator::operator++() template inline -arma_warn_unused typename subview::const_row_iterator subview::const_row_iterator::operator++(int) { @@ -3011,44 +3210,40 @@ subview::const_row_iterator::operator++(int) template inline -arma_warn_unused bool subview::const_row_iterator::operator==(const row_iterator& rhs) const { - return (current_ptr == rhs.current_ptr); + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); } template inline -arma_warn_unused bool subview::const_row_iterator::operator!=(const row_iterator& rhs) const { - return (current_ptr != rhs.current_ptr); + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); } template inline -arma_warn_unused bool subview::const_row_iterator::operator==(const const_row_iterator& rhs) const { - return (current_ptr == rhs.current_ptr); + return ( (current_row == rhs.current_row) && (current_col == rhs.current_col) ); } template inline -arma_warn_unused bool subview::const_row_iterator::operator!=(const const_row_iterator& rhs) const { - return (current_ptr != rhs.current_ptr); + return ( (current_row != rhs.current_row) || (current_col != rhs.current_col) ); } @@ -3081,6 +3276,30 @@ subview_col::subview_col(const Mat& in_m, const uword in_col, const uwor +template +inline +subview_col::subview_col(const subview_col& in) + : subview(in) // interprets 'subview_col' as 'subview' + , colmem(in.colmem) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_col::subview_col(subview_col&& in) + : subview(std::move(in)) // interprets 'subview_col' as 'subview' + , colmem(in.colmem) + { + arma_extra_debug_sigprint(); + + access::rw(in.colmem) = nullptr; + } + + + template inline void @@ -3105,6 +3324,22 @@ subview_col::operator=(const subview_col& X) +template +inline +void +subview_col::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_assert_same_size(subview::n_rows, subview::n_cols, N, 1, "copy into submatrix"); + + arrayops::copy( access::rwp(colmem), list.begin(), N ); + } + + + template inline void @@ -3135,6 +3370,19 @@ subview_col::operator=(const Base& X) +template +template +inline +void +subview_col::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + template template inline @@ -3264,7 +3512,7 @@ inline eT& subview_col::operator()(const uword ii) { - arma_debug_check( (ii >= subview::n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); return access::rw( colmem[ii] ); } @@ -3276,7 +3524,7 @@ inline eT subview_col::operator()(const uword ii) const { - arma_debug_check( (ii >= subview::n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); return colmem[ii]; } @@ -3288,7 +3536,7 @@ inline eT& subview_col::operator()(const uword in_row, const uword in_col) { - arma_debug_check( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds" ); return access::rw( colmem[in_row] ); } @@ -3300,7 +3548,7 @@ inline eT subview_col::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col > 0)), "subview::operator(): index out of bounds" ); return colmem[in_row]; } @@ -3352,7 +3600,7 @@ subview_col::rows(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -3370,7 +3618,7 @@ subview_col::rows(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::rows(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -3388,7 +3636,7 @@ subview_col::subvec(const uword in_row1, const uword in_row2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -3406,7 +3654,7 @@ subview_col::subvec(const uword in_row1, const uword in_row2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_row1 > in_row2) || (in_row2 >= subview::n_rows) ), "subview_col::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_rows = in_row2 - in_row1 + 1; @@ -3426,7 +3674,7 @@ subview_col::subvec(const uword start_row, const SizeMat& s) arma_debug_check( (s.n_cols != 1), "subview_col::subvec(): given size does not specify a column vector" ); - arma_debug_check( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); const uword base_row1 = this->aux_row1 + start_row; @@ -3444,7 +3692,7 @@ subview_col::subvec(const uword start_row, const SizeMat& s) const arma_debug_check( (s.n_cols != 1), "subview_col::subvec(): given size does not specify a column vector" ); - arma_debug_check( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_row >= subview::n_rows) || ((start_row + s.n_rows) > subview::n_rows) ), "subview_col::subvec(): size out of bounds" ); const uword base_row1 = this->aux_row1 + start_row; @@ -3460,7 +3708,7 @@ subview_col::head(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_rows), "subview_col::head(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::head(): size out of bounds" ); return subview_col(this->m, this->aux_col1, this->aux_row1, N); } @@ -3474,7 +3722,7 @@ subview_col::head(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_rows), "subview_col::head(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::head(): size out of bounds" ); return subview_col(this->m, this->aux_col1, this->aux_row1, N); } @@ -3488,7 +3736,7 @@ subview_col::tail(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_rows), "subview_col::tail(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::tail(): size out of bounds" ); const uword start_row = subview::aux_row1 + subview::n_rows - N; @@ -3504,7 +3752,7 @@ subview_col::tail(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_rows), "subview_col::tail(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_rows), "subview_col::tail(): size out of bounds" ); const uword start_row = subview::aux_row1 + subview::n_rows - N; @@ -3515,7 +3763,6 @@ subview_col::tail(const uword N) const template inline -arma_warn_unused eT subview_col::min() const { @@ -3535,7 +3782,6 @@ subview_col::min() const template inline -arma_warn_unused eT subview_col::max() const { @@ -3601,7 +3847,6 @@ subview_col::max(uword& index_of_max_val) const template inline -arma_warn_unused uword subview_col::index_min() const { @@ -3625,7 +3870,6 @@ subview_col::index_min() const template inline -arma_warn_unused uword subview_col::index_max() const { @@ -3647,6 +3891,314 @@ subview_col::index_max() const +// +// +// + + +template +inline +subview_cols::subview_cols(const Mat& in_m, const uword in_col1, const uword in_n_cols) + : subview(in_m, 0, in_col1, in_m.n_rows, in_n_cols) + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cols::subview_cols(const subview_cols& in) + : subview(in) // interprets 'subview_cols' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_cols::subview_cols(subview_cols&& in) + : subview(std::move(in)) // interprets 'subview_cols' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +void +subview_cols::operator=(const subview& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); + } + + + +template +inline +void +subview_cols::operator=(const subview_cols& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X); // interprets 'subview_cols' as 'subview' + } + + + +template +inline +void +subview_cols::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + subview::operator=(list); + } + + + +template +inline +void +subview_cols::operator=(const std::initializer_list< std::initializer_list >& list) + { + arma_extra_debug_sigprint(); + + subview::operator=(list); + } + + + +template +inline +void +subview_cols::operator=(const eT val) + { + arma_extra_debug_sigprint(); + + subview::operator=(val); + } + + + +template +template +inline +void +subview_cols::operator=(const Base& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +void +subview_cols::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + +template +template +inline +typename enable_if2< is_same_type::value, void>::result +subview_cols::operator= (const Gen& in) + { + arma_extra_debug_sigprint(); + + subview::operator=(in); + } + + + +template +arma_inline +const Op,op_htrans> +subview_cols::t() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_htrans> +subview_cols::ht() const + { + return Op,op_htrans>(*this); + } + + + +template +arma_inline +const Op,op_strans> +subview_cols::st() const + { + return Op,op_strans>(*this); + } + + + +template +arma_inline +const Op,op_vectorise_col> +subview_cols::as_col() const + { + return Op,op_vectorise_col>(*this); + } + + + +template +inline +eT +subview_cols::at_alt(const uword ii) const + { + return operator[](ii); + } + + + +template +inline +eT& +subview_cols::operator[](const uword ii) + { + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator[](const uword ii) const + { + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::operator()(const uword ii) + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator()(const uword ii) const + { + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); + + const uword index = subview::aux_col1 * subview::m.n_rows + ii; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::operator()(const uword in_row, const uword in_col) + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::operator()(const uword in_row, const uword in_col) const + { + arma_debug_check_bounds( ((in_row >= subview::n_rows) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); + + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return subview::m.mem[index]; + } + + + +template +inline +eT& +subview_cols::at(const uword in_row, const uword in_col) + { + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return access::rw( (const_cast< Mat& >(subview::m)).mem[index] ); + } + + + +template +inline +eT +subview_cols::at(const uword in_row, const uword in_col) const + { + const uword index = (in_col + subview::aux_col1) * subview::m.n_rows + in_row; + + return subview::m.mem[index]; + } + + + +template +arma_inline +eT* +subview_cols::colptr(const uword in_col) + { + return & access::rw((const_cast< Mat& >(subview::m)).mem[ (in_col + subview::aux_col1) * subview::m.n_rows ]); + } + + + +template +arma_inline +const eT* +subview_cols::colptr(const uword in_col) const + { + return & subview::m.mem[ (in_col + subview::aux_col1) * subview::m.n_rows ]; + } + + + // // // @@ -3673,6 +4225,26 @@ subview_row::subview_row(const Mat& in_m, const uword in_row, const uwor +template +inline +subview_row::subview_row(const subview_row& in) + : subview(in) // interprets 'subview_row' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + +template +inline +subview_row::subview_row(subview_row&& in) + : subview(std::move(in)) // interprets 'subview_row' as 'subview' + { + arma_extra_debug_sigprint(); + } + + + template inline void @@ -3709,6 +4281,28 @@ subview_row::operator=(const eT val) +template +inline +void +subview_row::operator=(const std::initializer_list& list) + { + arma_extra_debug_sigprint(); + + const uword N = uword(list.size()); + + arma_debug_assert_same_size(subview::n_rows, subview::n_cols, 1, N, "copy into submatrix"); + + auto it = list.begin(); + + for(uword ii=0; ii < N; ++ii) + { + (*this).operator[](ii) = (*it); + ++it; + } + } + + + template template inline @@ -3722,6 +4316,19 @@ subview_row::operator=(const Base& X) +template +template +inline +void +subview_row::operator=(const SpBase& X) + { + arma_extra_debug_sigprint(); + + subview::operator=(X.get_ref()); + } + + + template template inline @@ -3818,7 +4425,7 @@ inline eT& subview_row::operator()(const uword ii) { - arma_debug_check( (ii >= subview::n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); @@ -3832,7 +4439,7 @@ inline eT subview_row::operator()(const uword ii) const { - arma_debug_check( (ii >= subview::n_elem), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( (ii >= subview::n_elem), "subview::operator(): index out of bounds" ); const uword index = (ii + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); @@ -3846,7 +4453,7 @@ inline eT& subview_row::operator()(const uword in_row, const uword in_col) { - arma_debug_check( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); @@ -3860,7 +4467,7 @@ inline eT subview_row::operator()(const uword in_row, const uword in_col) const { - arma_debug_check( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds"); + arma_debug_check_bounds( ((in_row > 0) || (in_col >= subview::n_cols)), "subview::operator(): index out of bounds" ); const uword index = (in_col + (subview::aux_col1))*(subview::m).n_rows + (subview::aux_row1); @@ -3900,7 +4507,7 @@ subview_row::cols(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used" ); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -3918,7 +4525,7 @@ subview_row::cols(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::cols(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -3936,7 +4543,7 @@ subview_row::subvec(const uword in_col1, const uword in_col2) { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -3954,7 +4561,7 @@ subview_row::subvec(const uword in_col1, const uword in_col2) const { arma_extra_debug_sigprint(); - arma_debug_check( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used"); + arma_debug_check_bounds( ( (in_col1 > in_col2) || (in_col2 >= subview::n_cols) ), "subview_row::subvec(): indices out of bounds or incorrectly used" ); const uword subview_n_cols = in_col2 - in_col1 + 1; @@ -3974,7 +4581,7 @@ subview_row::subvec(const uword start_col, const SizeMat& s) arma_debug_check( (s.n_rows != 1), "subview_row::subvec(): given size does not specify a row vector" ); - arma_debug_check( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); const uword base_col1 = this->aux_col1 + start_col; @@ -3992,7 +4599,7 @@ subview_row::subvec(const uword start_col, const SizeMat& s) const arma_debug_check( (s.n_rows != 1), "subview_row::subvec(): given size does not specify a row vector" ); - arma_debug_check( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); + arma_debug_check_bounds( ( (start_col >= subview::n_cols) || ((start_col + s.n_cols) > subview::n_cols) ), "subview_row::subvec(): size out of bounds" ); const uword base_col1 = this->aux_col1 + start_col; @@ -4008,7 +4615,7 @@ subview_row::head(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_cols), "subview_row::head(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::head(): size out of bounds" ); return subview_row(this->m, this->aux_row1, this->aux_col1, N); } @@ -4022,7 +4629,7 @@ subview_row::head(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_cols), "subview_row::head(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::head(): size out of bounds" ); return subview_row(this->m, this->aux_row1, this->aux_col1, N); } @@ -4036,7 +4643,7 @@ subview_row::tail(const uword N) { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_cols), "subview_row::tail(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::tail(): size out of bounds" ); const uword start_col = subview::aux_col1 + subview::n_cols - N; @@ -4052,7 +4659,7 @@ subview_row::tail(const uword N) const { arma_extra_debug_sigprint(); - arma_debug_check( (N > subview::n_cols), "subview_row::tail(): size out of bounds"); + arma_debug_check_bounds( (N > subview::n_cols), "subview_row::tail(): size out of bounds" ); const uword start_col = subview::aux_col1 + subview::n_cols - N; @@ -4063,7 +4670,6 @@ subview_row::tail(const uword N) const template inline -arma_warn_unused uword subview_row::index_min() const { @@ -4087,7 +4693,6 @@ subview_row::index_min() const template inline -arma_warn_unused uword subview_row::index_max() const { diff --git a/src/armadillo_bits/sym_helper.hpp b/src/armadillo_bits/sym_helper.hpp new file mode 100644 index 00000000..00555c49 --- /dev/null +++ b/src/armadillo_bits/sym_helper.hpp @@ -0,0 +1,485 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +//! \addtogroup sym_helper +//! @{ + + +namespace sym_helper +{ + +// computationally inexpensive algorithm to guess whether a matrix is positive definite: +// (1) ensure the matrix is symmetric/hermitian (within a tolerance) +// (2) ensure the diagonal entries are real and greater than zero +// (3) ensure that the value with largest modulus is on the main diagonal +// (4) ensure rudimentary diagonal dominance: (real(A_ii) + real(A_jj)) > 2*abs(real(A_ij)) +// the above conditions are necessary, but not sufficient; +// doing it properly would be too computationally expensive for our purposes +// more info: +// http://mathworld.wolfram.com/PositiveDefiniteMatrix.html +// http://mathworld.wolfram.com/DiagonallyDominantMatrix.html + +template +inline +typename enable_if2::no, bool>::result +guess_sympd_worker(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming A is square-sized + + const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + eT max_diag = eT(0); + + for(uword j=0; j < N; ++j) + { + const eT A_jj = A_col[j]; + + if(A_jj <= eT(0)) { return false; } + + max_diag = (A_jj > max_diag) ? A_jj : max_diag; + + A_col += N; + } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const eT A_jj = A_col[j]; + + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + for(uword i=jp1; i < N; ++i) + { + const eT A_ij = A_col[i]; + const eT A_ji = (*A_ji_ptr); + + const eT A_ij_abs = (std::abs)(A_ij); + const eT A_ji_abs = (std::abs)(A_ji); + + // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { return false; } + if(A_ij_abs >= max_diag) { return false; } + + const eT A_delta = (std::abs)(A_ij - A_ji); + const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); + + if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { return false; } + + const eT A_ii = (*A_ii_ptr); + + if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { return false; } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + + return true; + } + + + +template +inline +typename enable_if2::yes, bool>::result +guess_sympd_worker(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming A is square-sized + + typedef typename get_pod_type::result T; + + const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + T max_diag = T(0); + + for(uword j=0; j < N; ++j) + { + const eT& A_jj = A_col[j]; + const T A_jj_real = std::real(A_jj); + const T A_jj_imag = std::imag(A_jj); + + if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { return false; } + + max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + + A_col += N; + } + + const T square_max_diag = max_diag * max_diag; + + if(arma_isfinite(square_max_diag) == false) { return false; } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + const T A_jj_real = std::real(A_col[j]); + + for(uword i=jp1; i < N; ++i) + { + const eT& A_ij = A_col[i]; + const T A_ij_real = std::real(A_ij); + const T A_ij_imag = std::imag(A_ij); + + // avoid using std::abs(), as that is time consuming due to division and std::sqrt() + const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); + + if(arma_isfinite(square_A_ij_abs) == false) { return false; } + + if(square_A_ij_abs >= square_max_diag) { return false; } + + const T A_ij_real_abs = (std::abs)(A_ij_real); + const T A_ij_imag_abs = (std::abs)(A_ij_imag); + + + const eT& A_ji = (*A_ji_ptr); + const T A_ji_real = std::real(A_ji); + const T A_ji_imag = std::imag(A_ji); + + const T A_ji_real_abs = (std::abs)(A_ji_real); + const T A_ji_imag_abs = (std::abs)(A_ji_imag); + + const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); + const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); + + if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { return false; } + + + const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate + const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); + + if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { return false; } + + + const T A_ii_real = std::real(*A_ii_ptr); + + if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { return false; } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + + return true; + } + + + +template +inline +bool +guess_sympd(const Mat& A) + { + arma_extra_debug_sigprint(); + + // analyse matrices with size >= 4x4 + + if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) { return false; } + + return guess_sympd_worker(A); + } + + + +template +inline +bool +guess_sympd(const Mat& A, const uword min_n_rows) + { + arma_extra_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < min_n_rows)) { return false; } + + return guess_sympd_worker(A); + } + + + +// + + + +template +inline +typename enable_if2::no, void>::result +analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + is_approx_sym = true; + is_approx_sympd = true; + + const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + eT max_diag = eT(0); + + for(uword j=0; j < N; ++j) + { + const eT A_jj = A_col[j]; + + if(A_jj <= eT(0)) { is_approx_sympd = false; } + + max_diag = (A_jj > max_diag) ? A_jj : max_diag; + + A_col += N; + } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const eT A_jj = A_col[j]; + + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + for(uword i=jp1; i < N; ++i) + { + const eT A_ij = A_col[i]; + const eT A_ji = (*A_ji_ptr); + + const eT A_ij_abs = (std::abs)(A_ij); + const eT A_ji_abs = (std::abs)(A_ji); + + const eT A_delta = (std::abs)(A_ij - A_ji); + const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); + + if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { is_approx_sym = false; return; } + + if(is_approx_sympd) + { + // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { is_approx_sympd = false; } + if(A_ij_abs >= max_diag) { is_approx_sympd = false; } + + const eT A_ii = (*A_ii_ptr); + + if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { is_approx_sympd = false; } + } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + } + + + +template +inline +typename enable_if2::yes, void>::result +analyse_matrix_worker(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + typedef typename get_pod_type::result T; + + is_approx_sym = true; + is_approx_sympd = true; + + const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway + + const uword N = A.n_rows; + + const eT* A_mem = A.memptr(); + const eT* A_col = A_mem; + + T max_diag = T(0); + + for(uword j=0; j < N; ++j) + { + const eT& A_jj = A_col[j]; + const T A_jj_real = std::real(A_jj); + const T A_jj_imag = std::imag(A_jj); + + if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { is_approx_sympd = false; } + + max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; + + A_col += N; + } + + const T square_max_diag = max_diag * max_diag; + + if(arma_isfinite(square_max_diag) == false) { is_approx_sympd = false; } + + A_col = A_mem; + + const uword Nm1 = N-1; + const uword Np1 = N+1; + + for(uword j=0; j < Nm1; ++j) + { + const uword jp1 = j+1; + const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); + const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); + + const T A_jj_real = std::real(A_col[j]); + + for(uword i=jp1; i < N; ++i) + { + const eT& A_ij = A_col[i]; + const T A_ij_real = std::real(A_ij); + const T A_ij_imag = std::imag(A_ij); + + const T A_ij_real_abs = (std::abs)(A_ij_real); + const T A_ij_imag_abs = (std::abs)(A_ij_imag); + + const eT& A_ji = (*A_ji_ptr); + const T A_ji_real = std::real(A_ji); + const T A_ji_imag = std::imag(A_ji); + + const T A_ji_real_abs = (std::abs)(A_ji_real); + const T A_ji_imag_abs = (std::abs)(A_ji_imag); + + const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); + const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); + + if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { is_approx_sym = false; return; } + + const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate + const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); + + if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { is_approx_sym = false; return; } + + if(is_approx_sympd) + { + // avoid using std::abs(), as that is time consuming due to division and std::sqrt() + const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); + + if(arma_isfinite(square_A_ij_abs) == false) + { + is_approx_sympd = false; + } + else + { + const T A_ii_real = std::real(*A_ii_ptr); + + if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { is_approx_sympd = false; } + + if(square_A_ij_abs >= square_max_diag) { is_approx_sympd = false; } + } + } + + A_ji_ptr += N; + A_ii_ptr += Np1; + } + + A_col += N; + } + } + + + +template +inline +void +analyse_matrix(bool& is_approx_sym, bool& is_approx_sympd, const Mat& A) + { + arma_extra_debug_sigprint(); + + if((A.n_rows != A.n_cols) || (A.n_rows < uword(4))) + { + is_approx_sym = false; + is_approx_sympd = false; + return; + } + + analyse_matrix_worker(is_approx_sym, is_approx_sympd, A); + + if(is_approx_sym == false) { is_approx_sympd = false; } + } + + + +template +inline +bool +check_diag_imag(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming matrix A is square-sized + + typedef typename get_pod_type::result T; + + const T tol = T(10000) * std::numeric_limits::epsilon(); // allow some leeway + + const eT* colmem = A.memptr(); + + const uword N = A.n_rows; + + for(uword i=0; i tol) { return false; } + + colmem += N; + } + + return true; + } + + + +} // end of namespace sym_helper + + +//! @} diff --git a/src/armadillo_bits/sympd_helper.hpp b/src/armadillo_bits/sympd_helper.hpp deleted file mode 100644 index 920a75a0..00000000 --- a/src/armadillo_bits/sympd_helper.hpp +++ /dev/null @@ -1,214 +0,0 @@ -// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) -// Copyright 2008-2016 National ICT Australia (NICTA) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ------------------------------------------------------------------------ - - -//! \addtogroup sympd_helper -//! @{ - - -namespace sympd_helper -{ - -// computationally inexpensive algorithm to guess whether a matrix is positive definite: -// (1) ensure the matrix is symmetric/hermitian (within a tolerance) -// (2) ensure the diagonal entries are real and greater than zero -// (3) ensure that the value with largest modulus is on the main diagonal -// (4) ensure rudimentary diagonal dominance: (real(A_ii) + real(A_jj)) > 2*abs(real(A_ij)) -// the above conditions are necessary, but not sufficient; -// doing it properly would be too computationally expensive for our purposes -// more info: -// http://mathworld.wolfram.com/PositiveDefiniteMatrix.html -// http://mathworld.wolfram.com/DiagonallyDominantMatrix.html - -template -inline -typename enable_if2::no, bool>::result -guess_sympd(const Mat& A) - { - arma_extra_debug_sigprint(); - - if((A.n_rows != A.n_cols) || (A.n_rows < 16)) { return false; } - - const eT tol = eT(100) * std::numeric_limits::epsilon(); // allow some leeway - - const uword N = A.n_rows; - - const eT* A_mem = A.memptr(); - const eT* A_col = A_mem; - - eT max_diag = eT(0); - - for(uword j=0; j < N; ++j) - { - const eT A_jj = A_col[j]; - - if(A_jj <= eT(0)) { return false; } - - max_diag = (A_jj > max_diag) ? A_jj : max_diag; - - A_col += N; - } - - A_col = A_mem; - - const uword Nm1 = N-1; - const uword Np1 = N+1; - - for(uword j=0; j < Nm1; ++j) - { - const eT A_jj = A_col[j]; - - const uword jp1 = j+1; - const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); - const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); - - for(uword i=jp1; i < N; ++i) - { - const eT A_ij = A_col[i]; - const eT A_ji = (*A_ji_ptr); - - const eT A_ij_abs = (std::abs)(A_ij); - const eT A_ji_abs = (std::abs)(A_ji); - - // if( (A_ij_abs >= max_diag) || (A_ji_abs >= max_diag) ) { return false; } - if(A_ij_abs >= max_diag) { return false; } - - const eT A_delta = (std::abs)(A_ij - A_ji); - const eT A_abs_max = (std::max)(A_ij_abs, A_ji_abs); - - if( (A_delta > tol) && (A_delta > (A_abs_max*tol)) ) { return false; } - - const eT A_ii = (*A_ii_ptr); - - if( (A_ij_abs + A_ij_abs) >= (A_ii + A_jj) ) { return false; } - - A_ji_ptr += N; - A_ii_ptr += Np1; - } - - A_col += N; - } - - return true; - } - - - -template -inline -typename enable_if2::yes, bool>::result -guess_sympd(const Mat& A) - { - arma_extra_debug_sigprint(); - - typedef typename get_pod_type::result T; - - if((A.n_rows != A.n_cols) || (A.n_rows < 16)) { return false; } - - const T tol = T(100) * std::numeric_limits::epsilon(); // allow some leeway - - const uword N = A.n_rows; - - const eT* A_mem = A.memptr(); - const eT* A_col = A_mem; - - T max_diag = T(0); - - for(uword j=0; j < N; ++j) - { - const eT& A_jj = A_col[j]; - const T A_jj_real = std::real(A_jj); - const T A_jj_imag = std::imag(A_jj); - - if( (A_jj_real <= T(0)) || (std::abs(A_jj_imag) > tol) ) { return false; } - - max_diag = (A_jj_real > max_diag) ? A_jj_real : max_diag; - - A_col += N; - } - - const T square_max_diag = max_diag * max_diag; - - if(arma_isfinite(square_max_diag) == false) { return false; } - - A_col = A_mem; - - const uword Nm1 = N-1; - const uword Np1 = N+1; - - for(uword j=0; j < Nm1; ++j) - { - const uword jp1 = j+1; - const eT* A_ji_ptr = &(A_mem[j + jp1*N]); // &(A.at(j,jp1)); - const eT* A_ii_ptr = &(A_mem[jp1 + jp1*N]); - - const T A_jj_real = std::real(A_col[j]); - - for(uword i=jp1; i < N; ++i) - { - const eT& A_ij = A_col[i]; - const T A_ij_real = std::real(A_ij); - const T A_ij_imag = std::imag(A_ij); - - // avoid using std::abs(), as that is time consuming due to division and std::sqrt() - const T square_A_ij_abs = (A_ij_real * A_ij_real) + (A_ij_imag * A_ij_imag); - - if(arma_isfinite(square_A_ij_abs) == false) { return false; } - - if(square_A_ij_abs >= square_max_diag) { return false; } - - const T A_ij_real_abs = (std::abs)(A_ij_real); - const T A_ij_imag_abs = (std::abs)(A_ij_imag); - - - const eT& A_ji = (*A_ji_ptr); - const T A_ji_real = std::real(A_ji); - const T A_ji_imag = std::imag(A_ji); - - const T A_ji_real_abs = (std::abs)(A_ji_real); - const T A_ji_imag_abs = (std::abs)(A_ji_imag); - - const T A_real_delta = (std::abs)(A_ij_real - A_ji_real); - const T A_real_abs_max = (std::max)(A_ij_real_abs, A_ji_real_abs); - - if( (A_real_delta > tol) && (A_real_delta > (A_real_abs_max*tol)) ) { return false; } - - - const T A_imag_delta = (std::abs)(A_ij_imag + A_ji_imag); // take into account complex conjugate - const T A_imag_abs_max = (std::max)(A_ij_imag_abs, A_ji_imag_abs); - - if( (A_imag_delta > tol) && (A_imag_delta > (A_imag_abs_max*tol)) ) { return false; } - - - const T A_ii_real = std::real(*A_ii_ptr); - - if( (A_ij_real_abs + A_ij_real_abs) >= (A_ii_real + A_jj_real) ) { return false; } - - A_ji_ptr += N; - A_ii_ptr += Np1; - } - - A_col += N; - } - - return true; - } - - - -} // end of namespace sympd_helper - - -//! @} diff --git a/src/armadillo_bits/traits.hpp b/src/armadillo_bits/traits.hpp index 9052bcce..bde22009 100644 --- a/src/armadillo_bits/traits.hpp +++ b/src/armadillo_bits/traits.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -37,7 +39,7 @@ struct is_Mat_fixed_only template static yes& check(typename X::Mat_fixed_type*); template static no& check(...); - static const bool value = ( sizeof(check(0)) == sizeof(yes) ); + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); }; @@ -51,7 +53,7 @@ struct is_Row_fixed_only template static yes& check(typename X::Row_fixed_type*); template static no& check(...); - static const bool value = ( sizeof(check(0)) == sizeof(yes) ); + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); }; @@ -65,165 +67,178 @@ struct is_Col_fixed_only template static yes& check(typename X::Col_fixed_type*); template static no& check(...); - static const bool value = ( sizeof(check(0)) == sizeof(yes) ); + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); }; template struct is_Mat_fixed - { static const bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; + { static constexpr bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; template struct is_Mat_only - { static const bool value = is_Mat_fixed_only::value; }; + { static constexpr bool value = is_Mat_fixed_only::value; }; template struct is_Mat_only< Mat > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat_only< const Mat > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat - { static const bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; + { static constexpr bool value = ( is_Mat_fixed_only::value || is_Row_fixed_only::value || is_Col_fixed_only::value ); }; template struct is_Mat< Mat > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat< const Mat > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat< Row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat< const Row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat< Col > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Mat< const Col > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Row - { static const bool value = is_Row_fixed_only::value; }; + { static constexpr bool value = is_Row_fixed_only::value; }; template struct is_Row< Row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Row< const Row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Col - { static const bool value = is_Col_fixed_only::value; }; + { static constexpr bool value = is_Col_fixed_only::value; }; template struct is_Col< Col > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Col< const Col > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_diagview - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_diagview< diagview > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_diagview< const diagview > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview< subview > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview< const subview > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_row - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_row< subview_row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_row< const subview_row > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_col - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_col< subview_col > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_col< const subview_col > - { static const bool value = true; }; + { static constexpr bool value = true; }; + + +template +struct is_subview_cols + { static constexpr bool value = false; }; + +template +struct is_subview_cols< subview_cols > + { static constexpr bool value = true; }; + +template +struct is_subview_cols< const subview_cols > + { static constexpr bool value = true; }; template struct is_subview_elem1 - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_elem1< subview_elem1 > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_elem1< const subview_elem1 > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_elem2 - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_elem2< subview_elem2 > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_elem2< const subview_elem2 > - { static const bool value = true; }; + { static constexpr bool value = true; }; @@ -235,39 +250,39 @@ struct is_subview_elem2< const subview_elem2 > template struct is_Cube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_Cube< Cube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Cube< const Cube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_cube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_cube< subview_cube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_cube< const subview_cube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_cube_slices - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_subview_cube_slices< subview_cube_slices > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_subview_cube_slices< const subview_cube_slices > - { static const bool value = true; }; + { static constexpr bool value = true; }; // @@ -277,149 +292,175 @@ struct is_subview_cube_slices< const subview_cube_slices > template struct is_Gen - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_Gen< Gen > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Gen< const Gen > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Op - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_Op< Op > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Op< const Op > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_CubeToMatOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_CubeToMatOp< CubeToMatOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_CubeToMatOp< const CubeToMatOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpToDOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpToDOp< SpToDOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpToDOp< const SpToDOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; + + +template +struct is_SpToDGlue + { static constexpr bool value = false; }; + +template +struct is_SpToDGlue< SpToDGlue > + { static constexpr bool value = true; }; + +template +struct is_SpToDGlue< const SpToDGlue > + { static constexpr bool value = true; }; template struct is_eOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_eOp< eOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_eOp< const eOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtOp< mtOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtOp< const mtOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Glue - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_Glue< Glue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_Glue< const Glue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_eGlue - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_eGlue< eGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_eGlue< const eGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtGlue - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtGlue< mtGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtGlue< const mtGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; // // +template +struct is_glue_times + { static constexpr bool value = false; }; + +template +struct is_glue_times< Glue > + { static constexpr bool value = true; }; + +template +struct is_glue_times< const Glue > + { static constexpr bool value = true; }; + + template struct is_glue_times_diag - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_glue_times_diag< Glue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_glue_times_diag< const Glue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_op_diagmat - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_op_diagmat< Op > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_op_diagmat< const Op > - { static const bool value = true; }; + { static constexpr bool value = true; }; // @@ -428,15 +469,15 @@ struct is_op_diagmat< const Op > template struct is_Mat_trans - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_Mat_trans< Op > - { static const bool value = is_Mat::value; }; + { static constexpr bool value = is_Mat::value; }; template struct is_Mat_trans< Op > - { static const bool value = is_Mat::value; }; + { static constexpr bool value = is_Mat::value; }; // @@ -445,65 +486,65 @@ struct is_Mat_trans< Op > template struct is_GenCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_GenCube< GenCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_OpCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_OpCube< OpCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_eOpCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_eOpCube< eOpCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtOpCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtOpCube< mtOpCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_GlueCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_GlueCube< GlueCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_eGlueCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_eGlueCube< eGlueCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtGlueCube - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtGlueCube< mtGlueCube > - { static const bool value = true; }; + { static constexpr bool value = true; }; // @@ -514,12 +555,10 @@ struct is_mtGlueCube< mtGlueCube > template struct is_arma_type2 { - static const bool value + static constexpr bool value = is_Mat::value || is_Gen::value || is_Op::value - || is_CubeToMatOp::value - || is_SpToDOp::value || is_Glue::value || is_eOp::value || is_eGlue::value @@ -529,8 +568,12 @@ struct is_arma_type2 || is_subview::value || is_subview_row::value || is_subview_col::value + || is_subview_cols::value || is_subview_elem1::value || is_subview_elem2::value + || is_CubeToMatOp::value + || is_SpToDOp::value + || is_SpToDGlue::value ; }; @@ -542,7 +585,7 @@ struct is_arma_type2 template struct is_arma_type { - static const bool value = is_arma_type2::value; + static constexpr bool value = is_arma_type2::value; }; @@ -550,7 +593,7 @@ struct is_arma_type template struct is_arma_cube_type { - static const bool value + static constexpr bool value = is_Cube::value || is_GenCube::value || is_OpCube::value @@ -574,103 +617,133 @@ struct is_arma_cube_type template struct is_SpMat - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpMat< SpMat > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpMat< SpCol > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpMat< SpRow > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpRow - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpRow< SpRow > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpCol - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpCol< SpCol > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpSubview - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpSubview< SpSubview > - { static const bool value = true; }; + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_col + { static constexpr bool value = false; }; + +template +struct is_SpSubview_col< SpSubview_col > + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_col_list + { static constexpr bool value = false; }; + +template +struct is_SpSubview_col_list< SpSubview_col_list > + { static constexpr bool value = true; }; + + +template +struct is_SpSubview_row + { static constexpr bool value = false; }; + +template +struct is_SpSubview_row< SpSubview_row > + { static constexpr bool value = true; }; template struct is_spdiagview - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_spdiagview< spdiagview > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpOp< SpOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_SpGlue - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_SpGlue< SpGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtSpOp - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtSpOp< mtSpOp > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_mtSpGlue - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_mtSpGlue< mtSpGlue > - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_arma_sparse_type { - static const bool value + static constexpr bool value = is_SpMat::value || is_SpSubview::value + || is_SpSubview_col::value + || is_SpSubview_col_list::value + || is_SpSubview_row::value || is_spdiagview::value || is_SpOp::value || is_SpGlue::value @@ -689,18 +762,18 @@ struct is_arma_sparse_type template struct is_same_type { - static const bool value = false; - static const bool yes = false; - static const bool no = true; + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; }; template struct is_same_type { - static const bool value = true; - static const bool yes = true; - static const bool no = false; + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; }; @@ -712,196 +785,206 @@ struct is_same_type template struct is_u8 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_u8 - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_s8 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_s8 - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_u16 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_u16 - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_s16 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_s16 - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_u32 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_u32 - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_s32 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_s32 - { static const bool value = true; }; + { static constexpr bool value = true; }; -#if defined(ARMA_USE_U64S64) - template - struct is_u64 - { static const bool value = false; }; +template +struct is_u64 + { static constexpr bool value = false; }; - template<> - struct is_u64 - { static const bool value = true; }; - - - template - struct is_s64 - { static const bool value = false; }; +template<> +struct is_u64 + { static constexpr bool value = true; }; - template<> - struct is_s64 - { static const bool value = true; }; -#endif + +template +struct is_s64 + { static constexpr bool value = false; }; + +template<> +struct is_s64 + { static constexpr bool value = true; }; template struct is_ulng_t - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_ulng_t - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_slng_t - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_slng_t - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_ulng_t_32 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_ulng_t_32 - { static const bool value = (sizeof(ulng_t) == 4); }; + { static constexpr bool value = (sizeof(ulng_t) == 4); }; template struct is_slng_t_32 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_slng_t_32 - { static const bool value = (sizeof(slng_t) == 4); }; + { static constexpr bool value = (sizeof(slng_t) == 4); }; template struct is_ulng_t_64 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_ulng_t_64 - { static const bool value = (sizeof(ulng_t) == 8); }; + { static constexpr bool value = (sizeof(ulng_t) == 8); }; template struct is_slng_t_64 - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_slng_t_64 - { static const bool value = (sizeof(slng_t) == 8); }; + { static constexpr bool value = (sizeof(slng_t) == 8); }; template struct is_uword - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_uword - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_sword - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_sword - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_float - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_float - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_double - { static const bool value = false; }; + { static constexpr bool value = false; }; template<> struct is_double - { static const bool value = true; }; + { static constexpr bool value = true; }; template struct is_real - { static const bool value = false; }; + { + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; + }; template<> struct is_real - { static const bool value = true; }; + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; template<> struct is_real - { static const bool value = true; }; + { + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; + }; @@ -909,18 +992,18 @@ struct is_real template struct is_cx { - static const bool value = false; - static const bool yes = false; - static const bool no = true; + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; }; // template<> template struct is_cx< std::complex > { - static const bool value = true; - static const bool yes = true; - static const bool no = false; + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; }; @@ -928,17 +1011,17 @@ struct is_cx< std::complex > template struct is_cx_float { - static const bool value = false; - static const bool yes = false; - static const bool no = true; + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; }; template<> struct is_cx_float< std::complex > { - static const bool value = true; - static const bool yes = true; - static const bool no = false; + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; }; @@ -946,17 +1029,17 @@ struct is_cx_float< std::complex > template struct is_cx_double { - static const bool value = false; - static const bool yes = false; - static const bool no = true; + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; }; template<> struct is_cx_double< std::complex > { - static const bool value = true; - static const bool yes = true; - static const bool no = false; + static constexpr bool value = true; + static constexpr bool yes = true; + static constexpr bool no = false; }; @@ -964,21 +1047,17 @@ struct is_cx_double< std::complex > template struct is_supported_elem_type { - static const bool value = \ + static constexpr bool value = \ is_u8::value || is_s8::value || is_u16::value || is_s16::value || is_u32::value || is_s32::value || -#if defined(ARMA_USE_U64S64) is_u64::value || is_s64::value || -#endif -#if defined(ARMA_ALLOW_LONG) is_ulng_t::value || is_slng_t::value || -#endif is_float::value || is_double::value || is_cx_float::value || @@ -990,7 +1069,7 @@ struct is_supported_elem_type template struct is_supported_blas_type { - static const bool value = \ + static constexpr bool value = \ is_float::value || is_double::value || is_cx_float::value || @@ -999,35 +1078,43 @@ struct is_supported_blas_type +template +struct has_blas_float_bug + { + #if defined(ARMA_BLAS_FLOAT_BUG) + static constexpr bool value = is_float::result>::value; + #else + static constexpr bool value = false; + #endif + }; + + + template struct is_signed { - static const bool value = true; + static constexpr bool value = true; }; -template<> struct is_signed { static const bool value = false; }; -template<> struct is_signed { static const bool value = false; }; -template<> struct is_signed { static const bool value = false; }; -#if defined(ARMA_USE_U64S64) -template<> struct is_signed { static const bool value = false; }; -#endif -#if defined(ARMA_ALLOW_LONG) -template<> struct is_signed { static const bool value = false; }; -#endif +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; +template<> struct is_signed { static constexpr bool value = false; }; template struct is_non_integral { - static const bool value = false; + static constexpr bool value = false; }; -template<> struct is_non_integral< float > { static const bool value = true; }; -template<> struct is_non_integral< double > { static const bool value = true; }; -template<> struct is_non_integral< std::complex > { static const bool value = true; }; -template<> struct is_non_integral< std::complex > { static const bool value = true; }; +template<> struct is_non_integral< float > { static constexpr bool value = true; }; +template<> struct is_non_integral< double > { static constexpr bool value = true; }; +template<> struct is_non_integral< std::complex > { static constexpr bool value = true; }; +template<> struct is_non_integral< std::complex > { static constexpr bool value = true; }; @@ -1059,17 +1146,17 @@ struct force_different_type template struct resolves_to_vector_default { - static const bool value = false; - static const bool yes = false; - static const bool no = true; + static constexpr bool value = false; + static constexpr bool yes = false; + static constexpr bool no = true; }; template struct resolves_to_vector_test { - static const bool value = (T1::is_col || T1::is_row || T1::is_xvec); - static const bool yes = (T1::is_col || T1::is_row || T1::is_xvec); - static const bool no = ((T1::is_col || T1::is_row || T1::is_xvec) == false); + static constexpr bool value = (T1::is_col || T1::is_row || T1::is_xvec); + static constexpr bool yes = (T1::is_col || T1::is_row || T1::is_xvec); + static constexpr bool no = ((T1::is_col || T1::is_row || T1::is_xvec) == false); }; @@ -1092,10 +1179,10 @@ struct resolves_to_sparse_vector : public resolves_to_vector_redirect -struct resolves_to_rowvector_default { static const bool value = false; }; +struct resolves_to_rowvector_default { static constexpr bool value = false; }; template -struct resolves_to_rowvector_test { static const bool value = T1::is_row; }; +struct resolves_to_rowvector_test { static constexpr bool value = T1::is_row; }; template @@ -1114,10 +1201,10 @@ struct resolves_to_rowvector : public resolves_to_rowvector_redirect -struct resolves_to_colvector_default { static const bool value = false; }; +struct resolves_to_colvector_default { static constexpr bool value = false; }; template -struct resolves_to_colvector_test { static const bool value = T1::is_col; }; +struct resolves_to_colvector_test { static constexpr bool value = T1::is_col; }; template @@ -1137,47 +1224,66 @@ struct resolves_to_colvector : public resolves_to_colvector_redirect struct is_outer_product - { static const bool value = false; }; + { static constexpr bool value = false; }; template struct is_outer_product< Glue > - { static const bool value = (resolves_to_colvector::value && resolves_to_rowvector::value); }; + { static constexpr bool value = (resolves_to_colvector::value && resolves_to_rowvector::value); }; + + + +template +struct has_op_inv_any + { static constexpr bool value = false; }; +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; +template +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; template -struct has_op_inv - { static const bool value = false; }; +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; template -struct has_op_inv< Op > - { static const bool value = true; }; +struct has_op_inv_any< Op > + { static constexpr bool value = true; }; template -struct has_op_inv< Glue, T2, glue_times> > - { static const bool value = true; }; +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; template -struct has_op_inv< Glue, glue_times> > - { static const bool value = true; }; +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; +template +struct has_op_inv_any< Glue, T2, glue_times> > + { static constexpr bool value = true; }; -template -struct has_op_inv_sympd - { static const bool value = false; }; +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; -template -struct has_op_inv_sympd< Op > - { static const bool value = true; }; +template +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; template -struct has_op_inv_sympd< Glue, T2, glue_times> > - { static const bool value = true; }; +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; template -struct has_op_inv_sympd< Glue, glue_times> > - { static const bool value = true; }; +struct has_op_inv_any< Glue, glue_times> > + { static constexpr bool value = true; }; + @@ -1190,7 +1296,7 @@ struct has_nested_op_traits template static yes& check(typename X::template traits*); template static no& check(...); - static const bool value = ( sizeof(check(0)) == sizeof(yes) ); + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); }; template @@ -1202,7 +1308,7 @@ struct has_nested_glue_traits template static yes& check(typename X::template traits*); template static no& check(...); - static const bool value = ( sizeof(check(0)) == sizeof(yes) ); + static constexpr bool value = ( sizeof(check(0)) == sizeof(yes) ); }; diff --git a/src/armadillo_bits/translate_arpack.hpp b/src/armadillo_bits/translate_arpack.hpp index 2eece27c..8482892a 100644 --- a/src/armadillo_bits/translate_arpack.hpp +++ b/src/armadillo_bits/translate_arpack.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,7 +17,7 @@ -#ifdef ARMA_USE_ARPACK +#if defined(ARMA_USE_ARPACK) //! \namespace arpack namespace for ARPACK functions namespace arpack @@ -33,12 +35,12 @@ namespace arpack #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_snaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dnaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info, 1, 1); } - else if (is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1); } else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_znaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info, 1, 1); } #else if( is_float::value) { typedef float T; arma_ignore(rwork); arma_fortran(arma_snaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } else if( is_double::value) { typedef double T; arma_ignore(rwork); arma_fortran(arma_dnaupd)(ido, bmat, n, which, nev, (T*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, info); } - else if (is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } + else if( is_cx_float::value) { typedef cx_float T; typedef float xT; arma_fortran(arma_cnaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } else if(is_cx_double::value) { typedef cx_double T; typedef double xT; arma_fortran(arma_znaupd)(ido, bmat, n, which, nev, (xT*) tol, (T*) resid, ncv, (T*) v, ldv, iparam, ipntr, (T*) workd, (T*) workl, lworkl, (xT*) rwork, info); } #endif } diff --git a/src/armadillo_bits/translate_atlas.hpp b/src/armadillo_bits/translate_atlas.hpp index 8dfac620..95d43d5a 100644 --- a/src/armadillo_bits/translate_atlas.hpp +++ b/src/armadillo_bits/translate_atlas.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -14,10 +16,12 @@ // ------------------------------------------------------------------------ -#ifdef ARMA_USE_ATLAS +#if defined(ARMA_USE_ATLAS) + +// TODO: remove support for ATLAS in next major version -//! \namespace atlas namespace for ATLAS functions (imported from the global namespace) +//! \namespace atlas namespace for ATLAS functions namespace atlas { @@ -138,7 +142,7 @@ namespace atlas void cblas_gemv ( - const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, const int M, const int N, const eT alpha, const eT *A, const int lda, @@ -152,25 +156,25 @@ namespace atlas if(is_float::value) { typedef float T; - arma_wrapper(cblas_sgemv)(Order, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + arma_wrapper(cblas_sgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); } else if(is_double::value) { typedef double T; - arma_wrapper(cblas_dgemv)(Order, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); + arma_wrapper(cblas_dgemv)(layout, TransA, M, N, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)X, incX, (const T)tmp_real(beta), (T*)Y, incY); } else if(is_cx_float::value) { typedef std::complex T; - arma_wrapper(cblas_cgemv)(Order, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + arma_wrapper(cblas_cgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); } else if(is_cx_double::value) { typedef std::complex T; - arma_wrapper(cblas_zgemv)(Order, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); + arma_wrapper(cblas_zgemv)(layout, TransA, M, N, (const T*)&alpha, (const T*)A, lda, (const T*)X, incX, (const T*)&beta, (T*)Y, incY); } } @@ -181,8 +185,8 @@ namespace atlas void cblas_gemm ( - const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_TRANS TransA, + const atlas_CBLAS_TRANS TransB, const int M, const int N, const int K, const eT alpha, const eT *A, const int lda, const eT *B, const int ldb, const eT beta, eT *C, const int ldc @@ -193,25 +197,25 @@ namespace atlas if(is_float::value) { typedef float T; - arma_wrapper(cblas_sgemm)(Order, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + arma_wrapper(cblas_sgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); } else if(is_double::value) { typedef double T; - arma_wrapper(cblas_dgemm)(Order, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); + arma_wrapper(cblas_dgemm)(layout, TransA, TransB, M, N, K, (const T)tmp_real(alpha), (const T*)A, lda, (const T*)B, ldb, (const T)tmp_real(beta), (T*)C, ldc); } else if(is_cx_float::value) { typedef std::complex T; - arma_wrapper(cblas_cgemm)(Order, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + arma_wrapper(cblas_cgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); } else if(is_cx_double::value) { typedef std::complex T; - arma_wrapper(cblas_zgemm)(Order, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); + arma_wrapper(cblas_zgemm)(layout, TransA, TransB, M, N, K, (const T*)&alpha, (const T*)A, lda, (const T*)B, ldb, (const T*)&beta, (T*)C, ldc); } } @@ -222,7 +226,7 @@ namespace atlas void cblas_syrk ( - const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const eT alpha, const eT* A, const int lda, const eT beta, eT* C, const int ldc ) @@ -232,13 +236,13 @@ namespace atlas if(is_float::value) { typedef float T; - arma_wrapper(cblas_ssyrk)(Order, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + arma_wrapper(cblas_ssyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); } else if(is_double::value) { typedef double T; - arma_wrapper(cblas_dsyrk)(Order, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); + arma_wrapper(cblas_dsyrk)(layout, Uplo, Trans, N, K, (const T)alpha, (const T*)A, lda, (const T)beta, (T*)C, ldc); } } @@ -249,7 +253,7 @@ namespace atlas void cblas_herk ( - const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, + const atlas_CBLAS_LAYOUT layout, const atlas_CBLAS_UPLO Uplo, const atlas_CBLAS_TRANS Trans, const int N, const int K, const T alpha, const std::complex* A, const int lda, const T beta, std::complex* C, const int ldc ) @@ -261,7 +265,7 @@ namespace atlas typedef float TT; typedef std::complex cx_TT; - arma_wrapper(cblas_cherk)(Order, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); + arma_wrapper(cblas_cherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); } else if(is_double::value) @@ -269,239 +273,10 @@ namespace atlas typedef double TT; typedef std::complex cx_TT; - arma_wrapper(cblas_zherk)(Order, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); - } - } - - - - template - inline - int - clapack_getrf - ( - const enum CBLAS_ORDER Order, const int M, const int N, - eT *A, const int lda, int *ipiv - ) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_sgetrf)(Order, M, N, (T*)A, lda, ipiv); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dgetrf)(Order, M, N, (T*)A, lda, ipiv); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cgetrf)(Order, M, N, (T*)A, lda, ipiv); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zgetrf)(Order, M, N, (T*)A, lda, ipiv); + arma_wrapper(cblas_zherk)(layout, Uplo, Trans, N, K, (const TT)alpha, (const cx_TT*)A, lda, (const TT)beta, (cx_TT*)C, ldc); } - - return -1; } - - - template - inline - int - clapack_getri - ( - const enum CBLAS_ORDER Order, const int N, eT *A, - const int lda, const int *ipiv - ) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_sgetri)(Order, N, (T*)A, lda, ipiv); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dgetri)(Order, N, (T*)A, lda, ipiv); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cgetri)(Order, N, (T*)A, lda, ipiv); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zgetri)(Order, N, (T*)A, lda, ipiv); - } - - return -1; - } - - - - template - inline - int - clapack_gesv - ( - const enum CBLAS_ORDER Order, - const int N, const int NRHS, - eT* A, const int lda, int* ipiv, - eT* B, const int ldb - ) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_sgesv)(Order, N, NRHS, (T*)A, lda, ipiv, (T*)B, ldb); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dgesv)(Order, N, NRHS, (T*)A, lda, ipiv, (T*)B, ldb); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cgesv)(Order, N, NRHS, (T*)A, lda, ipiv, (T*)B, ldb); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zgesv)(Order, N, NRHS, (T*)A, lda, ipiv, (T*)B, ldb); - } - - return -1; - } - - - - template - inline - int - clapack_potrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, eT *A, const int lda) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_spotrf)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dpotrf)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cpotrf)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zpotrf)(Order, Uplo, N, (T*)A, lda); - } - - return -1; - } - - - - template - inline - int - clapack_potri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, eT *A, const int lda) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_spotri)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dpotri)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cpotri)(Order, Uplo, N, (T*)A, lda); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zpotri)(Order, Uplo, N, (T*)A, lda); - } - - return -1; - } - - - - template - inline - int - clapack_posv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, eT *A, const int lda, eT *B, const int ldb) - { - arma_type_check((is_supported_blas_type::value == false)); - - if(is_float::value) - { - typedef float T; - return arma_wrapper(clapack_sposv)(Order, Uplo, N, NRHS, (T*)A, lda, (T*)B, ldb); - } - else - if(is_double::value) - { - typedef double T; - return arma_wrapper(clapack_dposv)(Order, Uplo, N, NRHS, (T*)A, lda, (T*)B, ldb); - } - else - if(is_cx_float::value) - { - typedef std::complex T; - return arma_wrapper(clapack_cposv)(Order, Uplo, N, NRHS, (T*)A, lda, (T*)B, ldb); - } - else - if(is_cx_double::value) - { - typedef std::complex T; - return arma_wrapper(clapack_zposv)(Order, Uplo, N, NRHS, (T*)A, lda, (T*)B, ldb); - } - - return -1; - } } #endif diff --git a/src/armadillo_bits/translate_blas.hpp b/src/armadillo_bits/translate_blas.hpp index 487752c5..91fb6a2d 100644 --- a/src/armadillo_bits/translate_blas.hpp +++ b/src/armadillo_bits/translate_blas.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,7 +17,7 @@ -#ifdef ARMA_USE_BLAS +#if defined(ARMA_USE_BLAS) //! \namespace blas namespace for BLAS functions @@ -128,7 +130,7 @@ namespace blas if(is_float::value) { - #if defined(ARMA_BLAS_SDOT_BUG) + #if defined(ARMA_BLAS_FLOAT_BUG) { if(n_elem == 0) { return eT(0); } diff --git a/src/armadillo_bits/translate_fftw3.hpp b/src/armadillo_bits/translate_fftw3.hpp new file mode 100644 index 00000000..1edd7276 --- /dev/null +++ b/src/armadillo_bits/translate_fftw3.hpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 +// +// Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) +// Copyright 2008-2016 National ICT Australia (NICTA) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ------------------------------------------------------------------------ + + +#if defined(ARMA_USE_FFTW3) + + +namespace fftw3 + { + template + arma_inline + void_ptr + plan_dft_1d(int N, eT* input, eT* output, int fftw3_sign, unsigned int fftw3_flags) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + return fftwf_plan_dft_1d(N, (cx_float*)input, (cx_float*)output, fftw3_sign, fftw3_flags); + } + else + if(is_cx_double::value) + { + return fftw_plan_dft_1d(N, (cx_double*)input, (cx_double*)output, fftw3_sign, fftw3_flags); + } + + return nullptr; + } + + + + template + arma_inline + void + execute(void_ptr plan) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_execute(plan); + } + else + if(is_cx_double::value) + { + fftw_execute(plan); + } + } + + + + template + arma_inline + void + destroy_plan(void_ptr plan) + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_destroy_plan(plan); + } + else + if(is_cx_double::value) + { + fftw_destroy_plan(plan); + } + } + + + + template + arma_inline + void + cleanup() + { + arma_type_check((is_cx::value == false)); + + if(is_cx_float::value) + { + fftwf_cleanup(); + } + else + if(is_cx_double::value) + { + fftw_cleanup(); + } + } + } + + +#endif diff --git a/src/armadillo_bits/translate_lapack.hpp b/src/armadillo_bits/translate_lapack.hpp index 662eb8b5..7ed4c0ec 100644 --- a/src/armadillo_bits/translate_lapack.hpp +++ b/src/armadillo_bits/translate_lapack.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -15,7 +17,7 @@ -#ifdef ARMA_USE_LAPACK +#if defined(ARMA_USE_LAPACK) //! \namespace lapack namespace for LAPACK functions @@ -408,6 +410,32 @@ namespace lapack + template + inline + void + geqp3(blas_int* m, blas_int* n, eT* a, blas_int* lda, blas_int* jpvt, eT* tau, eT* work, blas_int* lwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_float::value) { typedef float T; arma_fortran(arma_sgeqp3)(m, n, (T*)a, lda, jpvt, (T*)tau, (T*)work, lwork, info); } + else if(is_double::value) { typedef double T; arma_fortran(arma_dgeqp3)(m, n, (T*)a, lda, jpvt, (T*)tau, (T*)work, lwork, info); } + } + + + + template + inline + void + cx_geqp3(blas_int* m, blas_int* n, eT* a, blas_int* lda, blas_int* jpvt, eT* tau, eT* work, blas_int* lwork, typename eT::value_type* rwork, blas_int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if( is_cx_float::value) { typedef float T; typedef blas_cxf cx_T; arma_fortran(arma_cgeqp3)(m, n, (cx_T*)a, lda, jpvt, (cx_T*)tau, (cx_T*)work, lwork, (T*)rwork, info); } + else if(is_cx_double::value) { typedef double T; typedef blas_cxd cx_T; arma_fortran(arma_zgeqp3)(m, n, (cx_T*)a, lda, jpvt, (cx_T*)tau, (cx_T*)work, lwork, (T*)rwork, info); } + } + + + template inline void @@ -418,8 +446,8 @@ namespace lapack if( is_float::value) { typedef float T; arma_fortran(arma_sorgqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } else if(is_double::value) { typedef double T; arma_fortran(arma_dorgqr)(m, n, k, (T*)a, lda, (T*)tau, (T*)work, lwork, info); } } - - + + template inline @@ -1280,26 +1308,36 @@ namespace lapack template inline void - larnv(blas_int* idist, blas_int* iseed, const blas_int* n, eT* x) + gehrd(blas_int* n, blas_int* ilo, blas_int* ihi, eT* a, blas_int* lda, eT* tao, eT* work, blas_int* lwork, blas_int* info) { arma_type_check(( is_supported_blas_type::value == false )); - if( is_float::value) { typedef float T; arma_fortran(arma_slarnv)(idist, iseed, n, (T*)x); } - else if(is_double::value) { typedef double T; arma_fortran(arma_dlarnv)(idist, iseed, n, (T*)x); } + if( is_float::value) { typedef float T; arma_fortran(arma_sgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if( is_double::value) { typedef double T; arma_fortran(arma_dgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } } + template inline void - gehrd(blas_int* n, blas_int* ilo, blas_int* ihi, eT* a, blas_int* lda, eT* tao, eT* work, blas_int* lwork, blas_int* info) + pstrf(const char* uplo, const blas_int* n, eT* a, const blas_int* lda, blas_int* piv, blas_int* rank, const typename get_pod_type::result* tol, const typename get_pod_type::result* work, blas_int* info) { arma_type_check(( is_supported_blas_type::value == false )); - if( is_float::value) { typedef float T; arma_fortran(arma_sgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } - else if( is_double::value) { typedef double T; arma_fortran(arma_dgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } - else if( is_cx_float::value) { typedef blas_cxf T; arma_fortran(arma_cgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } - else if(is_cx_double::value) { typedef blas_cxd T; arma_fortran(arma_zgehrd)(n, ilo, ihi, (T*)a, lda, (T*)tao, (T*)work, lwork, info); } + #if defined(ARMA_USE_FORTRAN_HIDDEN_ARGS) + if( is_float::value) { typedef float pod_T; typedef float T; arma_fortran(arma_spstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if( is_double::value) { typedef double pod_T; typedef double T; arma_fortran(arma_dpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; arma_fortran(arma_cpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; arma_fortran(arma_zpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info, 1); } + #else + if( is_float::value) { typedef float pod_T; typedef float T; arma_fortran(arma_spstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if( is_double::value) { typedef double pod_T; typedef double T; arma_fortran(arma_dpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if( is_cx_float::value) { typedef float pod_T; typedef blas_cxf T; arma_fortran(arma_cpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + else if(is_cx_double::value) { typedef double pod_T; typedef blas_cxd T; arma_fortran(arma_zpstrf)(uplo, n, (T*)a, lda, piv, rank, (const pod_T*)tol, (pod_T*)work, info); } + #endif } diff --git a/src/armadillo_bits/translate_superlu.hpp b/src/armadillo_bits/translate_superlu.hpp index e653a6bb..a04f01e8 100644 --- a/src/armadillo_bits/translate_superlu.hpp +++ b/src/armadillo_bits/translate_superlu.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -97,6 +99,146 @@ namespace superlu + template + inline + void + gstrf(superlu_options_t* options, + SuperMatrix* A, + int relax, + int panel_size, int *etree, + void *work, int lwork, + int* perm_c, int* perm_r, + SuperMatrix* L, SuperMatrix* U, + GlobalLU_t* Glu, SuperLUStat_t* stat, int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + arma_wrapper(sgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_double::value) + { + arma_wrapper(dgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_cx_float::value) + { + arma_wrapper(cgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + else + if(is_cx_double::value) + { + arma_wrapper(zgstrf)(options, A, relax, panel_size, etree, work, lwork, perm_c, perm_r, L, U, Glu, stat, info); + } + } + + + + template + inline + void + gstrs(trans_t trans, + SuperMatrix* L, SuperMatrix* U, + int* perm_c, int* perm_r, + SuperMatrix* B, SuperLUStat_t* stat, int* info + ) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + arma_wrapper(sgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_double::value) + { + arma_wrapper(dgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_cx_float::value) + { + arma_wrapper(cgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + else + if(is_cx_double::value) + { + arma_wrapper(zgstrs)(trans, L, U, perm_c, perm_r, B, stat, info); + } + } + + + + template + inline + typename get_pod_type::result + langs(char* norm, superlu::SuperMatrix* A) + { + arma_type_check(( is_supported_blas_type::value == false )); + + typedef typename get_pod_type::result T; + + if(is_float::value) + { + return arma_wrapper(slangs)(norm, A); + } + else + if(is_double::value) + { + return arma_wrapper(dlangs)(norm, A); + } + else + if(is_cx_float::value) + { + return arma_wrapper(clangs)(norm, A); + } + else + if(is_cx_double::value) + { + return arma_wrapper(zlangs)(norm, A); + } + + return T(0); // to avoid false warnigns from the compiler + } + + + + template + inline + void + gscon(char* norm, superlu::SuperMatrix* L, superlu::SuperMatrix* U, typename get_pod_type::result anorm, typename get_pod_type::result* rcond, superlu::SuperLUStat_t* stat, int* info) + { + arma_type_check(( is_supported_blas_type::value == false )); + + if(is_float::value) + { + typedef float T; + arma_wrapper(sgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_double::value) + { + typedef double T; + arma_wrapper(dgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_cx_float::value) + { + typedef float T; + arma_wrapper(cgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + else + if(is_cx_double::value) + { + typedef double T; + arma_wrapper(zgscon)(norm, L, U, (T)anorm, (T*)rcond, stat, info); + } + } + + + inline void init_stat(SuperLUStat_t* stat) @@ -122,7 +264,33 @@ namespace superlu } - + inline + void + get_permutation_c(int ispec, SuperMatrix* A, int* perm_c) + { + arma_wrapper(get_perm_c)(ispec, A, perm_c); + } + + + + inline + void + sp_preorder_mat(superlu_options_t* opts, SuperMatrix* A, int* perm_c, int* etree, SuperMatrix* AC) + { + arma_wrapper(sp_preorder)(opts, A, perm_c, etree, AC); + } + + + + inline + int + sp_ispec_environ(int ispec) + { + return arma_wrapper(sp_ienv)(ispec); + } + + + inline void destroy_supernode_mat(SuperMatrix* a) @@ -141,6 +309,15 @@ namespace superlu + inline + void + destroy_compcolperm_mat(SuperMatrix* a) + { + arma_wrapper(Destroy_CompCol_Permuted)(a); + } + + + inline void destroy_dense_mat(SuperMatrix* a) diff --git a/src/armadillo_bits/trimat_helper.hpp b/src/armadillo_bits/trimat_helper.hpp index 415b3892..9242083d 100644 --- a/src/armadillo_bits/trimat_helper.hpp +++ b/src/armadillo_bits/trimat_helper.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -37,19 +39,15 @@ is_triu(const Mat& A) if(N < 2) { return false; } - const eT* A_mem = A.memptr(); + const eT* A_col = A.memptr(); const eT eT_zero = eT(0); - // quickly check bottom-left corner - const eT* A_col0 = A_mem; - const eT* A_col1 = A_col0 + N; + // quickly check element at bottom-left - if( (A_col0[N-2] != eT_zero) || (A_col0[Nm1] != eT_zero) || (A_col1[Nm1] != eT_zero) ) { return false; } + if(A_col[Nm1] != eT_zero) { return false; } // if we got to this point, do a thorough check - const eT* A_col = A_mem; - for(uword j=0; j < Nm1; ++j) { for(uword i=(j+1); i < N; ++i) @@ -82,11 +80,11 @@ is_tril(const Mat& A) const eT eT_zero = eT(0); - // quickly check top-right corner - const eT* A_colNm2 = A.colptr(N-2); - const eT* A_colNm1 = A_colNm2 + N; + // quickly check element at top-right + + const eT* A_colNm1 = A.colptr(N-1); - if( (A_colNm2[0] != eT_zero) || (A_colNm1[0] != eT_zero) || (A_colNm1[1] != eT_zero) ) { return false; } + if(A_colNm1[0] != eT_zero) { return false; } // if we got to this point, do a thorough check @@ -109,6 +107,58 @@ is_tril(const Mat& A) +template +inline +bool +has_nonfinite_tril(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const eT* colptr = A.memptr(); + const uword N = A.n_rows; + + for(uword i=0; i +inline +bool +has_nonfinite_triu(const Mat& A) + { + arma_extra_debug_sigprint(); + + // NOTE: assuming that A has a square size + + const eT* colptr = A.memptr(); + const uword N = A.n_rows; + + for(uword i=0; i= 0xffff @@ -60,28 +61,18 @@ #endif -#if defined(ARMA_USE_U64S64) - #if ULLONG_MAX >= 0xffffffffffffffff - typedef unsigned long long u64; - typedef long long s64; - #elif ULONG_MAX >= 0xffffffffffffffff - typedef unsigned long u64; - typedef long s64; - #define ARMA_U64_IS_LONG - #elif defined(UINT64_MAX) - typedef uint64_t u64; - typedef int64_t s64; - #else - #error "don't know how to typedef 'u64' on this system; please disable ARMA_64BIT_WORD" - #endif -#endif - - -#if !defined(ARMA_USE_U64S64) || (defined(ARMA_USE_U64S64) && !defined(ARMA_U64_IS_LONG)) - #define ARMA_ALLOW_LONG +#if ULLONG_MAX >= 0xffffffffffffffff + typedef unsigned long long u64; + typedef long long s64; +#elif defined(UINT64_MAX) + typedef uint64_t u64; + typedef int64_t s64; +#else + #error "don't know how to typedef 'u64' on this system" #endif +// for compatibility with earlier versions of Armadillo typedef unsigned long ulng_t; typedef long slng_t; @@ -131,7 +122,7 @@ typedef void* void_ptr; // -#ifdef ARMA_USE_MKL_TYPES +#if defined(ARMA_USE_MKL_TYPES) // for compatibility with MKL typedef MKL_Complex8 blas_cxf; typedef MKL_Complex16 blas_cxd; diff --git a/src/armadillo_bits/typedef_elem_check.hpp b/src/armadillo_bits/typedef_elem_check.hpp index a52d0006..db462ab1 100644 --- a/src/armadillo_bits/typedef_elem_check.hpp +++ b/src/armadillo_bits/typedef_elem_check.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -22,29 +24,23 @@ namespace junk { struct arma_elem_size_test { + arma_static_check( (sizeof(u8) != 1), "error: type 'u8' has unsupported size" ); + arma_static_check( (sizeof(s8) != 1), "error: type 's8' has unsupported size" ); - // arma_static_check( (sizeof(size_t) < sizeof(uword)), ERROR___TYPE_SIZE_T_IS_SMALLER_THAN_UWORD ); - - arma_static_check( (sizeof(u8) != 1), ERROR___TYPE_U8_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(s8) != 1), ERROR___TYPE_S8_HAS_UNSUPPORTED_SIZE ); - - arma_static_check( (sizeof(u16) != 2), ERROR___TYPE_U16_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(s16) != 2), ERROR___TYPE_S16_HAS_UNSUPPORTED_SIZE ); - - arma_static_check( (sizeof(u32) != 4), ERROR___TYPE_U32_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(s32) != 4), ERROR___TYPE_S32_HAS_UNSUPPORTED_SIZE ); + arma_static_check( (sizeof(u16) != 2), "error: type 'u16' has unsupported size" ); + arma_static_check( (sizeof(s16) != 2), "error: type 's16' has unsupported size" ); - #if defined(ARMA_USE_U64S64) - arma_static_check( (sizeof(u64) != 8), ERROR___TYPE_U64_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(s64) != 8), ERROR___TYPE_S64_HAS_UNSUPPORTED_SIZE ); - #endif + arma_static_check( (sizeof(u32) != 4), "error: type 'u32' has unsupported size" ); + arma_static_check( (sizeof(s32) != 4), "error: type 's32' has unsupported size" ); - arma_static_check( (sizeof(float) != 4), ERROR___TYPE_FLOAT_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(double) != 8), ERROR___TYPE_DOUBLE_HAS_UNSUPPORTED_SIZE ); + arma_static_check( (sizeof(u64) != 8), "error: type 'u64' has unsupported size" ); + arma_static_check( (sizeof(s64) != 8), "error: type 's64' has unsupported size" ); - arma_static_check( (sizeof(std::complex) != 8), ERROR___TYPE_COMPLEX_FLOAT_HAS_UNSUPPORTED_SIZE ); - arma_static_check( (sizeof(std::complex) != 16), ERROR___TYPE_COMPLEX_DOUBLE_HAS_UNSUPPORTED_SIZE ); + arma_static_check( (sizeof(float) != 4), "error: type 'float' has unsupported size" ); + arma_static_check( (sizeof(double) != 8), "error: type 'double' has unsupported size" ); + arma_static_check( (sizeof(std::complex) != 8), "type 'std::complex' has unsupported size" ); + arma_static_check( (sizeof(std::complex) != 16), "type 'std::complex' has unsupported size" ); }; } diff --git a/src/armadillo_bits/typedef_mat.hpp b/src/armadillo_bits/typedef_mat.hpp index cc353171..69a4c90f 100644 --- a/src/armadillo_bits/typedef_mat.hpp +++ b/src/armadillo_bits/typedef_mat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -36,19 +38,17 @@ typedef Col s32_colvec; typedef Row s32_rowvec; typedef Cube s32_cube; -#if defined(ARMA_USE_U64S64) - typedef Mat u64_mat; - typedef Col u64_vec; - typedef Col u64_colvec; - typedef Row u64_rowvec; - typedef Cube u64_cube; - - typedef Mat s64_mat; - typedef Col s64_vec; - typedef Col s64_colvec; - typedef Row s64_rowvec; - typedef Cube s64_cube; -#endif +typedef Mat u64_mat; +typedef Col u64_vec; +typedef Col u64_colvec; +typedef Row u64_rowvec; +typedef Cube u64_cube; + +typedef Mat s64_mat; +typedef Col s64_vec; +typedef Col s64_colvec; +typedef Row s64_rowvec; +typedef Cube s64_cube; typedef Mat umat; typedef Col uvec; diff --git a/src/armadillo_bits/typedef_mat_fixed.hpp b/src/armadillo_bits/typedef_mat_fixed.hpp index 2b155204..bd45615a 100644 --- a/src/armadillo_bits/typedef_mat_fixed.hpp +++ b/src/armadillo_bits/typedef_mat_fixed.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/unwrap.hpp b/src/armadillo_bits/unwrap.hpp index 0ebee98b..4e935067 100644 --- a/src/armadillo_bits/unwrap.hpp +++ b/src/armadillo_bits/unwrap.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -68,11 +70,11 @@ struct unwrap_redirect { typedef unwrap_fixed result; }; template -struct unwrap : public unwrap_redirect::value >::result +struct unwrap : public unwrap_redirect::value>::result { inline unwrap(const T1& A) - : unwrap_redirect< T1, is_Mat_fixed::value >::result(A) + : unwrap_redirect::value>::result(A) { } }; @@ -130,6 +132,40 @@ struct unwrap< Col > +template +struct unwrap< subview_col > + { + typedef Col stored_type; + + inline + unwrap(const subview_col& A) + : M(A.colmem, A.n_rows) + { + arma_extra_debug_sigprint(); + } + + const Col M; + }; + + + +template +struct unwrap< subview_cols > + { + typedef Mat stored_type; + + inline + unwrap(const subview_cols& A) + : M(A.colptr(0), A.n_rows, A.n_cols) + { + arma_extra_debug_sigprint(); + } + + const Mat M; + }; + + + template struct unwrap< mtGlue > { @@ -185,12 +221,12 @@ struct quasi_unwrap_default // NOTE: DO NOT DIRECTLY CHECK FOR ALIASING BY TAKING THE ADDRESS OF THE "M" OBJECT IN ANY quasi_unwrap CLASS !!! Mat M; - static const bool is_const = false; - static const bool has_subview = false; - static const bool has_orig_mem = false; + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -209,9 +245,9 @@ struct quasi_unwrap_fixed const T1& M; - static const bool is_const = true; - static const bool has_subview = false; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } @@ -230,9 +266,9 @@ struct quasi_unwrap_redirect { typedef quasi_unwrap_fixed resul template -struct quasi_unwrap : public quasi_unwrap_redirect::value >::result +struct quasi_unwrap : public quasi_unwrap_redirect::value>::result { - typedef typename quasi_unwrap_redirect::value >::result quasi_unwrap_extra; + typedef typename quasi_unwrap_redirect::value>::result quasi_unwrap_extra; inline quasi_unwrap(const T1& A) @@ -240,9 +276,9 @@ struct quasi_unwrap : public quasi_unwrap_redirect::value > { } - static const bool is_const = quasi_unwrap_extra::is_const; - static const bool has_subview = quasi_unwrap_extra::has_subview; - static const bool has_orig_mem = quasi_unwrap_extra::has_orig_mem; + static constexpr bool is_const = quasi_unwrap_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_extra::has_orig_mem; using quasi_unwrap_extra::M; using quasi_unwrap_extra::is_alias; @@ -262,9 +298,9 @@ struct quasi_unwrap< Mat > const Mat& M; - static const bool is_const = true; - static const bool has_subview = false; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } @@ -285,9 +321,9 @@ struct quasi_unwrap< Row > const Row& M; - static const bool is_const = true; - static const bool has_subview = false; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } @@ -307,9 +343,9 @@ struct quasi_unwrap< Col > const Col& M; - static const bool is_const = true; - static const bool has_subview = false; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&M) == void_ptr(&X)); } @@ -331,9 +367,9 @@ struct quasi_unwrap< subview > const subview& sv; const Mat M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = false; // NOTE: set to false as this is the general case; original memory is only used when the subview is a contiguous chunk + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = false; // NOTE: set to false as this is the general case; original memory is only used when the subview is a contiguous chunk template arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } @@ -353,12 +389,12 @@ struct quasi_unwrap< subview_row > Row M; - static const bool is_const = false; - static const bool has_subview = false; - static const bool has_orig_mem = false; + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -369,7 +405,7 @@ struct quasi_unwrap< subview_col > inline quasi_unwrap(const subview_col& A) : orig( A.m ) - , M ( const_cast( A.colptr(0) ), A.n_rows, false, false ) + , M ( const_cast( A.colmem ), A.n_rows, false, false ) { arma_extra_debug_sigprint(); } @@ -377,9 +413,33 @@ struct quasi_unwrap< subview_col > const Mat& orig; const Col M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + }; + + + +template +struct quasi_unwrap< subview_cols > + { + inline + quasi_unwrap(const subview_cols& A) + : orig( A.m ) + , M ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + const Mat& orig; + const Mat M; + + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -399,12 +459,12 @@ struct quasi_unwrap< mtGlue > Mat M; - static const bool is_const = false; - static const bool has_subview = false; - static const bool has_orig_mem = false; + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -421,12 +481,12 @@ struct quasi_unwrap< mtOp > Mat M; - static const bool is_const = false; - static const bool has_subview = false; - static const bool has_orig_mem = false; + static constexpr bool is_const = false; + static constexpr bool has_subview = false; + static constexpr bool has_orig_mem = false; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -444,15 +504,15 @@ struct quasi_unwrap< Op > arma_extra_debug_sigprint(); } - const unwrap U; - const Mat M; + const quasi_unwrap U; + const Mat M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template - arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&(U.M)) == void_ptr(&X)); } + arma_inline bool is_alias(const Mat& X) const { return U.is_alias(X); } }; @@ -471,9 +531,9 @@ struct quasi_unwrap< Op, op_strans> > const Col& orig; const Row M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -495,9 +555,9 @@ struct quasi_unwrap< Op, op_strans> > const Row& orig; const Col M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -511,7 +571,7 @@ struct quasi_unwrap< Op, op_strans> > inline quasi_unwrap(const Op, op_strans>& A) : orig( A.m.m ) - , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, false, false ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) { arma_extra_debug_sigprint(); } @@ -519,9 +579,9 @@ struct quasi_unwrap< Op, op_strans> > const Mat& orig; const Row M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } @@ -551,9 +611,9 @@ struct quasi_unwrap_Col_htrans< Op, op_htrans> > const Col& orig; const Row M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -583,9 +643,9 @@ struct quasi_unwrap< Op, op_htrans> > { } - static const bool is_const = quasi_unwrap_Col_htrans_extra::is_const; - static const bool has_subview = quasi_unwrap_Col_htrans_extra::has_subview; - static const bool has_orig_mem = quasi_unwrap_Col_htrans_extra::has_orig_mem; + static constexpr bool is_const = quasi_unwrap_Col_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_Col_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_Col_htrans_extra::has_orig_mem; using quasi_unwrap_Col_htrans_extra::M; using quasi_unwrap_Col_htrans_extra::is_alias; @@ -615,9 +675,9 @@ struct quasi_unwrap_Row_htrans< Op, op_htrans> > const Row& orig; const Col M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -647,9 +707,9 @@ struct quasi_unwrap< Op, op_htrans> > { } - static const bool is_const = quasi_unwrap_Row_htrans_extra::is_const; - static const bool has_subview = quasi_unwrap_Row_htrans_extra::has_subview; - static const bool has_orig_mem = quasi_unwrap_Row_htrans_extra::has_orig_mem; + static constexpr bool is_const = quasi_unwrap_Row_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_Row_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_Row_htrans_extra::has_orig_mem; using quasi_unwrap_Row_htrans_extra::M; using quasi_unwrap_Row_htrans_extra::is_alias; @@ -671,7 +731,7 @@ struct quasi_unwrap_subview_col_htrans< Op, op_htrans> > inline quasi_unwrap_subview_col_htrans(const Op, op_htrans>& A) : orig(A.m.m) - , M (const_cast(A.m.colptr(0)), A.m.n_rows, false, false) + , M (const_cast(A.m.colmem), A.m.n_rows, false, false) { arma_extra_debug_sigprint(); } @@ -679,9 +739,9 @@ struct quasi_unwrap_subview_col_htrans< Op, op_htrans> > const Mat& orig; const Row M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } @@ -711,9 +771,9 @@ struct quasi_unwrap< Op, op_htrans> > { } - static const bool is_const = quasi_unwrap_subview_col_htrans_extra::is_const; - static const bool has_subview = quasi_unwrap_subview_col_htrans_extra::has_subview; - static const bool has_orig_mem = quasi_unwrap_subview_col_htrans_extra::has_orig_mem; + static constexpr bool is_const = quasi_unwrap_subview_col_htrans_extra::is_const; + static constexpr bool has_subview = quasi_unwrap_subview_col_htrans_extra::has_subview; + static constexpr bool has_orig_mem = quasi_unwrap_subview_col_htrans_extra::has_orig_mem; using quasi_unwrap_subview_col_htrans_extra::M; using quasi_unwrap_subview_col_htrans_extra::is_alias; @@ -737,12 +797,12 @@ struct quasi_unwrap< CubeToMatOp > const unwrap_cube U; const Mat M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -763,12 +823,12 @@ struct quasi_unwrap< SpToDOp > const unwrap_spmat U; const Mat M; - static const bool is_const = true; - static const bool has_subview = true; - static const bool has_orig_mem = true; + static constexpr bool is_const = true; + static constexpr bool has_subview = true; + static constexpr bool has_orig_mem = true; template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } }; @@ -812,16 +872,16 @@ struct unwrap_check_fixed inline unwrap_check_fixed(const T1& A, const Mat& B) - : M_local( (&A == &B) ? new T1(A) : 0 ) - , M ( (&A == &B) ? *M_local : A ) + : M_local( (&A == &B) ? new T1(A) : nullptr ) + , M ( (&A == &B) ? *M_local : A ) { arma_extra_debug_sigprint(); } inline unwrap_check_fixed(const T1& A, const bool is_alias) - : M_local( is_alias ? new T1(A) : 0 ) - , M ( is_alias ? *M_local : A ) + : M_local( is_alias ? new T1(A) : nullptr ) + , M ( is_alias ? *M_local : A ) { arma_extra_debug_sigprint(); } @@ -853,15 +913,15 @@ struct unwrap_check_redirect { typedef unwrap_check_fixed resul template -struct unwrap_check : public unwrap_check_redirect::value >::result +struct unwrap_check : public unwrap_check_redirect::value>::result { inline unwrap_check(const T1& A, const Mat& B) - : unwrap_check_redirect< T1, is_Mat_fixed::value >::result(A, B) + : unwrap_check_redirect::value>::result(A, B) { } inline unwrap_check(const T1& A, const bool is_alias) - : unwrap_check_redirect< T1, is_Mat_fixed::value >::result(A, is_alias) + : unwrap_check_redirect::value>::result(A, is_alias) { } }; @@ -875,16 +935,16 @@ struct unwrap_check< Mat > inline unwrap_check(const Mat& A, const Mat& B) - : M_local( (&A == &B) ? new Mat(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local( (&A == &B) ? new Mat(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } inline unwrap_check(const Mat& A, const bool is_alias) - : M_local( is_alias ? new Mat(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Mat(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -912,16 +972,16 @@ struct unwrap_check< Row > inline unwrap_check(const Row& A, const Mat& B) - : M_local( (&A == &B) ? new Row(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local( (&A == &B) ? new Row(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } inline unwrap_check(const Row& A, const bool is_alias) - : M_local( is_alias ? new Row(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Row(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -949,16 +1009,16 @@ struct unwrap_check< Col > inline unwrap_check(const Col& A, const Mat& B) - : M_local( (&A == &B) ? new Col(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local( (&A == &B) ? new Col(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } inline unwrap_check(const Col& A, const bool is_alias) - : M_local( is_alias ? new Col(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Col(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1017,8 +1077,8 @@ struct unwrap_check_mixed< Mat > template inline unwrap_check_mixed(const Mat& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Mat(A) : 0 ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Mat(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1026,8 +1086,8 @@ struct unwrap_check_mixed< Mat > //template inline unwrap_check_mixed(const Mat& A, const bool is_alias) - : M_local( is_alias ? new Mat(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Mat(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1054,8 +1114,8 @@ struct unwrap_check_mixed< Row > template inline unwrap_check_mixed(const Row& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Row(A) : 0 ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Row(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1064,8 +1124,8 @@ struct unwrap_check_mixed< Row > //template inline unwrap_check_mixed(const Row& A, const bool is_alias) - : M_local( is_alias ? new Row(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Row(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1092,8 +1152,8 @@ struct unwrap_check_mixed< Col > template inline unwrap_check_mixed(const Col& A, const Mat& B) - : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Col(A) : 0 ) - , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) + : M_local( (void_ptr(&A) == void_ptr(&B)) ? new Col(A) : nullptr ) + , M ( (void_ptr(&A) == void_ptr(&B)) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1101,8 +1161,8 @@ struct unwrap_check_mixed< Col > //template inline unwrap_check_mixed(const Col& A, const bool is_alias) - : M_local( is_alias ? new Col(A) : 0 ) - , M ( is_alias ? (*M_local) : A ) + : M_local( is_alias ? new Col(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -1142,13 +1202,13 @@ struct partial_unwrap_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Mat M; }; @@ -1167,13 +1227,13 @@ struct partial_unwrap_fixed arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const T1& M; }; @@ -1190,11 +1250,11 @@ template struct partial_unwrap_redirect { typedef partial_unwrap_fixed result; }; template -struct partial_unwrap : public partial_unwrap_redirect::value >::result +struct partial_unwrap : public partial_unwrap_redirect::value>::result { inline partial_unwrap(const T1& A) - : partial_unwrap_redirect< T1, is_Mat_fixed::value >::result(A) + : partial_unwrap_redirect< T1, is_Mat_fixed::value>::result(A) { } }; @@ -1213,13 +1273,13 @@ struct partial_unwrap< Mat > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Mat& M; }; @@ -1238,13 +1298,13 @@ struct partial_unwrap< Row > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Row& M; }; @@ -1263,13 +1323,13 @@ struct partial_unwrap< Col > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Col& M; }; @@ -1289,13 +1349,13 @@ struct partial_unwrap< subview > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const subview& sv; const Mat M; @@ -1311,18 +1371,18 @@ struct partial_unwrap< subview_col > inline partial_unwrap(const subview_col& A) : orig( A.m ) - , M ( const_cast( A.colptr(0) ), A.n_rows, false, false ) + , M ( const_cast( A.colmem ), A.n_rows, false, false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Mat& orig; const Col M; @@ -1330,6 +1390,33 @@ struct partial_unwrap< subview_col > +template +struct partial_unwrap< subview_cols > + { + typedef Mat stored_type; + + inline + partial_unwrap(const subview_cols& A) + : orig( A.m ) + , M ( const_cast( A.colptr(0) ), A.n_rows, A.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } + + static constexpr bool do_trans = false; + static constexpr bool do_times = false; + + const Mat& orig; + const Mat M; + }; + + + template struct partial_unwrap< subview_row > { @@ -1342,13 +1429,13 @@ struct partial_unwrap< subview_row > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Row M; }; @@ -1368,13 +1455,13 @@ struct partial_unwrap_htrans_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Mat M; }; @@ -1393,13 +1480,13 @@ struct partial_unwrap_htrans_fixed arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const T1& M; }; @@ -1416,10 +1503,10 @@ template struct partial_unwrap_htrans_redirect { typedef partial_unwrap_htrans_fixed result; }; template -struct partial_unwrap< Op > : public partial_unwrap_htrans_redirect::value >::result +struct partial_unwrap< Op > : public partial_unwrap_htrans_redirect::value>::result { inline partial_unwrap(const Op& A) - : partial_unwrap_htrans_redirect< T1, is_Mat_fixed::value >::result(A) + : partial_unwrap_htrans_redirect::value>::result(A) { } }; @@ -1438,13 +1525,13 @@ struct partial_unwrap< Op< Mat, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Mat& M; }; @@ -1463,13 +1550,13 @@ struct partial_unwrap< Op< Row, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Row& M; }; @@ -1488,19 +1575,73 @@ struct partial_unwrap< Op< Col, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Col& M; }; +template +struct partial_unwrap< Op< subview, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview, op_htrans>& A) + : sv( A.m ) + , M ( A.m, ((A.m.aux_row1 == 0) && (A.m.n_rows == A.m.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const subview& sv; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_cols, op_htrans> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview_cols, op_htrans>& A) + : orig( A.m.m ) + , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, A.m.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + constexpr eT get_val() const { return eT(1); } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = false; + + const Mat& orig; + const Mat M; + }; + + + template struct partial_unwrap< Op< subview_col, op_htrans> > { @@ -1509,18 +1650,18 @@ struct partial_unwrap< Op< subview_col, op_htrans> > inline partial_unwrap(const Op< subview_col, op_htrans>& A) : orig( A.m.m ) - , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, false, false ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Mat& orig; const Col M; @@ -1540,13 +1681,13 @@ struct partial_unwrap< Op< subview_row, op_htrans> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Row M; }; @@ -1570,10 +1711,10 @@ struct partial_unwrap_htrans2_default arma_inline eT get_val() const { return val; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Mat M; @@ -1599,8 +1740,8 @@ struct partial_unwrap_htrans2_fixed template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const T1& M; @@ -1618,10 +1759,10 @@ template struct partial_unwrap_htrans2_redirect { typedef partial_unwrap_htrans2_fixed result; }; template -struct partial_unwrap< Op > : public partial_unwrap_htrans2_redirect::value >::result +struct partial_unwrap< Op > : public partial_unwrap_htrans2_redirect::value>::result { inline partial_unwrap(const Op& A) - : partial_unwrap_htrans2_redirect< T1, is_Mat_fixed::value >::result(A) + : partial_unwrap_htrans2_redirect::value>::result(A) { } }; @@ -1646,8 +1787,8 @@ struct partial_unwrap< Op< Mat, op_htrans2> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Mat& M; @@ -1673,8 +1814,8 @@ struct partial_unwrap< Op< Row, op_htrans2> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Row& M; @@ -1700,8 +1841,8 @@ struct partial_unwrap< Op< Col, op_htrans2> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Col& M; @@ -1709,6 +1850,64 @@ struct partial_unwrap< Op< Col, op_htrans2> > +template +struct partial_unwrap< Op< subview, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview, op_htrans2>& A) + : sv ( A.m ) + , val( A.aux ) + , M ( A.m, ((A.m.aux_row1 == 0) && (A.m.n_rows == A.m.m.n_rows)) ) // reuse memory if the subview is a contiguous chunk + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return ( ((sv.aux_row1 == 0) && (sv.n_rows == sv.m.n_rows)) ? (void_ptr(&(sv.m)) == void_ptr(&X)) : false ); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const subview& sv; + const eT val; + const Mat M; + }; + + + +template +struct partial_unwrap< Op< subview_cols, op_htrans2> > + { + typedef Mat stored_type; + + inline + partial_unwrap(const Op< subview_cols, op_htrans2>& A) + : orig( A.m.m ) + , val ( A.aux ) + , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, A.m.n_cols, false, false ) + { + arma_extra_debug_sigprint(); + } + + inline eT get_val() const { return val; } + + template + arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&orig) == void_ptr(&X)); } + + static constexpr bool do_trans = true; + static constexpr bool do_times = true; + + const Mat& orig; + const eT val; + const Mat M; + }; + + + template struct partial_unwrap< Op< subview_col, op_htrans2> > { @@ -1718,7 +1917,7 @@ struct partial_unwrap< Op< subview_col, op_htrans2> > partial_unwrap(const Op< subview_col, op_htrans2>& A) : orig( A.m.m ) , val ( A.aux ) - , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, false, false ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, false, false ) { arma_extra_debug_sigprint(); } @@ -1728,8 +1927,8 @@ struct partial_unwrap< Op< subview_col, op_htrans2> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const Mat& orig; @@ -1755,10 +1954,10 @@ struct partial_unwrap< Op< subview_row, op_htrans2> > arma_inline eT get_val() const { return val; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Row M; @@ -1783,10 +1982,10 @@ struct partial_unwrap_scalar_times_default arma_inline eT get_val() const { return val; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Mat M; @@ -1813,8 +2012,8 @@ struct partial_unwrap_scalar_times_fixed template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const T1& M; @@ -1833,13 +2032,13 @@ struct partial_unwrap_scalar_times_redirect { typedef partial_unwrap_ template -struct partial_unwrap< eOp > : public partial_unwrap_scalar_times_redirect::value >::result +struct partial_unwrap< eOp > : public partial_unwrap_scalar_times_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap(const eOp& A) - : partial_unwrap_scalar_times_redirect< T1, is_Mat_fixed::value >::result(A) + : partial_unwrap_scalar_times_redirect< T1, is_Mat_fixed::value>::result(A) { } }; @@ -1864,8 +2063,8 @@ struct partial_unwrap< eOp, eop_scalar_times> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Mat& M; @@ -1891,8 +2090,8 @@ struct partial_unwrap< eOp, eop_scalar_times> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Row& M; @@ -1918,8 +2117,8 @@ struct partial_unwrap< eOp, eop_scalar_times> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Col& M; @@ -1936,7 +2135,7 @@ struct partial_unwrap< eOp, eop_scalar_times> > partial_unwrap(const eOp,eop_scalar_times>& A) : orig( A.P.Q.m ) , val ( A.aux ) - , M ( const_cast( A.P.Q.colptr(0) ), A.P.Q.n_rows, false, false ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, false, false ) { arma_extra_debug_sigprint(); } @@ -1946,8 +2145,8 @@ struct partial_unwrap< eOp, eop_scalar_times> > template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat& orig; @@ -1973,10 +2172,10 @@ struct partial_unwrap< eOp, eop_scalar_times> > arma_inline eT get_val() const { return val; } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Row M; @@ -1997,13 +2196,13 @@ struct partial_unwrap_neg_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat M; }; @@ -2023,13 +2222,13 @@ struct partial_unwrap_neg_fixed arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const T1& M; }; @@ -2047,13 +2246,13 @@ struct partial_unwrap_neg_redirect { typedef partial_unwrap_neg_fixed template -struct partial_unwrap< eOp > : public partial_unwrap_neg_redirect::value >::result +struct partial_unwrap< eOp > : public partial_unwrap_neg_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap(const eOp& A) - : partial_unwrap_neg_redirect< T1, is_Mat_fixed::value >::result(A) + : partial_unwrap_neg_redirect< T1, is_Mat_fixed::value>::result(A) { } }; @@ -2072,13 +2271,13 @@ struct partial_unwrap< eOp, eop_neg> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat& M; }; @@ -2097,13 +2296,13 @@ struct partial_unwrap< eOp, eop_neg> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Row& M; }; @@ -2122,13 +2321,13 @@ struct partial_unwrap< eOp, eop_neg> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&M)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Col& M; }; @@ -2143,18 +2342,18 @@ struct partial_unwrap< eOp, eop_neg> > inline partial_unwrap(const eOp,eop_neg>& A) : orig( A.P.Q.m ) - , M ( const_cast( A.P.Q.colptr(0) ), A.P.Q.n_rows, false, false ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, false, false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template arma_inline bool is_alias(const Mat& X) const { return (void_ptr(&X) == void_ptr(&orig)); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat& orig; const Col M; @@ -2174,13 +2373,13 @@ struct partial_unwrap< eOp, eop_neg> > arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } template - arma_inline bool is_alias(const Mat&) const { return false; } + constexpr bool is_alias(const Mat&) const { return false; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Row M; }; @@ -2204,10 +2403,10 @@ struct partial_unwrap_check_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Mat M; }; @@ -2221,8 +2420,8 @@ struct partial_unwrap_check_fixed inline explicit partial_unwrap_check_fixed(const T1& A, const Mat& B) - : M_local( (&A == &B) ? new T1(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local( (&A == &B) ? new T1(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -2235,10 +2434,10 @@ struct partial_unwrap_check_fixed if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const T1* M_local; const T1& M; @@ -2256,12 +2455,12 @@ template struct partial_unwrap_check_redirect { typedef partial_unwrap_check_fixed result; }; template -struct partial_unwrap_check : public partial_unwrap_check_redirect::value >::result +struct partial_unwrap_check : public partial_unwrap_check_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap_check(const T1& A, const Mat& B) - : partial_unwrap_check_redirect< T1, is_Mat_fixed::value >::result(A, B) + : partial_unwrap_check_redirect::value>::result(A, B) { } }; @@ -2275,8 +2474,8 @@ struct partial_unwrap_check< Mat > inline partial_unwrap_check(const Mat& A, const Mat& B) - : M_local ( (&A == &B) ? new Mat(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local ( (&A == &B) ? new Mat(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -2290,10 +2489,10 @@ struct partial_unwrap_check< Mat > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; // the order below is important const Mat* M_local; @@ -2309,8 +2508,8 @@ struct partial_unwrap_check< Row > inline partial_unwrap_check(const Row& A, const Mat& B) - : M_local ( (&A == &B) ? new Row(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local ( (&A == &B) ? new Row(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -2324,10 +2523,10 @@ struct partial_unwrap_check< Row > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; // the order below is important const Row* M_local; @@ -2343,8 +2542,8 @@ struct partial_unwrap_check< Col > inline partial_unwrap_check(const Col& A, const Mat& B) - : M_local ( (&A == &B) ? new Col(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local ( (&A == &B) ? new Col(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -2358,10 +2557,10 @@ struct partial_unwrap_check< Col > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; // the order below is important const Col* M_local; @@ -2379,15 +2578,15 @@ struct partial_unwrap_check< subview_col > inline partial_unwrap_check(const subview_col& A, const Mat& B) - : M ( const_cast( A.colptr(0) ), A.n_rows, (&(A.m) == &B), false ) + : M ( const_cast( A.colmem ), A.n_rows, (&(A.m) == &B), false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = false; - static const bool do_times = false; + static constexpr bool do_trans = false; + static constexpr bool do_times = false; const Col M; }; @@ -2407,10 +2606,10 @@ struct partial_unwrap_check_htrans_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Mat M; }; @@ -2424,8 +2623,8 @@ struct partial_unwrap_check_htrans_fixed inline explicit partial_unwrap_check_htrans_fixed(const Op& A, const Mat& B) - : M_local( (&(A.m) == &B) ? new T1(A.m) : 0 ) - , M ( (&(A.m) == &B) ? (*M_local) : A.m ) + : M_local( (&(A.m) == &B) ? new T1(A.m) : nullptr ) + , M ( (&(A.m) == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2438,10 +2637,10 @@ struct partial_unwrap_check_htrans_fixed if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const T1* M_local; const T1& M; @@ -2460,12 +2659,12 @@ struct partial_unwrap_check_htrans_redirect { typedef partial_unwrap_ template -struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans_redirect::value >::result +struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap_check(const Op& A, const Mat& B) - : partial_unwrap_check_htrans_redirect< T1, is_Mat_fixed::value >::result(A, B) + : partial_unwrap_check_htrans_redirect::value>::result(A, B) { } }; @@ -2479,8 +2678,8 @@ struct partial_unwrap_check< Op< Mat, op_htrans> > inline partial_unwrap_check(const Op< Mat, op_htrans>& A, const Mat& B) - : M_local ( (&A.m == &B) ? new Mat(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + : M_local ( (&A.m == &B) ? new Mat(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2493,10 +2692,10 @@ struct partial_unwrap_check< Op< Mat, op_htrans> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; // the order below is important const Mat* M_local; @@ -2512,8 +2711,8 @@ struct partial_unwrap_check< Op< Row, op_htrans> > inline partial_unwrap_check(const Op< Row, op_htrans>& A, const Mat& B) - : M_local ( (&A.m == &B) ? new Row(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + : M_local ( (&A.m == &B) ? new Row(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2526,10 +2725,10 @@ struct partial_unwrap_check< Op< Row, op_htrans> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; // the order below is important const Row* M_local; @@ -2545,8 +2744,8 @@ struct partial_unwrap_check< Op< Col, op_htrans> > inline partial_unwrap_check(const Op< Col, op_htrans>& A, const Mat& B) - : M_local ( (&A.m == &B) ? new Col(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + : M_local ( (&A.m == &B) ? new Col(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2559,10 +2758,10 @@ struct partial_unwrap_check< Op< Col, op_htrans> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; // the order below is important const Col* M_local; @@ -2580,15 +2779,15 @@ struct partial_unwrap_check< Op< subview_col, op_htrans> > inline partial_unwrap_check(const Op< subview_col, op_htrans>& A, const Mat& B) - : M ( const_cast( A.m.colptr(0) ), A.m.n_rows, (&(A.m.m) == &B), false ) + : M ( const_cast( A.m.colmem ), A.m.n_rows, (&(A.m.m) == &B), false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(1); } + constexpr eT get_val() const { return eT(1); } - static const bool do_trans = true; - static const bool do_times = false; + static constexpr bool do_trans = true; + static constexpr bool do_times = false; const Col M; }; @@ -2611,8 +2810,8 @@ struct partial_unwrap_check_htrans2_default arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Mat M; @@ -2629,8 +2828,8 @@ struct partial_unwrap_check_htrans2_fixed inline explicit partial_unwrap_check_htrans2_fixed(const Op& A, const Mat& B) : val (A.aux) - , M_local( (&(A.m) == &B) ? new T1(A.m) : 0 ) - , M ( (&(A.m) == &B) ? (*M_local) : A.m ) + , M_local( (&(A.m) == &B) ? new T1(A.m) : nullptr ) + , M ( (&(A.m) == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2645,8 +2844,8 @@ struct partial_unwrap_check_htrans2_fixed arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const T1* M_local; @@ -2666,12 +2865,12 @@ struct partial_unwrap_check_htrans2_redirect { typedef partial_unwrap template -struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans2_redirect::value >::result +struct partial_unwrap_check< Op > : public partial_unwrap_check_htrans2_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap_check(const Op& A, const Mat& B) - : partial_unwrap_check_htrans2_redirect< T1, is_Mat_fixed::value >::result(A, B) + : partial_unwrap_check_htrans2_redirect::value>::result(A, B) { } }; @@ -2686,8 +2885,8 @@ struct partial_unwrap_check< Op< Mat, op_htrans2> > inline partial_unwrap_check(const Op< Mat, op_htrans2>& A, const Mat& B) : val (A.aux) - , M_local ( (&A.m == &B) ? new Mat(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + , M_local ( (&A.m == &B) ? new Mat(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2702,8 +2901,8 @@ struct partial_unwrap_check< Op< Mat, op_htrans2> > arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; // the order below is important const eT val; @@ -2721,8 +2920,8 @@ struct partial_unwrap_check< Op< Row, op_htrans2> > inline partial_unwrap_check(const Op< Row, op_htrans2>& A, const Mat& B) : val (A.aux) - , M_local ( (&A.m == &B) ? new Row(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + , M_local ( (&A.m == &B) ? new Row(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2737,8 +2936,8 @@ struct partial_unwrap_check< Op< Row, op_htrans2> > arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; // the order below is important const eT val; @@ -2756,8 +2955,8 @@ struct partial_unwrap_check< Op< Col, op_htrans2> > inline partial_unwrap_check(const Op< Col, op_htrans2>& A, const Mat& B) : val (A.aux) - , M_local ( (&A.m == &B) ? new Col(A.m) : 0 ) - , M ( (&A.m == &B) ? (*M_local) : A.m ) + , M_local ( (&A.m == &B) ? new Col(A.m) : nullptr ) + , M ( (&A.m == &B) ? (*M_local) : A.m ) { arma_extra_debug_sigprint(); } @@ -2772,8 +2971,8 @@ struct partial_unwrap_check< Op< Col, op_htrans2> > arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; // the order below is important const eT val; @@ -2793,15 +2992,15 @@ struct partial_unwrap_check< Op< subview_col, op_htrans2> > inline partial_unwrap_check(const Op< subview_col, op_htrans2>& A, const Mat& B) : val( A.aux ) - , M ( const_cast( A.m.colptr(0) ), A.m.n_rows, (&(A.m.m) == &B), false ) + , M ( const_cast( A.m.colmem ), A.m.n_rows, (&(A.m.m) == &B), false ) { arma_extra_debug_sigprint(); } arma_inline eT get_val() const { return val; } - static const bool do_trans = true; - static const bool do_times = true; + static constexpr bool do_trans = true; + static constexpr bool do_times = true; const eT val; const Col M; @@ -2825,8 +3024,8 @@ struct partial_unwrap_check_scalar_times_default arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Mat M; @@ -2843,8 +3042,8 @@ struct partial_unwrap_check_scalar_times_fixed inline explicit partial_unwrap_check_scalar_times_fixed(const eOp& A, const Mat& B) : val ( A.aux ) - , M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) + , M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -2859,8 +3058,8 @@ struct partial_unwrap_check_scalar_times_fixed arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const T1* M_local; @@ -2880,12 +3079,12 @@ struct partial_unwrap_check_scalar_times_redirect { typedef partial_u template -struct partial_unwrap_check< eOp > : public partial_unwrap_check_scalar_times_redirect::value >::result +struct partial_unwrap_check< eOp > : public partial_unwrap_check_scalar_times_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap_check(const eOp& A, const Mat& B) - : partial_unwrap_check_scalar_times_redirect< T1, is_Mat_fixed::value >::result(A, B) + : partial_unwrap_check_scalar_times_redirect::value>::result(A, B) { } }; @@ -2900,8 +3099,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > inline partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) : val (A.aux) - , M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + , M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -2916,8 +3115,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Mat* M_local; @@ -2934,8 +3133,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > inline partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) : val(A.aux) - , M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + , M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -2950,8 +3149,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Row* M_local; @@ -2968,8 +3167,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > inline partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) : val ( A.aux ) - , M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + , M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -2984,8 +3183,8 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Col* M_local; @@ -3004,15 +3203,15 @@ struct partial_unwrap_check< eOp, eop_scalar_times> > inline partial_unwrap_check(const eOp,eop_scalar_times>& A, const Mat& B) : val( A.aux ) - , M ( const_cast( A.P.Q.colptr(0) ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) + , M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) { arma_extra_debug_sigprint(); } arma_inline eT get_val() const { return val; } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const eT val; const Col M; @@ -3033,10 +3232,10 @@ struct partial_unwrap_check_neg_default arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat M; }; @@ -3051,8 +3250,8 @@ struct partial_unwrap_check_neg_fixed inline explicit partial_unwrap_check_neg_fixed(const eOp& A, const Mat& B) - : M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) + : M_local( (&(A.P.Q) == &B) ? new T1(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? (*M_local) : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -3065,10 +3264,10 @@ struct partial_unwrap_check_neg_fixed if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const T1* M_local; const T1& M; @@ -3087,12 +3286,12 @@ struct partial_unwrap_check_neg_redirect { typedef partial_unwrap_che template -struct partial_unwrap_check< eOp > : public partial_unwrap_check_neg_redirect::value >::result +struct partial_unwrap_check< eOp > : public partial_unwrap_check_neg_redirect::value>::result { typedef typename T1::elem_type eT; inline partial_unwrap_check(const eOp& A, const Mat& B) - : partial_unwrap_check_neg_redirect< T1, is_Mat_fixed::value >::result(A, B) + : partial_unwrap_check_neg_redirect::value>::result(A, B) { } }; @@ -3106,8 +3305,8 @@ struct partial_unwrap_check< eOp, eop_neg> > inline partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) - : M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + : M_local( (&(A.P.Q) == &B) ? new Mat(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -3120,10 +3319,10 @@ struct partial_unwrap_check< eOp, eop_neg> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Mat* M_local; const Mat& M; @@ -3138,8 +3337,8 @@ struct partial_unwrap_check< eOp, eop_neg> > inline partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) - : M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + : M_local( (&(A.P.Q) == &B) ? new Row(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -3152,10 +3351,10 @@ struct partial_unwrap_check< eOp, eop_neg> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Row* M_local; const Row& M; @@ -3170,8 +3369,8 @@ struct partial_unwrap_check< eOp, eop_neg> > inline partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) - : M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : 0 ) - , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) + : M_local( (&(A.P.Q) == &B) ? new Col(A.P.Q) : nullptr ) + , M ( (&(A.P.Q) == &B) ? *M_local : A.P.Q ) { arma_extra_debug_sigprint(); } @@ -3184,10 +3383,10 @@ struct partial_unwrap_check< eOp, eop_neg> > if(M_local) { delete M_local; } } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Col* M_local; const Col& M; @@ -3204,15 +3403,15 @@ struct partial_unwrap_check< eOp, eop_neg> > inline partial_unwrap_check(const eOp,eop_neg>& A, const Mat& B) - : M ( const_cast( A.P.Q.colptr(0) ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) + : M ( const_cast( A.P.Q.colmem ), A.P.Q.n_rows, (&(A.P.Q.m) == &B), false ) { arma_extra_debug_sigprint(); } - arma_inline eT get_val() const { return eT(-1); } + constexpr eT get_val() const { return eT(-1); } - static const bool do_trans = false; - static const bool do_times = true; + static constexpr bool do_trans = false; + static constexpr bool do_times = true; const Col M; }; diff --git a/src/armadillo_bits/unwrap_cube.hpp b/src/armadillo_bits/unwrap_cube.hpp index d5290a23..ca91cfad 100644 --- a/src/armadillo_bits/unwrap_cube.hpp +++ b/src/armadillo_bits/unwrap_cube.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -34,7 +36,7 @@ struct unwrap_cube const Cube M; template - arma_inline bool is_alias(const Cube&) const { return false; } + constexpr bool is_alias(const Cube&) const { return false; } }; @@ -77,6 +79,15 @@ struct unwrap_cube_check arma_type_check(( is_arma_cube_type::value == false )); } + inline + unwrap_cube_check(const T1& A, const bool) + : M(A) + { + arma_extra_debug_sigprint(); + + arma_type_check(( is_arma_cube_type::value == false )); + } + const Cube M; }; @@ -87,8 +98,17 @@ struct unwrap_cube_check< Cube > { inline unwrap_cube_check(const Cube& A, const Cube& B) - : M_local( (&A == &B) ? new Cube(A) : 0 ) - , M ( (&A == &B) ? (*M_local) : A ) + : M_local( (&A == &B) ? new Cube(A) : nullptr ) + , M ( (&A == &B) ? (*M_local) : A ) + { + arma_extra_debug_sigprint(); + } + + + inline + unwrap_cube_check(const Cube& A, const bool is_alias) + : M_local( is_alias ? new Cube(A) : nullptr ) + , M ( is_alias ? (*M_local) : A ) { arma_extra_debug_sigprint(); } @@ -99,10 +119,7 @@ struct unwrap_cube_check< Cube > { arma_extra_debug_sigprint(); - if(M_local) - { - delete M_local; - } + if(M_local) { delete M_local; } } diff --git a/src/armadillo_bits/unwrap_spmat.hpp b/src/armadillo_bits/unwrap_spmat.hpp index a985dbbf..0597aaa2 100644 --- a/src/armadillo_bits/unwrap_spmat.hpp +++ b/src/armadillo_bits/unwrap_spmat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -36,7 +38,7 @@ struct unwrap_spmat const SpMat M; template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; @@ -124,7 +126,7 @@ struct unwrap_spmat< SpOp > const SpMat M; template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; @@ -146,7 +148,7 @@ struct unwrap_spmat< SpGlue > const SpMat M; template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; @@ -166,7 +168,7 @@ struct unwrap_spmat< mtSpOp > const SpMat M; template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; @@ -186,7 +188,7 @@ struct unwrap_spmat< mtSpGlue > const SpMat M; template - arma_inline bool is_alias(const SpMat&) const { return false; } + constexpr bool is_alias(const SpMat&) const { return false; } }; diff --git a/src/armadillo_bits/upgrade_val.hpp b/src/armadillo_bits/upgrade_val.hpp index 115d829d..a5e9da2b 100644 --- a/src/armadillo_bits/upgrade_val.hpp +++ b/src/armadillo_bits/upgrade_val.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -61,7 +63,7 @@ struct upgrade_val //! upgrade a type to allow multiplication with a complex type -//! e.g. the int in "int * complex" is upgraded to a double +//! eg. the int in "int * complex" is upgraded to a double // template<> template struct upgrade_val< std::complex, T2 > diff --git a/src/armadillo_bits/wall_clock_bones.hpp b/src/armadillo_bits/wall_clock_bones.hpp index 964cc9bb..29c30142 100644 --- a/src/armadillo_bits/wall_clock_bones.hpp +++ b/src/armadillo_bits/wall_clock_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -26,22 +28,15 @@ class wall_clock inline wall_clock(); inline ~wall_clock(); - inline void tic(); //!< start the timer - inline double toc(); //!< return the number of seconds since the last call to tic() + inline void tic(); //!< start the timer + arma_warn_unused inline double toc(); //!< return the number of seconds since the last call to tic() private: - bool valid; + bool valid = false; - #if defined(ARMA_USE_CXX11) - std::chrono::steady_clock::time_point chrono_time1; - #elif defined(ARMA_HAVE_GETTIMEOFDAY) - struct timeval posix_time1; - struct timeval posix_time2; - #else - std::clock_t time1; - #endif + std::chrono::steady_clock::time_point chrono_time1; }; diff --git a/src/armadillo_bits/wall_clock_meat.hpp b/src/armadillo_bits/wall_clock_meat.hpp index 50604159..54ed68a7 100644 --- a/src/armadillo_bits/wall_clock_meat.hpp +++ b/src/armadillo_bits/wall_clock_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -20,7 +22,6 @@ inline wall_clock::wall_clock() - : valid(false) { arma_extra_debug_sigprint(); } @@ -41,22 +42,8 @@ wall_clock::tic() { arma_extra_debug_sigprint(); - #if defined(ARMA_USE_CXX11) - { - chrono_time1 = std::chrono::steady_clock::now(); - valid = true; - } - #elif defined(ARMA_HAVE_GETTIMEOFDAY) - { - gettimeofday(&posix_time1, 0); - valid = true; - } - #else - { - time1 = std::clock(); - valid = true; - } - #endif + chrono_time1 = std::chrono::steady_clock::now(); + valid = true; } @@ -69,42 +56,17 @@ wall_clock::toc() if(valid) { - #if defined(ARMA_USE_CXX11) - { - const std::chrono::steady_clock::time_point chrono_time2 = std::chrono::steady_clock::now(); - - typedef std::chrono::duration duration_type; - - const duration_type chrono_span = std::chrono::duration_cast< duration_type >(chrono_time2 - chrono_time1); - - return chrono_span.count(); - } - #elif defined(ARMA_HAVE_GETTIMEOFDAY) - { - gettimeofday(&posix_time2, 0); - - const double tmp_time1 = double(posix_time1.tv_sec) + double(posix_time1.tv_usec) * 1.0e-6; - const double tmp_time2 = double(posix_time2.tv_sec) + double(posix_time2.tv_usec) * 1.0e-6; - - return tmp_time2 - tmp_time1; - } - #else - { - std::clock_t time2 = std::clock(); - - std::clock_t diff = time2 - time1; - - return double(diff) / double(CLOCKS_PER_SEC); - } - #endif - } - else - { - return 0.0; + const std::chrono::steady_clock::time_point chrono_time2 = std::chrono::steady_clock::now(); + + typedef std::chrono::duration duration_type; // TODO: check this + + const duration_type chrono_span = std::chrono::duration_cast< duration_type >(chrono_time2 - chrono_time1); + + return chrono_span.count(); } + + return 0.0; } - //! @} - diff --git a/src/armadillo_bits/xtrans_mat_bones.hpp b/src/armadillo_bits/xtrans_mat_bones.hpp index 269f3436..58756662 100644 --- a/src/armadillo_bits/xtrans_mat_bones.hpp +++ b/src/armadillo_bits/xtrans_mat_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,18 +21,18 @@ template -class xtrans_mat : public Base > +class xtrans_mat : public Base< eT, xtrans_mat > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = false; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = false; - static const bool really_do_conj = (do_conj && is_cx::yes); + static constexpr bool really_do_conj = (do_conj && is_cx::yes); arma_aligned const Mat& X; arma_aligned mutable Mat Y; diff --git a/src/armadillo_bits/xtrans_mat_meat.hpp b/src/armadillo_bits/xtrans_mat_meat.hpp index 29f094f7..1872c303 100644 --- a/src/armadillo_bits/xtrans_mat_meat.hpp +++ b/src/armadillo_bits/xtrans_mat_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/armadillo_bits/xvec_htrans_bones.hpp b/src/armadillo_bits/xvec_htrans_bones.hpp index 93520a50..6eab7101 100644 --- a/src/armadillo_bits/xvec_htrans_bones.hpp +++ b/src/armadillo_bits/xvec_htrans_bones.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // @@ -19,16 +21,16 @@ template -class xvec_htrans : public Base > +class xvec_htrans : public Base< eT, xvec_htrans > { public: typedef eT elem_type; typedef typename get_pod_type::result pod_type; - static const bool is_row = false; - static const bool is_col = false; - static const bool is_xvec = true; + static constexpr bool is_row = false; + static constexpr bool is_col = false; + static constexpr bool is_xvec = true; arma_aligned const eT* const mem; diff --git a/src/armadillo_bits/xvec_htrans_meat.hpp b/src/armadillo_bits/xvec_htrans_meat.hpp index 6f49efc8..b79a7ef1 100644 --- a/src/armadillo_bits/xvec_htrans_meat.hpp +++ b/src/armadillo_bits/xvec_htrans_meat.hpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: Apache-2.0 +// // Copyright 2008-2016 Conrad Sanderson (http://conradsanderson.id.au) // Copyright 2008-2016 National ICT Australia (NICTA) // diff --git a/src/crbfgs.pxd b/src/crbfgs.pxd new file mode 100644 index 00000000..b910daa2 --- /dev/null +++ b/src/crbfgs.pxd @@ -0,0 +1,7 @@ +# cython: language_level=2 +from libcpp cimport bool +cimport cyarma +from cyarma cimport vec + +cdef extern from "rbfgs.h": + vec rlbfgs_optim(vec q1, vec q2, vec time, int maxiter, double lam, int penalty); diff --git a/src/crbfgs.pyx b/src/crbfgs.pyx new file mode 100644 index 00000000..acf92eef --- /dev/null +++ b/src/crbfgs.pyx @@ -0,0 +1,23 @@ +# cython: language_level=2 +import numpy as np +cimport crbfgs +cimport cyarma +cimport numpy as np + +include "cyarma.pyx" + +from libcpp cimport bool + +def rlbfgs(np.ndarray[double, ndim=1, mode="c"] q1, np.ndarray[double, ndim=1, mode="c"] q2, + np.ndarray[double, ndim=1, mode="c"] time, maxiter=30, lam=0.0, penalty=0): + q1 = np.ascontiguousarray(q1) + q2 = np.ascontiguousarray(q2) + time = np.ascontiguousarray(time) + cdef vec aq1 = numpy_to_vec_d(q1) + cdef vec aq2 = numpy_to_vec_d(q2) + cdef vec atime = numpy_to_vec_d(time) + cdef vec out = rlbfgs_optim(aq1, aq2, atime, maxiter, lam, penalty) + + cdef np.ndarray[np.double_t,ndim=1] gam = vec_to_numpy(out, None) + + return(gam) diff --git a/src/rbfgs.cpp b/src/rbfgs.cpp new file mode 100644 index 00000000..f1c3fae5 --- /dev/null +++ b/src/rbfgs.cpp @@ -0,0 +1,565 @@ +#include +#include "armadillo" +//#include "rbfgs.h" + +using namespace arma; +using namespace std; + +class rlbfgs { + private: + vec time; // time + vec q1; // srvf1 + vec q2; // srvf2 + uword T; // size of time + double hCurCost; + vec hCurGradient; + struct options{ + double tolgradnorm; + double maxtime; + int memory; + int ls_max_steps; + int maxiter; + double minstepsize; + }; + + struct stats{ + int iter; + double cost; + double gradnorm; + double stepsize; + bool accepted; + }; + + struct lstats{ + double stepsize; + vec newh; + }; + + public: + vec gammaOpt; + vec q2Opt; + double cost; + // constructor + rlbfgs(vec q1i, vec q2i, vec timei) { + q1 = normalise( q1i, 2 ); + q2 = normalise( q2i, 2 ); + time = timei; + + T = timei.n_elem; + } + + void solve(int maxiter=30, double lam=0.0, int penalty=0){ + // run solver + options option; + // terminates if the norm of the gradient drops below this + option.tolgradnorm = 1e-3; + // terminates if more than seconds elapsed + option.maxtime = datum::inf; + // number of previous iterations the program remembers + option.memory = 30; + option.ls_max_steps = 25; + option.maxiter = maxiter; + // minimum norm of tangent vector that points from current to next + option.minstepsize = 1e-10; + + // Initialization of Variables + vec htilde = ones(T); + vec q2tilde = q2; + + // list to store step vectors + arma::field tmp(option.memory); + arma::field sHistory(option.memory); + arma::field yHistory(option.memory); + vec rhoHistory(option.memory); + vec tmp_vec(option.memory); + + // number of iterations since last restart + int j = 0; + // Total number of BFGS iterations + int k = 0; + // scaling of direction given by getDirection + double alpha = 1; + // scaling of initial matrix, Barzilai-Borwein + double scaleFactor = 1; + // Norm of the step + double stepsize = 1; + + bool accepted = true; + int stop; + + double hCurCost; + vec hCurGradient; + lstats lstat; + vec hNext; + alignment_costgrad(q2tilde, htilde, hCurCost, hCurGradient, lam, penalty); + double hCurGradNorm = norm2(hCurGradient); + + bool ultimatum = false; + + stats stat; + stat.iter = k; + stat.cost = hCurCost; + stat.gradnorm = hCurGradNorm; + stat.stepsize = datum::nan; + stat.accepted = false; + + vec p; + double in_prod; + vec step; + double hNextCost; + vec hNextGradient; + vec sk; + vec yk; + double norm_sk; + double inner_sk_yk; + double inner_sk_sk; + double cap, rhok; + while (true){ + stop = stoppingcriterion(option, stat); + + if (stop == 0){ + if (stat.stepsize < option.minstepsize){ + if (!ultimatum){ + j = 0; + ultimatum = true; + } else { + stop = 1; + } + } else { + ultimatum = false; + } + } + + if (stop > 0){ + break; + } + + // compute BFGS direction + p = getDirection(hCurGradient, sHistory, yHistory, rhoHistory, scaleFactor, min(j, option.memory)); + + // execute line search + in_prod = inner(hCurGradient, p); + lstat = linesearch_hint(p, hCurCost, in_prod, q2tilde, lam, penalty); + + stepsize = lstat.stepsize; + hNext = lstat.newh; + + // iterative update + htilde = group_action_SRVF(htilde, hNext); + q2tilde = group_action_SRVF(q2tilde, hNext); + + // record the BGFS step multiplier + alpha = stepsize / norm2(p); + step = alpha * p; + + // query cost and gradient at the candidate new point + alignment_costgrad(q2tilde, hNext, hNextCost, hNextGradient, lam, penalty); + + // compute sk and yk + sk = step; + yk = hNextGradient - hCurGradient; + + // computation of the BFGS step + norm_sk = norm2(sk); + sk = sk / norm_sk; + yk = yk / norm_sk; + + inner_sk_yk = inner(sk, yk); + inner_sk_sk = pow(norm2(sk),2); // ensures nonnegativity + + // cautious step + cap = strict_inc_func(hCurGradNorm); + if ((inner_sk_sk != 0) & ((inner_sk_yk/inner_sk_sk) >= cap)){ + accepted = true; + + rhok = 1/inner_sk_yk; + + scaleFactor = inner_sk_yk / pow(norm2(yk),2); + + if (j >= option.memory){ + if (option.memory > 1){ + tmp.subfield(0, 0, option.memory-2, 0) = sHistory.subfield(1, 0, option.memory-1, 0); + tmp(option.memory-1) = sHistory(0); + sHistory = tmp; + + tmp.subfield(0, 0, option.memory-2, 0) = yHistory.subfield(1, 0, option.memory-1, 0); + tmp(option.memory-1) = yHistory(0); + yHistory = tmp; + + tmp_vec(arma::span(0,option.memory-1)) = rhoHistory(arma::span(1,option.memory-2)); + tmp_vec(option.memory) = rhoHistory(0); + rhoHistory = tmp_vec; + } + if (option.memory > 0){ + sHistory(option.memory-1) = sk; + yHistory(option.memory-1) = yk; + rhoHistory(option.memory-1) = rhok; + } + } else { + sHistory(j) = sk; + yHistory(j) = yk; + rhoHistory(j) = rhok; + } + + j += 1; + } else { + accepted = false; + } + + hCurGradient = hNextGradient; + hCurGradNorm = norm2(hNextGradient); + hCurCost = hNextCost; + + k += 1; + + stat.iter = k; + stat.cost = hCurCost; + stat.gradnorm = hCurGradNorm; + stat.stepsize = datum::nan; + stat.accepted = accepted; + } + + gammaOpt = cumtrapz(time, pow(htilde,2)); + gammaOpt = (gammaOpt - gammaOpt.min()) / (gammaOpt.max() - gammaOpt.min()); + q2Opt = q2tilde; + cost = hCurCost; + } + + double alignment_cost(vec h, vec q2k, double lam = 0, int penalty = 0){ + vec q2new = group_action_SRVF(q2k, h); + + double pen = 0; + if (penalty == 0){ + vec time1 = arma::linspace(0,1,h.n_elem); + vec b = arma::diff(time1); + double binsize = mean(b); + vec g = gradient(arma::pow(h, 2), binsize); + arma::mat pen1 = arma::trapz(time1, arma::pow(g, 2)); + pen = pen1(0); + } + // l2gam + if (penalty == 1){ + vec tmp = arma::ones(T); + pen = normL2(arma::pow(h,2)-tmp); + pen = pow(pen, 2); + } + // l2psi + if (penalty == 2){ + vec tmp = arma::ones(T); + pen = normL2(h-tmp); + pen = pow(pen, 2); + } + // geodesic + if (penalty == 3){ + vec time1 = arma::linspace(0,1,h.n_elem); + arma::mat pen1 = arma::trapz(time1, h); + double q1dotq2 = pen1(0); + if (q1dotq2 > 1){ + q1dotq2 = 1; + } else if (q1dotq2 < -1) + { + q1dotq2 = -1; + } + pen = pow(real(acos(q1dotq2)),2); + } + + double f = normL2(q1-q2new); + f = pow(f,2) + lam * pen; + + return f; + } + + void alignment_costgrad(vec q2k, vec h, double& f, vec& g, double lam = 0, int penalty = 0){ + // roughness + double pen = 0; + if (penalty == 0){ + vec time1 = arma::linspace(0,1,h.n_elem); + vec b = arma::diff(time1); + double binsize = mean(b); + vec g = gradient(arma::pow(h, 2), binsize); + arma::mat pen1 = arma::trapz(time1, arma::pow(g, 2)); + pen = pen1(0); + } + // l2gam + if (penalty == 1){ + vec tmp = arma::ones(T); + pen = normL2(arma::pow(h,2)-tmp); + pen = pow(pen, 2); + } + // l2psi + if (penalty == 2){ + vec tmp = arma::ones(T); + pen = normL2(h-tmp); + pen = pow(pen, 2); + } + // geodesic + if (penalty == 3){ + vec time1 = arma::linspace(0,1,h.n_elem); + arma::mat pen1 = arma::trapz(time1, h); + double q1dotq2 = pen1(0); + if (q1dotq2 > 1){ + q1dotq2 = 1; + } else if (q1dotq2 < -1) + { + q1dotq2 = -1; + } + pen = pow(real(acos(q1dotq2)),2); + } + + // compute cost + f = normL2(q1-q2k); + f = pow(f,2) + lam * pen; + + // compute cost gradient + double binsize = 1.0/(T-1); + vec q2kdot = gradient(q2k, binsize); + vec dq = q1 - q2k; + vec tmp = dq % q2kdot; + vec tmp1 = dq % q2k; + vec v = 2 * cumtrapz(time, tmp); + v = v - tmp1; + + mat val = arma::trapz(time, v); + g = v - val(0); + + return; + + } + + vec getDirection(vec hCurGradient, arma::field sHistory, arma::field yHistory, vec rhoHistory, double scaleFactor, int j){ + vec q = hCurGradient; + vec inner_s_q = arma::zeros(j); + + for (int i = j; i > 0; i--) { + inner_s_q(i-1) = rhoHistory(i-1) * inner(sHistory(i-1), q); + q = q - inner_s_q(i-1) * yHistory(i-1); + } + + vec r = scaleFactor * q; + + double omega; + for (int i=0; i < j; i++){ + omega = rhoHistory(i) * inner(yHistory(i), r); + r = r + (inner_s_q(i) - omega) * sHistory(i); + } + + vec direction = -1 * r; + + return direction; + } + + lstats linesearch_hint(vec d, double f0, double df0, vec q2k, double lam=0, int penalty=0){ + // Armijo line-search based on the line-search hint in the problem + + double contraction_factor = 0.5; + double suff_decr = 1e-6; + int max_ls_steps = 25; + bool ls_backtrack = true; + bool ls_force_decrease = true; + + double alpha = 1; + vec hid = arma::ones(T); // identity element + + vec newh = exp(hid, d, alpha); + double newf = alignment_cost(newh, q2k, lam, penalty); + int cost_evaluations = 1; + + uvec tst = newh <= 0; + while (ls_backtrack & (newf > (f0 + suff_decr*alpha*df0)) || arma::sum(tst) > 0){ + alpha *= contraction_factor; + + newh = exp(hid, d, alpha); + newf = alignment_cost(newh, q2k, lam, penalty); + cost_evaluations += 1; + tst = newh <= 0; + + if (cost_evaluations >= max_ls_steps){ + break; + } + } + + if (ls_force_decrease & (newf > f0)){ + alpha = 0; + newh = hid; + newf = f0; + } + + double norm_d = norm2(d); + double stepsize = alpha * norm_d; + + lstats lstat; + lstat.stepsize = stepsize; + lstat.newh = newh; + + return lstat; + } + + int stoppingcriterion(options option, stats stat){ + int stop = 0; + if (stat.gradnorm <= option.tolgradnorm){ + stop = 2; + } + + if (stat.iter >= option.maxiter){ + stop = 3; + } + + return stop; + } + + vec group_action_SRVF(vec q, vec h){ + vec gamma = cumtrapz(time, arma::pow(h,2)); + gamma = gamma / gamma.back(); + vec time1 = arma::linspace(0,1,h.n_elem); + vec b = arma::diff(time1); + double binsize = mean(b); + vec h1 = gradient(gamma, binsize); + h1 = sqrt(h1); + vec qnew; + arma::interp1(time, q, gamma, qnew); + qnew = qnew % h1; + + return qnew; + } + + double strict_inc_func(double t){ + // the cautious step needs a real function that has value 0 at t=0 + return 1e-4*t; + } + + double normL2(vec f){ + double val1 = innerProdL2(f, f); + double val = sqrt(val1); + + return val; + } + + double innerProdL2(vec f1, vec f2){ + vec tmp = f1 % f2; + arma::mat tmp1 = arma::trapz(time, tmp); + double val = tmp1(0); + + return val; + } + + double dist(vec f1, vec f2){ + double temp = inner(f1, f2); + double d = real(acos(temp)); + + return d; + } + + double typicaldist(){ + double out = M_PI/2; + + return out; + } + + vec proj(vec f, vec v){ + arma::mat tmp1 = arma::trapz(time, f%v); + vec out = v - f * tmp1(0); + + return out; + } + + vec log(vec f1, vec f2){ + vec v = proj(f1, f2 - f1); + double di = dist(f1, f2); + if (di > 1e-6){ + double nv = norm2(v); + v = v * (di/nv); + } + + return v; + } + + vec exp(vec f1, vec v, double delta=1){ + vec vd = delta * v; + double nrm_vd = norm2(vd); + + vec f2; + if (nrm_vd > 0){ + f2 = f1 * cos(nrm_vd) + vd * (sin(nrm_vd)/nrm_vd); + } else { + f2 = f1; + } + + return f2; + } + + double inner(vec v1, vec v2){ + arma::mat M = arma::trapz(time, v1 % v2); + double val = M(0); + + return val; + } + + double norm2(vec f){ + arma::mat tmp1 = arma::trapz(time, pow(f,2)); + double out = sqrt(tmp1(0)); + + return out; + } + + vec transp(vec f1, vec f2, vec v){ + // isometric vector transport + vec w = log(f1, f2); + double dist_f1f2 = norm2(w); + + vec Tv; + if (dist_f1f2 > 0){ + vec u = w / dist_f1f2; + double utv = inner(u, v); + Tv = v + (cos(dist_f1f2) - 1) * utv * u - sin(dist_f1f2); + } else{ + Tv = v; + } + + return Tv; + } + + vec gradient(vec f, double binsize){ + vec g = arma::zeros(T); + g(0) = (f(1) - f(0)) / binsize; + g(T-1) = (f(T-1) - f(T-2)) / binsize; + + g(arma::span(1, T-2)) = (f(arma::span(2, T-1)) - f(arma::span(0, T-3))) / (2 * binsize); + + return g; + } + + vec cumtrapz(vec x, vec y){ + vec z = arma::zeros(T); + + vec dt = arma::diff(x)/2.0; + vec tmp = dt % (y(arma::span(1, T-1)) + y(arma::span(0,T-2))); + z(arma::span(1,T-1)) = cumsum(tmp); + + return z; + } +}; + +vec rlbfgs_optim(vec q1, vec q2, vec time, int maxiter=30, double lam=0.0, int penalty=0){ + uword T = time.n_elem; + vec time1 = arma::linspace(0, 1, T); + + rlbfgs myObj(q1, q2, time1); + myObj.solve(maxiter, lam, penalty); + + return myObj.gammaOpt; +} + +int main() { + uword T = 101; + vec time = arma::linspace(0, 2*M_PI, T); + vec q1 = sin(time); + vec q2 = cos(time); + vec time1 = arma::linspace(0, 1, T); + + rlbfgs myObj(q1, q2, time1); + myObj.solve(); + + myObj.gammaOpt.print(); + + return 0; +} diff --git a/src/rbfgs.h b/src/rbfgs.h new file mode 100644 index 00000000..e7cc1677 --- /dev/null +++ b/src/rbfgs.h @@ -0,0 +1,11 @@ +#ifndef RBFGS_H +#define RBFGS_H + +#include "armadillo" + +using namespace arma; +using namespace std; + +vec rlbfgs_optim(vec q1, vec q2, vec time, int maxiter, double lam, int penalty); + +#endif // end of RBFGS_H