Skip to content

Commit

Permalink
update c++ sources
Browse files Browse the repository at this point in the history
  • Loading branch information
yomichi committed Sep 19, 2021
1 parent 9c38915 commit a5e0822
Show file tree
Hide file tree
Showing 18 changed files with 644 additions and 252 deletions.
3 changes: 1 addition & 2 deletions c++/include/Gf.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#ifndef _GF_H
#define _GF_H

Expand Down Expand Up @@ -39,7 +38,7 @@ class Gf{
double rho(const double omega);

private:
int N;
int N; // == beta/dt
double beta;
double tail;
statistics stat;
Expand Down
8 changes: 4 additions & 4 deletions c++/include/SVD_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ class SVD_matrix {
void rearrange_col(std::vector<int> &); // rearrange column of A (actaully VT)
void print_basis(std::string _file_U, std::string _file_V, int);

CPPL::dcovector transform_x2sv(const CPPL::dcovector &); // V^t * x (to original basis)
CPPL::dcovector transform_y2sv(const CPPL::dcovector &); // U^t * y (to original basis)
CPPL::dcovector transform_sv2x(const CPPL::dcovector &); // V * x' (to SV basis)
CPPL::dcovector transform_sv2y(const CPPL::dcovector &); // U * y' (to SV basis)
CPPL::dcovector transform_x2sv(const CPPL::dcovector &); // V^t * x (to SV basis)
CPPL::dcovector transform_y2sv(const CPPL::dcovector &); // U^t * y (to SV basis)
CPPL::dcovector transform_sv2x(const CPPL::dcovector &); // V * x' (to original basis)
CPPL::dcovector transform_sv2y(const CPPL::dcovector &); // U * y' (to original basis)
void OutputSVD(std::string _file_SVD);
private:
CPPL::dcovector S_temp;
Expand Down
29 changes: 24 additions & 5 deletions c++/include/admm_svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class SVD_matrix;
struct admm_result{
// w/ sv : results in SV basis
// w/o sv : results in omega-tau basis
std::vector<double> x, xsv, z1, z1sv, z2, z2sv;
std::vector<double> x, xsv, z1, z1sv, z2, z2sv, z3, z3sv;
std::vector<double> y, ysv, y_recovered_x, y_recovered_z1, ysv_recovered_x, ysv_recovered_z1;
};
struct admm_info{
double res1_pri, res1_dual, res2_pri, res2_dual; // residual errors
double res3_pri, res3_dual; // spmpade
double mse, mse_full;
int l0_norm;
double l1_norm, sum_x_calc, negative_weight;
int iter;
};
Expand All @@ -52,13 +54,16 @@ class admm_svd{

// [optional]
void set_sumrule(double sum_x);
void set_rho_ref(std::vector<double>const &rho_ref, std::vector<double> const & rho_ref_weight);
void set_omega_coeff(std::vector<double> const& omega_coeff);
void set_nonnegative(bool _flag);
void set_fileout_iter(const std::string filename); // output convergence in a file. unset if filename=""
void set_print_level(int); // 0: none, 1: results, 2: verbose

// [required]
void set_coef(double lambda, double penalty1=1.0, double penalty2=1.0, bool flag_penalty_auto=false);
void set_coef(double lambda, double penalty1=1.0, double penalty2=1.0, double penalty3=1.0, bool flag_penalty_auto=false);
void set_y(const std::vector<double> &y);
int get_svd_num(){return svd_num;}

// [optional]
void clear_x();
Expand All @@ -77,28 +82,42 @@ class admm_svd{

bool flag_nonnegative;
bool flag_sumrule;
bool flag_rho_ref;
double sum_x;
std::vector<double> rho_ref;
std::vector<double> rho_ref_weight;

CPPL::dcovector x, Vx, z1, u1, z2, u2; // 1 for L1 norm, 2 for non-negativity
CPPL::dcovector x, Vx, WVx, z1, u1, z2, u2; // 1 for L1 norm, 2 for non-negativity
CPPL::dcovector z3, u3; // SpM-Pade
CPPL::dcovector y, y_sv;

double regulariz, penalty1, penalty2;
double penalty3; // SpMPade
bool flag_penalty_auto;
static const int PENALTY_UPDATE_INTERVAL;

std::vector<double> omega_coeff; // omega

int print_level;
bool flag_fileout_iter;
std::string file_iter;

int svd_num;
void pre_update(); // This function must be called before calling update_x, and must be recalled when one of values of lambda, penalty1, penalty2 are changed
void update_x();

CPPL::dcovector transform_x2sv(CPPL::dcovector const& v);
CPPL::dcovector transform_sv2x(CPPL::dcovector const& v);

// quantities used in functions update_x and set_y (set in function pre_update)
struct quantities_for_update{
CPPL::dgbmatrix Y, B; // diagonal matrix
CPPL::dgematrix C;
CPPL::dgematrix Y, B; // diagonal matrix
CPPL::dgbmatrix Pade_B3;
CPPL::dgematrix Pade_B0, C;
CPPL::dcovector Yy, w;
CPPL::dcovector Pade_y;
CPPL::drovector v_row;
CPPL::dcovector rho_w;
double sum_Vw;
} pre;
};
Expand Down
6 changes: 6 additions & 0 deletions c++/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ CPPL::dcovector vec2cppl_col(const std::vector<double> &); // vector -> CPPL
std::vector<double> cppl2vec(const CPPL::dcovector &); // CPPL -> vector
std::vector<double> cppl2vec(const CPPL::drovector &); // CPPL -> vector

int norm_l0(const CPPL::dcovector &); // L0 norm
double norm_l1(const CPPL::dcovector &); // L1 norm
double norm_l2(const CPPL::dcovector &); // L2 norm
double norm_l2_sq(const CPPL::dcovector &); // square of L2 norm
Expand All @@ -39,6 +40,11 @@ CPPL::drovector drovector_all1(int n); // return a rovector with elements all b
CPPL::dcovector positive(const CPPL::dcovector &); // set negative elements at zero
CPPL::dgematrix positive(const CPPL::dgematrix &);

// return (v + coeff*ref) / (1+coeff)
CPPL::dcovector project_ref(const CPPL::dcovector &v,
const std::vector<double> &ref,
const std::vector<double> &coeff);

// return shrinked matrix
CPPL::dgematrix low_rank_matrix(CPPL::dgematrix &mat, int m, int n);

Expand Down
6 changes: 3 additions & 3 deletions c++/include/fft.h
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@


#ifndef _FFT_H
#define _FFT_H

#include <vector>
#include <complex>

inline bool isodd(int n) { return (n/2)*2 != n; }
inline bool iseven(int n){ return !isodd(n); }

void fft_fermion_tau2iw(const std::vector<double> &G_tau, std::vector< std::complex<double> > &G_iw, const double beta, const double tail=1.);

void fft_fermion_iw2tau(std::vector<double> &G_tau, const std::vector< std::complex<double> > &G_iw, const double beta, const double tail=1.);


void fft_boson_tau2iw(const std::vector<double> &G_tau, std::vector< std::complex<double> > &G_iw, const double beta, const double tail=0.);

void fft_boson_iw2tau(std::vector<double> &G_tau, const std::vector< std::complex<double> > &G_iw, const double beta, const double tail=0.);
Expand Down
6 changes: 3 additions & 3 deletions c++/include/set_initial.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class SetInitial {
};
struct FileInfo {
std::string filein_G;
std::string filein_Gsigma;
int col;
int colsigma;
int output_interval;
std::string fileout_spec;
int print_level;
Expand All @@ -43,6 +45,7 @@ class SetInitial {
struct CalcInfo {
std::string statistics;
double beta;
double sigma;
};

int argc;
Expand All @@ -53,8 +56,6 @@ class SetInitial {
std::map<std::string, std::string> mapForKeyWordToValue;
std::map<std::string, bool> mapForKeyWordToRead;

template<class T>
std::string argv_or_defaultvalue(int n, T value);
bool SetDefaultValue();
void SetInputValue();
void RegisterMap(std::string _keyword, std::string _value);
Expand All @@ -79,7 +80,6 @@ class SetInitial {
CalcInfo calcInfo;
bool AddDefaulValueMap(char* _filename);
bool ReadParam(char* _filename);
bool InputFromArgs(int argc, char *argv[]);
void PrintInfo();
};

Expand Down
17 changes: 12 additions & 5 deletions c++/include/spm_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class SPM_Core {
std::vector<double> omega;
std::vector<admm_result> result;
std::vector<admm_info> info;
std::string StatisticsType;
double beta;

double integrate(std::vector<double> &y, double width);

Expand Down Expand Up @@ -66,23 +68,28 @@ class SPM_Core {

void GetSpectrum(std::vector<double> &_spectrum);

void GetResults(std::vector<double> &_vmse, std::vector<double> &_vmse_full, std::vector<double> &_vl1_norm,
void GetResults(std::vector<double> &_vmse, std::vector<double> &_vmse_full,
std::vector<int> &_vl0_norm, std::vector<double> &_vl1_norm,
std::vector<double> &_valid);

int SolveEquation(
std::string _StatisticsType,
double _Beta,
std::vector<std::vector<double> > &_AIn,
std::vector<double> &_Gtau,
std::vector<double> &_lambda,
std::vector<double> &_omega);
std::vector<double> &_Gtau_error,
std::vector<double> &_lambda,
std::vector<double> &_omega);

int SolveEquationCore(
std::vector<std::vector<double> > &_AIn,
std::vector<double> &_Gtau,
std::vector<double> &_omega,
std::vector<double> &_lambda,
const double _sum_G
std::vector<double> &_omega_coeff,
std::vector<double> &_lambda,
const double _sum_G,
std::vector<double> &_ref_rho,
std::vector<double> &_ref_coeff
);
};

Expand Down
93 changes: 52 additions & 41 deletions c++/include/spm_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,64 @@
#ifndef _SPM_PARAM_HEADER
#define _SPM_PARAM_HEADER

class SPM_Param{
private:
struct Lambda{
int Nl;
double lbegin;
double lend;
int lvalid;
double dlambda;
Lambda(){
Nl=1;
lbegin = 1e-1;
lend = 1e+0;
lvalid=0;
dlambda=-1;
}
};
#include <string>

struct Admm{
double penalty;
double tolerance;
int max_iter;
bool flag_penalty_auto;
Admm(){
penalty=10.;
tolerance = 1e-6;
max_iter=1000;
flag_penalty_auto=false;
}
};
class SPM_Param {
private:
struct Lambda {
int Nl;
double lbegin;
double lend;
int lvalid;
double dlambda;

struct SVD{
double sv_min;
SVD(){
sv_min=0;
}
};

public:
Lambda() {
Nl = 1;
lbegin = 1e-1;
lend = 1e+0;
lvalid = 0;
dlambda = -1;
}
};

struct Admm {
double penalty;
double tolerance;
int max_iter;
bool flag_penalty_auto;

Admm() {
penalty = 10.;
tolerance = 1e-6;
max_iter = 1000;
flag_penalty_auto = false;
}
};

struct SVD {
double sv_min;
SVD() {
sv_min = 0;
}
};

struct Pade {
std::string filename;
int nsample;
double eta;
};

public:
Lambda lambda;
Admm admm;
SVD svd;

Pade pade;
};

struct SPM_Flags{
bool validation;
bool nonnegative;
struct SPM_Flags {
bool validation;
bool nonnegative;
bool refrho;
};

#endif
Loading

0 comments on commit a5e0822

Please sign in to comment.