Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
rambaut committed Nov 9, 2024
2 parents f49f0ca + 3a76b7d commit 9ebcf1c
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

package dr.evomodel.branchmodel.lineagespecific;

import dr.math.matrixAlgebra.WrappedVector;
import org.apache.commons.math.MathException;

import dr.inference.model.CompoundLikelihood;
Expand Down Expand Up @@ -244,7 +245,7 @@ private void doOperate() throws MathException {

if (DEBUG) {
System.out.println("N[-index]: ");
dr.app.bss.Utils.printArray(occupancy);
System.out.println(new WrappedVector.Raw(occupancy));
}

Likelihood clusterLikelihood = (Likelihood) likelihood.getLikelihood(index);
Expand Down Expand Up @@ -288,11 +289,11 @@ private void doOperate() throws MathException {
clusterProbs[i] = logprob;
}// END: i loop

dr.app.bss.Utils.exponentiate(clusterProbs);
exponentiate(clusterProbs);

if (DEBUG) {
System.out.println("P(z[index] | z[-index]): ");
dr.app.bss.Utils.printArray(clusterProbs);
System.out.println(new WrappedVector.Raw(clusterProbs));
}

// sample
Expand All @@ -308,6 +309,12 @@ private void doOperate() throws MathException {

}// END: doOperate

public static void exponentiate(double[] array) {
for (int i = 0; i < array.length; i++) {
array[i] = Math.exp(array[i]);
}
}// END: exponentiate

@Override
public String getOperatorName() {
return DirichletProcessOperatorParser.DIRICHLET_PROCESS_OPERATOR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,12 @@
import java.util.ArrayList;
import java.util.List;

import dr.app.bss.Utils;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import dr.math.distributions.NormalDistribution;

public class DirichletProcessPriorLogger implements Loggable {
Expand Down Expand Up @@ -112,7 +111,7 @@ private void getNew() {

this.categoryProbabilities = getCategoryProbs();

this.newCategoryIndex = Utils.sample(categoryProbabilities);
this.newCategoryIndex = MathUtils.randomChoicePDF(categoryProbabilities);
this.meanForCategory = uniquelyRealizedParameters
.getParameterValue(newCategoryIndex);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodel.treelikelihood.BeagleTreeLikelihood;
import dr.evomodel.treelikelihood.PartialsRescalingScheme;
import dr.app.beagle.tools.BeagleSequenceSimulator;
import dr.app.beagle.tools.Partition;
//import dr.app.beagle.tools.BeagleSequenceSimulator;
//import dr.app.beagle.tools.Partition;
import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.ConvertAlignment;
import dr.evolution.datatype.Codons;
Expand Down Expand Up @@ -204,111 +204,111 @@ protected void acceptState() {
//
}// END: acceptState

public static void main(String[] args) {

try {

// the seed of the BEAST
MathUtils.setSeed(666);

// create tree
NewickImporter importer = new NewickImporter(
"(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);");
TreeModel tree = new DefaultTreeModel(importer.importTree(null));

// create site model
GammaSiteRateModel siteRateModel = new GammaSiteRateModel(
"siteModel");

// create branch rate model
BranchRateModel branchRateModel = new DefaultBranchRateModel();

int sequenceLength = 10;
ArrayList<Partition> partitionsList = new ArrayList<Partition>();

// create Frequency Model
Parameter freqs = new Parameter.Default(new double[]{
0.0163936, //
0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 //
});
FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL,
freqs);

// create substitution model
Parameter alpha = new Parameter.Default(1, 10);
Parameter beta = new Parameter.Default(1, 5);
MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions());

HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94);

// create partition
Partition partition1 = new Partition(tree, //
substitutionModel,//
siteRateModel, //
branchRateModel, //
freqModel, //
0, // from
sequenceLength - 1, // to
1 // every
);

partitionsList.add(partition1);

// feed to sequence simulator and generate data
BeagleSequenceSimulator simulator = new BeagleSequenceSimulator(
partitionsList);

Alignment alignment = simulator.simulate(false, false);

ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE,
GeneticCode.UNIVERSAL, alignment);


List<SubstitutionModel> substModels = new ArrayList<SubstitutionModel>();
for (int i = 0; i < 2; i++) {
// alpha = new Parameter.Default(1, 10 );
// beta = new Parameter.Default(1, 5 );
// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta,
// freqModel);
substModels.add(mg94);
}
Parameter uCategories = new Parameter.Default(2, 0);
// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories);
LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider,
uCategories);

BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, //
tree, //
branchSpecific, //
siteRateModel, //
branchRateModel, //
null, //
false, //
PartialsRescalingScheme.DEFAULT, true);

BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, //
tree, //
substitutionModel, //
siteRateModel, //
branchRateModel, //
null, //
false, //
PartialsRescalingScheme.DEFAULT, true);
System.out.println("likelihood (gold) = " + gold.getLogLikelihood());
System.out.println("likelihood = " + like.getLogLikelihood());
} catch (Exception e) {
e.printStackTrace();
}

}// END: main
// public static void main(String[] args) {
//
// try {
//
// // the seed of the BEAST
// MathUtils.setSeed(666);
//
// // create tree
// NewickImporter importer = new NewickImporter(
// "(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);");
// TreeModel tree = new DefaultTreeModel(importer.importTree(null));
//
// // create site model
// GammaSiteRateModel siteRateModel = new GammaSiteRateModel(
// "siteModel");
//
// // create branch rate model
// BranchRateModel branchRateModel = new DefaultBranchRateModel();
//
// int sequenceLength = 10;
// ArrayList<Partition> partitionsList = new ArrayList<Partition>();
//
// // create Frequency Model
// Parameter freqs = new Parameter.Default(new double[]{
// 0.0163936, //
// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, //
// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 //
// });
// FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL,
// freqs);
//
// // create substitution model
// Parameter alpha = new Parameter.Default(1, 10);
// Parameter beta = new Parameter.Default(1, 5);
// MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions());
//
// HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94);
//
// // create partition
// Partition partition1 = new Partition(tree, //
// substitutionModel,//
// siteRateModel, //
// branchRateModel, //
// freqModel, //
// 0, // from
// sequenceLength - 1, // to
// 1 // every
// );
//
// partitionsList.add(partition1);
//
// // feed to sequence simulator and generate data
// BeagleSequenceSimulator simulator = new BeagleSequenceSimulator(
// partitionsList);
//
// Alignment alignment = simulator.simulate(false, false);
//
// ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE,
// GeneticCode.UNIVERSAL, alignment);
//
//
// List<SubstitutionModel> substModels = new ArrayList<SubstitutionModel>();
// for (int i = 0; i < 2; i++) {
//// alpha = new Parameter.Default(1, 10 );
//// beta = new Parameter.Default(1, 5 );
//// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta,
//// freqModel);
// substModels.add(mg94);
// }
//
// Parameter uCategories = new Parameter.Default(2, 0);
//// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories);
//
// LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider,
// uCategories);
//
// BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, //
// tree, //
// branchSpecific, //
// siteRateModel, //
// branchRateModel, //
// null, //
// false, //
// PartialsRescalingScheme.DEFAULT, true);
//
// BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, //
// tree, //
// substitutionModel, //
// siteRateModel, //
// branchRateModel, //
// null, //
// false, //
// PartialsRescalingScheme.DEFAULT, true);
//
// System.out.println("likelihood (gold) = " + gold.getLogLikelihood());
// System.out.println("likelihood = " + like.getLogLikelihood());
//
// } catch (Exception e) {
// e.printStackTrace();
// }
//
// }// END: main

@Override
public Citation.Category getCategory() {
Expand Down
10 changes: 3 additions & 7 deletions src/dr/inference/model/BoundedSpace.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@

package dr.inference.model;

import dr.app.bss.Utils;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.math.MathUtils;
import dr.math.matrixAlgebra.EJMLUtils;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.*;
import org.ejml.data.Complex64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
Expand Down Expand Up @@ -192,7 +188,7 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) {
System.out.println("Raw matrix to decompose: ");
System.out.println(CinvV);
System.out.print("Raw eigenvalues: ");
Utils.printArray(values);
System.out.println(new WrappedVector.Raw(values));
}
for (int i = 0; i < values.length; i++) {
values[i] = 1 / values[i];
Expand Down Expand Up @@ -282,7 +278,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc
SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim);
SymmetricMatrix X = compoundSymmetricMatrix(0.0, direction, dim);
System.out.print("Eigenvalues: ");
Utils.printArray(values);
System.out.println(new WrappedVector.Raw(values));

Matrix S = new SymmetricMatrix(dim, dim);
Matrix T = new SymmetricMatrix(dim, dim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

package dr.inference.operators.hmc;

import dr.app.bss.Utils;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.*;
import dr.inference.operators.AdaptationMode;
Expand Down Expand Up @@ -340,15 +339,15 @@ public boolean doReflection(double[] position, WrappedVector momentum) {
if (DEBUG) {
System.out.println("time: " + eventTime);
System.out.print("start: ");
Utils.printArray(position);
System.out.println(new WrappedVector.Raw(position));
System.out.println(momentum);
}

type.doReflection(position, preconditioning, momentum, eventLocation, indices, normalVector, eventTime, remainingTime);

if (DEBUG) {
System.out.print("end: ");
Utils.printArray(position);
System.out.println(new WrappedVector.Raw(position));
System.out.println(momentum);
}

Expand Down
12 changes: 12 additions & 0 deletions src/dr/math/matrixAlgebra/WrappedVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ public Raw(double[] buffer) {
this(buffer, 0, buffer.length);
}

public Raw(int[] in) {
this(convert(in), 0, in.length);
}

private static double[] convert(int[] in) {
double[] buffer = new double[in.length];
for (int i = 0; i < in.length; ++i) {
buffer[i] = in[i];
}
return buffer;
}

@Override
final public double get(final int i) {
return buffer[offset + i];
Expand Down

0 comments on commit 9ebcf1c

Please sign in to comment.