From 3a76b7d87a99d688fb14c3f2c93473dfa795c35a Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Fri, 8 Nov 2024 14:41:35 -0800 Subject: [PATCH] clean messy dependence between unrelated packages --- .../DirichletProcessOperator.java | 13 +- .../DirichletProcessPriorLogger.java | 5 +- .../LineageSpecificBranchModel.java | 214 +++++++++--------- src/dr/inference/model/BoundedSpace.java | 10 +- ...flectiveHamiltonianMonteCarloOperator.java | 5 +- src/dr/math/matrixAlgebra/WrappedVector.java | 12 + 6 files changed, 136 insertions(+), 123 deletions(-) diff --git a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java index 332fd226a1..154477a095 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessOperator.java @@ -27,6 +27,7 @@ package dr.evomodel.branchmodel.lineagespecific; +import dr.math.matrixAlgebra.WrappedVector; import org.apache.commons.math.MathException; import dr.inference.model.CompoundLikelihood; @@ -244,7 +245,7 @@ private void doOperate() throws MathException { if (DEBUG) { System.out.println("N[-index]: "); - dr.app.bss.Utils.printArray(occupancy); + System.out.println(new WrappedVector.Raw(occupancy)); } Likelihood clusterLikelihood = (Likelihood) likelihood.getLikelihood(index); @@ -288,11 +289,11 @@ private void doOperate() throws MathException { clusterProbs[i] = logprob; }// END: i loop - dr.app.bss.Utils.exponentiate(clusterProbs); + exponentiate(clusterProbs); if (DEBUG) { System.out.println("P(z[index] | z[-index]): "); - dr.app.bss.Utils.printArray(clusterProbs); + System.out.println(new WrappedVector.Raw(clusterProbs)); } // sample @@ -308,6 +309,12 @@ private void doOperate() throws MathException { }// END: doOperate + public static void exponentiate(double[] array) { + for (int i = 0; i < array.length; i++) { + array[i] = Math.exp(array[i]); + } + }// END: exponentiate + @Override public String getOperatorName() { return DirichletProcessOperatorParser.DIRICHLET_PROCESS_OPERATOR; diff --git a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java index def82debb5..d0c2537c26 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/DirichletProcessPriorLogger.java @@ -30,13 +30,12 @@ import java.util.ArrayList; import java.util.List; -import dr.app.bss.Utils; -import dr.inference.distribution.ParametricMultivariateDistributionModel; import dr.inference.loggers.LogColumn; import dr.inference.loggers.Loggable; import dr.inference.loggers.NumberColumn; import dr.inference.model.CompoundParameter; import dr.inference.model.Parameter; +import dr.math.MathUtils; import dr.math.distributions.NormalDistribution; public class DirichletProcessPriorLogger implements Loggable { @@ -112,7 +111,7 @@ private void getNew() { this.categoryProbabilities = getCategoryProbs(); - this.newCategoryIndex = Utils.sample(categoryProbabilities); + this.newCategoryIndex = MathUtils.randomChoicePDF(categoryProbabilities); this.meanForCategory = uniquelyRealizedParameters .getParameterValue(newCategoryIndex); diff --git a/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java b/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java index 811f68f469..0edca6ea3c 100644 --- a/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java +++ b/src/dr/evomodel/branchmodel/lineagespecific/LineageSpecificBranchModel.java @@ -40,8 +40,8 @@ import dr.evomodel.tree.DefaultTreeModel; import dr.evomodel.treelikelihood.BeagleTreeLikelihood; import dr.evomodel.treelikelihood.PartialsRescalingScheme; -import dr.app.beagle.tools.BeagleSequenceSimulator; -import dr.app.beagle.tools.Partition; +//import dr.app.beagle.tools.BeagleSequenceSimulator; +//import dr.app.beagle.tools.Partition; import dr.evolution.alignment.Alignment; import dr.evolution.alignment.ConvertAlignment; import dr.evolution.datatype.Codons; @@ -204,111 +204,111 @@ protected void acceptState() { // }// END: acceptState - public static void main(String[] args) { - - try { - - // the seed of the BEAST - MathUtils.setSeed(666); - - // create tree - NewickImporter importer = new NewickImporter( - "(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);"); - TreeModel tree = new DefaultTreeModel(importer.importTree(null)); - - // create site model - GammaSiteRateModel siteRateModel = new GammaSiteRateModel( - "siteModel"); - - // create branch rate model - BranchRateModel branchRateModel = new DefaultBranchRateModel(); - - int sequenceLength = 10; - ArrayList partitionsList = new ArrayList(); - - // create Frequency Model - Parameter freqs = new Parameter.Default(new double[]{ - 0.0163936, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // - 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 // - }); - FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL, - freqs); - - // create substitution model - Parameter alpha = new Parameter.Default(1, 10); - Parameter beta = new Parameter.Default(1, 5); - MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions()); - - HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94); - - // create partition - Partition partition1 = new Partition(tree, // - substitutionModel,// - siteRateModel, // - branchRateModel, // - freqModel, // - 0, // from - sequenceLength - 1, // to - 1 // every - ); - - partitionsList.add(partition1); - - // feed to sequence simulator and generate data - BeagleSequenceSimulator simulator = new BeagleSequenceSimulator( - partitionsList); - - Alignment alignment = simulator.simulate(false, false); - - ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE, - GeneticCode.UNIVERSAL, alignment); - - - List substModels = new ArrayList(); - for (int i = 0; i < 2; i++) { -// alpha = new Parameter.Default(1, 10 ); -// beta = new Parameter.Default(1, 5 ); -// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta, -// freqModel); - substModels.add(mg94); - } - - Parameter uCategories = new Parameter.Default(2, 0); -// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories); - - LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider, - uCategories); - - BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, // - tree, // - branchSpecific, // - siteRateModel, // - branchRateModel, // - null, // - false, // - PartialsRescalingScheme.DEFAULT, true); - - BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, // - tree, // - substitutionModel, // - siteRateModel, // - branchRateModel, // - null, // - false, // - PartialsRescalingScheme.DEFAULT, true); - - System.out.println("likelihood (gold) = " + gold.getLogLikelihood()); - System.out.println("likelihood = " + like.getLogLikelihood()); - - } catch (Exception e) { - e.printStackTrace(); - } - - }// END: main +// public static void main(String[] args) { +// +// try { +// +// // the seed of the BEAST +// MathUtils.setSeed(666); +// +// // create tree +// NewickImporter importer = new NewickImporter( +// "(SimSeq1:73.7468,(SimSeq2:25.256989999999995,SimSeq3:45.256989999999995):18.48981);"); +// TreeModel tree = new DefaultTreeModel(importer.importTree(null)); +// +// // create site model +// GammaSiteRateModel siteRateModel = new GammaSiteRateModel( +// "siteModel"); +// +// // create branch rate model +// BranchRateModel branchRateModel = new DefaultBranchRateModel(); +// +// int sequenceLength = 10; +// ArrayList partitionsList = new ArrayList(); +// +// // create Frequency Model +// Parameter freqs = new Parameter.Default(new double[]{ +// 0.0163936, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, // +// 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344 // +// }); +// FrequencyModel freqModel = new FrequencyModel(Codons.UNIVERSAL, +// freqs); +// +// // create substitution model +// Parameter alpha = new Parameter.Default(1, 10); +// Parameter beta = new Parameter.Default(1, 5); +// MG94HKYCodonModel mg94 = new MG94K80CodonModel(Codons.UNIVERSAL, alpha, beta, freqModel, new CodonOptions()); +// +// HomogeneousBranchModel substitutionModel = new HomogeneousBranchModel(mg94); +// +// // create partition +// Partition partition1 = new Partition(tree, // +// substitutionModel,// +// siteRateModel, // +// branchRateModel, // +// freqModel, // +// 0, // from +// sequenceLength - 1, // to +// 1 // every +// ); +// +// partitionsList.add(partition1); +// +// // feed to sequence simulator and generate data +// BeagleSequenceSimulator simulator = new BeagleSequenceSimulator( +// partitionsList); +// +// Alignment alignment = simulator.simulate(false, false); +// +// ConvertAlignment convert = new ConvertAlignment(Nucleotides.INSTANCE, +// GeneticCode.UNIVERSAL, alignment); +// +// +// List substModels = new ArrayList(); +// for (int i = 0; i < 2; i++) { +//// alpha = new Parameter.Default(1, 10 ); +//// beta = new Parameter.Default(1, 5 ); +//// mg94 = new MG94HKYCodonModel(Codons.UNIVERSAL, alpha, beta, +//// freqModel); +// substModels.add(mg94); +// } +// +// Parameter uCategories = new Parameter.Default(2, 0); +//// CountableBranchCategoryProvider provider = new CountableBranchCategoryProvider.IndependentBranchCategoryModel(tree, uCategories); +// +// LineageSpecificBranchModel branchSpecific = new LineageSpecificBranchModel(tree, freqModel, substModels, //provider, +// uCategories); +// +// BeagleTreeLikelihood like = new BeagleTreeLikelihood(convert, // +// tree, // +// branchSpecific, // +// siteRateModel, // +// branchRateModel, // +// null, // +// false, // +// PartialsRescalingScheme.DEFAULT, true); +// +// BeagleTreeLikelihood gold = new BeagleTreeLikelihood(convert, // +// tree, // +// substitutionModel, // +// siteRateModel, // +// branchRateModel, // +// null, // +// false, // +// PartialsRescalingScheme.DEFAULT, true); +// +// System.out.println("likelihood (gold) = " + gold.getLogLikelihood()); +// System.out.println("likelihood = " + like.getLogLikelihood()); +// +// } catch (Exception e) { +// e.printStackTrace(); +// } +// +// }// END: main @Override public Citation.Category getCategory() { diff --git a/src/dr/inference/model/BoundedSpace.java b/src/dr/inference/model/BoundedSpace.java index 458dfd3143..22dcc609ff 100644 --- a/src/dr/inference/model/BoundedSpace.java +++ b/src/dr/inference/model/BoundedSpace.java @@ -27,13 +27,9 @@ package dr.inference.model; -import dr.app.bss.Utils; import dr.inference.operators.hmc.HamiltonianMonteCarloOperator; import dr.math.MathUtils; -import dr.math.matrixAlgebra.EJMLUtils; -import dr.math.matrixAlgebra.IllegalDimension; -import dr.math.matrixAlgebra.Matrix; -import dr.math.matrixAlgebra.SymmetricMatrix; +import dr.math.matrixAlgebra.*; import org.ejml.data.Complex64F; import org.ejml.data.DenseMatrix64F; import org.ejml.factory.DecompositionFactory; @@ -192,7 +188,7 @@ private double[] trajectoryEigenvalues(double[] origin, double[] direction) { System.out.println("Raw matrix to decompose: "); System.out.println(CinvV); System.out.print("Raw eigenvalues: "); - Utils.printArray(values); + System.out.println(new WrappedVector.Raw(values)); } for (int i = 0; i < values.length; i++) { values[i] = 1 / values[i]; @@ -282,7 +278,7 @@ public IntersectionDistances distancesToBoundary(double[] origin, double[] direc SymmetricMatrix Y = compoundCorrelationSymmetricMatrix(origin, dim); SymmetricMatrix X = compoundSymmetricMatrix(0.0, direction, dim); System.out.print("Eigenvalues: "); - Utils.printArray(values); + System.out.println(new WrappedVector.Raw(values)); Matrix S = new SymmetricMatrix(dim, dim); Matrix T = new SymmetricMatrix(dim, dim); diff --git a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java index 052d6e3a33..a86a13e9c3 100644 --- a/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java +++ b/src/dr/inference/operators/hmc/ReflectiveHamiltonianMonteCarloOperator.java @@ -28,7 +28,6 @@ package dr.inference.operators.hmc; -import dr.app.bss.Utils; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.*; import dr.inference.operators.AdaptationMode; @@ -340,7 +339,7 @@ public boolean doReflection(double[] position, WrappedVector momentum) { if (DEBUG) { System.out.println("time: " + eventTime); System.out.print("start: "); - Utils.printArray(position); + System.out.println(new WrappedVector.Raw(position)); System.out.println(momentum); } @@ -348,7 +347,7 @@ public boolean doReflection(double[] position, WrappedVector momentum) { if (DEBUG) { System.out.print("end: "); - Utils.printArray(position); + System.out.println(new WrappedVector.Raw(position)); System.out.println(momentum); } diff --git a/src/dr/math/matrixAlgebra/WrappedVector.java b/src/dr/math/matrixAlgebra/WrappedVector.java index 1a4923ad0f..4655b8233e 100644 --- a/src/dr/math/matrixAlgebra/WrappedVector.java +++ b/src/dr/math/matrixAlgebra/WrappedVector.java @@ -86,6 +86,18 @@ public Raw(double[] buffer) { this(buffer, 0, buffer.length); } + public Raw(int[] in) { + this(convert(in), 0, in.length); + } + + private static double[] convert(int[] in) { + double[] buffer = new double[in.length]; + for (int i = 0; i < in.length; ++i) { + buffer[i] = in[i]; + } + return buffer; + } + @Override final public double get(final int i) { return buffer[offset + i];