Skip to content

Commit

Permalink
Merge pull request #53 from CosmoStat/gmca
Browse files Browse the repository at this point in the history
gmca bug correction
  • Loading branch information
jstarck authored Apr 29, 2024
2 parents 3f9d548 + f00319d commit 1ff7b9b
Show file tree
Hide file tree
Showing 13 changed files with 489 additions and 295 deletions.
41 changes: 21 additions & 20 deletions src/mc/mcmain1d/mr1d_gmca.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include "Array.h"
#include "NR.h"
#include "IM_Obj.h"
#include "IM_IO.h"
#include "GMCA.h"
#include "MR1D1D.h"
#include "IM_Noise.h"
#include "MR1D_NoiseModel.h"
#include "MR1D_Filter.h"

/****************************************************************************/

class GMCA_1D: public MR1D1D, public GMCA
class GMCA_1D: public MR1D1D, public GMCA
{
public:
Bool UseRMSMap;
Expand All @@ -46,7 +47,7 @@ class GMCA_1D: public MR1D1D, public GMCA
void inpainting_run(fltarray &TabCannels, fltarray & InpData);
void transrecons_sources(fltarray &TabVect,fltarray &Recdata);
void transform_sources(fltarray &Data,fltarray &TabVect);
void recons_sources(fltarray &DataIn, fltarray &EstSources); // Applying the matrix on the data
void recons_sources(dblarray &DataIn, dblarray &EstSources); // Applying the matrix on the data
void HT_Sources(fltarray &TabSource, float &KThrd) ;

~GMCA_1D() {} ;
Expand Down Expand Up @@ -170,12 +171,12 @@ void GMCA_1D::transrecons_sources(fltarray &TabVect,fltarray &Recdata)

/****************************************************************************/

void GMCA_1D::recons_sources(fltarray &DataIn, fltarray &EstSources) // Applying the matrix on the data
void GMCA_1D::recons_sources(dblarray &DataIn, dblarray &EstSources) // Applying the matrix on the data
{
int Nx = DataIn.nx();
int Ny = DataIn.ny();
int i,k,l;
fltarray RefData,RefSources;
dblarray RefData,RefSources;
int Deb = 0;

// cout << "NEW recons_sources " << endl;
Expand Down Expand Up @@ -588,7 +589,7 @@ int test_main(int argc, char *argv[])

int main(int argc, char *argv[])
{
fltarray Dat;
dblarray Dat;
/* Get command line arguments, open input file(s) if necessary */
fitsstruct Header;
char Cmd[512];
Expand Down Expand Up @@ -624,7 +625,7 @@ int main(int argc, char *argv[])

if (Verbose == True) cout << "\n Reading the data"<< endl;

fits_read_fltarr(Name_Cube_In, Dat, &Header);
fits_read_dblarr(Name_Cube_In, Dat);

int Nx = Dat.nx();
int Ny = Dat.ny();
Expand Down Expand Up @@ -676,8 +677,8 @@ int main(int argc, char *argv[])

// WT.write(Name_Out);
// Compute the 2D1D transform
fltarray TabVect;
WT.transform_to_vectarray(Dat, TabVect);
dblarray TabVect;
WT.transform_to_vectdblarray(Dat, TabVect);
// fits_write_fltarr ("xx_tabvect.fits", TabVect);

// Initalize the class for GMCA
Expand All @@ -699,7 +700,7 @@ int main(int argc, char *argv[])
WT.GlobThrd = GThrd;
WT.SVConst = UsePCA;
WT.MatNbrScale1D = Nbr_Plan;
fltarray QSVec;
dblarray QSVec;

if (UsePCA == True)
{
Expand All @@ -709,13 +710,13 @@ int main(int argc, char *argv[])

if (UseMask == True)
{
fits_read_fltarr (Name_Mask, WT.Mask);
fits_read_dblarr (Name_Mask, WT.Mask);
}


if (UseKnownColomn == True)
{
fits_read_fltarr (Name_KnowColumn, WT.MatKnownColumn);
fits_read_dblarr (Name_KnowColumn, WT.MatKnownColumn);
if (WT.MatKnownColumn.naxis() == 1) WT.NbrKnownColumn = 1;
else WT.NbrKnownColumn = WT.MatKnownColumn.axis(2);
}
Expand All @@ -726,7 +727,7 @@ int main(int argc, char *argv[])

if (EstimNbSources == False) // THE NUMBER OF SOURCES IS FIXED
{
fltarray TabSource;
dblarray TabSource;
int NbrCoef = TabVect.nx();
TabSource.alloc(NbrCoef,NbrSources);
WT.GMCA::Verbose = Verbose;
Expand All @@ -749,7 +750,7 @@ int main(int argc, char *argv[])
NbrSources++;
if (Verbose == True) cout << "Running GMCA ... Number of Estimated Sources : " << NbrSources << endl;
WT.NbrSources = NbrSources;
fltarray TabSource;
dblarray TabSource;
int NbrCoef = TabVect.nx();
TabSource.alloc(NbrCoef,NbrSources);
WT.GMCA::Verbose = Verbose;
Expand Down Expand Up @@ -782,7 +783,7 @@ int main(int argc, char *argv[])

// Reconstruction :
if (Verbose == True) cout << "Reconstruction ... "<< endl;
fltarray EstSources;
dblarray EstSources;
// cout << "GO REC" << endl;

WT.recons_sources(Dat,EstSources);
Expand All @@ -795,16 +796,16 @@ int main(int argc, char *argv[])
// fits_write_fltarr ("xx_InvMixingMat.fits", WT.InvMixingMat);

// Header.origin = Cmd;
fits_write_fltarr(Name_Out, EstSources);
if (WriteMixing == True) fits_write_fltarr (Name_Out_2, WT.RecMixingMat);
if (WriteChannels == True)
fits_write_dblarr(Name_Out, EstSources);
if (WriteMixing == True) fits_write_dblarr (Name_Out_2, WT.RecMixingMat);
if (WriteChannels == True)
{
fltarray EstChannels, TranspMixingMat;
dblarray EstChannels, TranspMixingMat;
MatOper MAT; // See file $Tools/MatrixOper.cc and .h
MAT.transpose(WT.MixingMat,TranspMixingMat);
WT.apply_mat(EstSources, TranspMixingMat, EstChannels);
// WT.apply_mat(EstSources, WT.MixingMat, EstChannels);
fits_write_fltarr (Name_Out_3, EstChannels);
fits_write_dblarr (Name_Out_3, EstChannels);
}
exit(0);
}
64 changes: 34 additions & 30 deletions src/mc/mcmain2d/mr_gmca.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,20 +229,18 @@ static void usage(char *argv[])
fprintf(OUTMAN, " Apply a l_1 constraint also on the mixing matrix. Default is no. \n");
fprintf(OUTMAN, " [-d]\n");
fprintf(OUTMAN, " Estimate the number of sources. Default is no. \n");
fprintf(OUTMAN, " [-d]\n");
fprintf(OUTMAN, " Estimate the number of sources. Default is no. \n");
fprintf(OUTMAN, " [-m]\n");
fprintf(OUTMAN, " Mad-based stopping criterion when the number of sources is estimated. Default is 5 - default criterion is l2-based. \n");
fprintf(OUTMAN, " [-L]\n");
fprintf(OUTMAN, " L2-based stopping criterion when the number of sources is estimated. Default is 40 (in dB). \n");
fprintf(OUTMAN, " [-D] \n");
fprintf(OUTMAN, " Spectra with disjoint supports for thresholds higher than 7 Mad. \n"); // Should be an option
// fprintf(OUTMAN, " [-D] \n");
// fprintf(OUTMAN, " Spectra with disjoint supports for thresholds higher than 7 Mad. \n"); // Should be an option
fprintf(OUTMAN, " [-K Last K-Mad]\n");
fprintf(OUTMAN, " Last value of K for K-Mad Thresholding. \n");
fprintf(OUTMAN, " [-G Global Thresholding]\n");
fprintf(OUTMAN, " [-O]\n");
fprintf(OUTMAN, " Orthogonalization of the spectra\n");
verbose_usage();
// fprintf(OUTMAN, " [-O]\n");
// fprintf(OUTMAN, " Orthogonalization of the spectra\n");
verbose_usage();
vm_usage();
manline();
exit(-1);
Expand Down Expand Up @@ -435,7 +433,7 @@ static void transinit(int argc, char *argv[])

int main(int argc, char *argv[])
{
fltarray Dat;
dblarray Dat;
/* Get command line arguments, open input file(s) if necessary */
fitsstruct Header;
char Cmd[512];
Expand Down Expand Up @@ -466,38 +464,44 @@ int main(int argc, char *argv[])

if (Verbose == True) cout << "\n Reading the data"<< endl;

io_3d_read_data(Name_Cube_In, Dat, &Header);

// io_3d_read_data(Name_Cube_In, Dat, &Header);
fits_read_dblarr (Name_Cube_In, Dat);

int Nx = Dat.nx();
int Ny = Dat.ny();
int Nz = Dat.nz();
if (Verbose == True) cout << "Nx = " << Dat.nx() << " Ny = " << Dat.ny() << " Nz = " << Dat.nz() << endl;

if (Verbose == True)
cout << "Nx = " << Dat.nx() << " Ny = " << Dat.ny() << " Nz = " << Dat.nz() << endl;
// Dat.info("READ data");


if (Normalize == True)
{
double Mean = Dat.mean();
double Sigma = Dat.sigma();
// cout << "Data mean = " << Mean << " Sigma = " << Sigma << endl;
// printf("Sigmad=f = %f", Sigma);
for (int i=0;i<Nx;i++)
for (int j=0;j<Ny;j++)
for (int k=0;k<Nz;k++) Dat(i,j,k) = (Dat(i,j,k)-Mean)/Sigma;
}

// Dat.info("Normalized data");


// MR2D1D WT;
GMCA_2D WT;

if (Verbose == True) cout << "Alloc ... " << endl;
WT.alloc(Nx, Ny, Nz, Transform, NbrScale2d, 1); // On ne regularise que la MixingMat

if (Verbose == True) cout << "2d1d_trans ... "<< endl;

//WT.transform (Dat); // Pas utile

// WT.write(Name_Out);
// Compute the 2D1D transform
fltarray TabVect;
WT.transform_to_vectarray(Dat, TabVect);
// fits_write_fltarr ("xx_tabvect.fits", TabVect);
dblarray TabVect;
WT.transform_to_vectdblarray(Dat, TabVect);
// fits_write_dblarr ("xx_tabvect.fits", TabVect);

// Initalize the class for GMCA

Expand All @@ -518,7 +522,7 @@ int main(int argc, char *argv[])
WT.GlobThrd = GThrd;
WT.SVConst = UsePCA;
WT.MatNbrScale1D = Nbr_Plan;
fltarray QSVec;
dblarray QSVec;

if (UsePCA == True)
{
Expand All @@ -528,12 +532,12 @@ int main(int argc, char *argv[])

if (UseMask == True)
{
fits_read_fltarr (Name_Mask, WT.Mask);
fits_read_dblarr (Name_Mask, WT.Mask);
}

if (UseKnownColomn == True)
{
fits_read_fltarr (Name_KnowColumn, WT.MatKnownColumn);
fits_read_dblarr (Name_KnowColumn, WT.MatKnownColumn);
if (WT.MatKnownColumn.naxis() == 1) WT.NbrKnownColumn = 1;
else WT.NbrKnownColumn = WT.MatKnownColumn.axis(2);
}
Expand All @@ -543,7 +547,7 @@ int main(int argc, char *argv[])

if (EstimNbSources == False) // THE NUMBER OF SOURCES IS FIXED
{
fltarray TabSource;
dblarray TabSource;
int NbrCoef = TabVect.nx();
TabSource.alloc(NbrCoef,NbrSources);
WT.GMCA::Verbose = Verbose;
Expand All @@ -554,19 +558,19 @@ int main(int argc, char *argv[])
if (EstimNbSources == TRUE) // THE NUMBER OF SOURCES IS ESTIMATED
{
int NbrSourcesMax = NbrSources;
float RelError = 0;
double RelError = 0;
//float SigmaData=TabVect.sigma();
NbrSources = 1;
bool ExitCriterion = False; /// CHANGEDDDDD
float OldRelError=0;
float DiffRelError;
double OldRelError=0;
double DiffRelError;

while (NbrSources <= NbrSourcesMax && ExitCriterion == False)
{
NbrSources++;
if (Verbose == True) cout << "Running GMCA ... Number of Estimated Sources : " << NbrSources << endl;
WT.NbrSources = NbrSources;
fltarray TabSource;
dblarray TabSource;
int NbrCoef = TabVect.nx();
TabSource.alloc(NbrCoef,NbrSources);
WT.GMCA::Verbose = Verbose;
Expand Down Expand Up @@ -605,17 +609,17 @@ int main(int argc, char *argv[])

// Reconstruction :
if (Verbose == True) cout << "Reconstruction ... "<< endl;
fltarray EstSources;
dblarray EstSources;
WT.recons_sources(Dat,EstSources);
// WT.Sort_Sources(EstSources);

if (Verbose == True) cout << "Write results ... "<< endl;
// fits_write_fltarr ("xx_EstSources.fits", EstSources);
fits_write_fltarr ("xx_EstMixmat.fits", WT.MixingMat);
fits_write_fltarr ("xx_InvMixingMat.fits", WT.InvMixingMat);
fits_write_dblarr ("xx_EstMixmat.fits", WT.MixingMat);
fits_write_dblarr ("xx_InvMixingMat.fits", WT.InvMixingMat);

// Header.origin = Cmd;
fits_write_fltarr(Name_Out, EstSources);
fits_write_dblarr(Name_Out, EstSources);

exit(0);
}
Loading

0 comments on commit 1ff7b9b

Please sign in to comment.