From 9a489197b74a11c7b383f036200ea4a640359760 Mon Sep 17 00:00:00 2001 From: jtmccr1 Date: Sat, 23 Nov 2024 11:58:55 -0600 Subject: [PATCH] updating indexes in tests flagged by ci --- .../coalescent/CoalescentGradient.java | 12 +++---- .../GMRFMultilocusSkyrideLikelihood.java | 3 +- .../coalescent/UnifiedGMRFLikelihood.java | 3 +- .../evomodel/coalescent/hmc/GMRFGradient.java | 32 ++++++++++++------- .../coalescent/smooth/SkyGlideLikelihood.java | 6 ++-- 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/src/dr/evolution/coalescent/CoalescentGradient.java b/src/dr/evolution/coalescent/CoalescentGradient.java index 6784e3320e..088e258758 100644 --- a/src/dr/evolution/coalescent/CoalescentGradient.java +++ b/src/dr/evolution/coalescent/CoalescentGradient.java @@ -123,19 +123,19 @@ private double[] getGradientLogDensityWrtNodeHeights() { } IntervalList intervals = likelihood.getIntervalList(); - BigFastTreeIntervals bigFastTreeIntervals = (BigFastTreeIntervals) intervals; + TreeIntervalList bigFastTreeIntervals = (TreeIntervalList) intervals; // TODO should not be BFT specific DemographicFunction demographicFunction = likelihood.getDemoModel().getDemographicFunction(); int numSameHeightNodes = 1; double thisGradient = 0; - for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); i++) { - if (bigFastTreeIntervals.getIntervalType(i) == IntervalType.COALESCENT) { - final double time = bigFastTreeIntervals.getIntervalTime(i + 1); + for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); i++) { + if (bigFastTreeIntervals.getIntervalType(i) == IntervalType.COALESCENT) { + final double endTime = bigFastTreeIntervals.getEventTime(i + 1); final int lineageCount = bigFastTreeIntervals.getLineageCount(i); final double kChoose2 = Binomial.choose2(lineageCount); - final double intensityGradient = demographicFunction.getIntensityGradient(time); - thisGradient += demographicFunction.getLogDemographicGradient(time); + final double intensityGradient = demographicFunction.getIntensityGradient(endTime); + thisGradient += demographicFunction.getLogDemographicGradient(endTime); if (bigFastTreeIntervals.getInterval(i) != 0) { thisGradient -= kChoose2 * intensityGradient; diff --git a/src/dr/evomodel/coalescent/GMRFMultilocusSkyrideLikelihood.java b/src/dr/evomodel/coalescent/GMRFMultilocusSkyrideLikelihood.java index 1df250a3e6..0f3fdb6de9 100644 --- a/src/dr/evomodel/coalescent/GMRFMultilocusSkyrideLikelihood.java +++ b/src/dr/evomodel/coalescent/GMRFMultilocusSkyrideLikelihood.java @@ -25,6 +25,7 @@ package dr.evomodel.coalescent; +import dr.evolution.coalescent.IntervalEventList; import dr.evolution.coalescent.IntervalList; import dr.evolution.coalescent.IntervalType; import dr.evolution.coalescent.TreeIntervals; @@ -879,7 +880,7 @@ public Tree getTree(int nt) { return treeList.get(nt); } - public IntervalList getTreeIntervals(int nt) { + public IntervalEventList getTreeIntervals(int nt) { return intervalsList.get(nt); } diff --git a/src/dr/evomodel/coalescent/UnifiedGMRFLikelihood.java b/src/dr/evomodel/coalescent/UnifiedGMRFLikelihood.java index bb748352cc..a818bab7ae 100644 --- a/src/dr/evomodel/coalescent/UnifiedGMRFLikelihood.java +++ b/src/dr/evomodel/coalescent/UnifiedGMRFLikelihood.java @@ -2,6 +2,7 @@ import java.util.List; +import dr.evolution.coalescent.IntervalEventList; import dr.evolution.coalescent.IntervalList; import dr.evolution.coalescent.TreeIntervalList; import dr.evolution.tree.Tree; @@ -76,7 +77,7 @@ interface SkyGrid extends Skyride{ double[] getNumCoalEvents(); Tree getTree(int nt); - IntervalList getTreeIntervals(int nt); + IntervalEventList getTreeIntervals(int nt); double getPopulationFactor(int nt); diff --git a/src/dr/evomodel/coalescent/hmc/GMRFGradient.java b/src/dr/evomodel/coalescent/hmc/GMRFGradient.java index c3c2e7b3d2..b8a7a1338b 100644 --- a/src/dr/evomodel/coalescent/hmc/GMRFGradient.java +++ b/src/dr/evomodel/coalescent/hmc/GMRFGradient.java @@ -1,5 +1,6 @@ package dr.evomodel.coalescent.hmc; +import dr.evolution.coalescent.IntervalEventList; import dr.evolution.coalescent.IntervalList; import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; @@ -291,24 +292,31 @@ private double[] getGradientWrtNodeHeights(UnifiedGMRFLikelihood.SkyGrid likelih final IntervalList intervals = likelihood.getTreeIntervals(0); - int[] intervalIndices = new int[tree.getInternalNodeCount()]; + int[] eventIndices = new int[tree.getInternalNodeCount()]; int[] gridIndices = new int[tree.getInternalNodeCount()]; - getGridIndexForInternalNodes(likelihood, 0, intervalIndices, gridIndices); + // event index of a node is the interval it starts + // for this it should be the interval it ends. + getGridIndexForInternalNodes(likelihood, 0, eventIndices, gridIndices); for (int i = 0; i < tree.getInternalNodeCount(); i++) { NodeRef node = tree.getNode(i + tree.getExternalNodeCount()); final int nodeIndex = getNodeHeightParameterIndex(node, tree); - - final int numLineage = intervals.getLineageCount(intervalIndices[i]); + final int precedingIntervalIndex = eventIndices[i] - 1; + final int numIncomingLineages; + if(precedingIntervalIndex == -1) { + numIncomingLineages =0 ; + } else { + numIncomingLineages = intervals.getLineageCount(precedingIntervalIndex); + } final double currentPopSize = Math.exp(-currentGamma[gridIndices[nodeIndex]]); - gradient[nodeIndex] += -currentPopSize * numLineage * (numLineage - 1); + gradient[nodeIndex] += -currentPopSize * numIncomingLineages * (numIncomingLineages - 1); if (!tree.isRoot(node)) { - final int nextNumLineage = intervals.getLineageCount(intervalIndices[i] + 1); + final int nextNumLineage = intervals.getLineageCount(precedingIntervalIndex + 1); gradient[nodeIndex] -= -currentPopSize * nextNumLineage * (nextNumLineage - 1); } } @@ -326,7 +334,7 @@ private int getNodeHeightParameterIndex(NodeRef node, Tree tree) { } private void getGridIndexForInternalNodes(UnifiedGMRFLikelihood.SkyGrid likelihood, int treeIndex, - int[] intervalIndices, int[] gridIndices) { + int[] eventIndices, int[] gridIndices) { Tree tree = likelihood.getTree(treeIndex); double[] sortedValues = new double[tree.getInternalNodeCount()]; double[] nodeHeights = new double[tree.getInternalNodeCount()]; @@ -335,18 +343,18 @@ private void getGridIndexForInternalNodes(UnifiedGMRFLikelihood.SkyGrid likeliho int gridIndex = 0; double[] gridPoints = likelihood.getGridPoints(); - int intervalIndex = 0; - final IntervalList intervals = likelihood.getTreeIntervals(treeIndex); + int eventIndex = 0; + final IntervalEventList intervals = likelihood.getTreeIntervals(treeIndex); for (int i = 0; i < tree.getInternalNodeCount(); i++) { while(gridIndex < gridPoints.length && gridPoints[gridIndex] < sortedValues[i]) { gridIndex++; } gridIndices[nodeIndices[i]] = gridIndex; - while(intervalIndex < intervals.getIntervalCount() - 1 && intervals.getIntervalTime(intervalIndex) < sortedValues[i]) { - intervalIndex++; + while(eventIndex < intervals.getEventCount() - 1 && intervals.getEventTime(eventIndex) < sortedValues[i]) { + eventIndex++; } - intervalIndices[nodeIndices[i]] = intervalIndex; + eventIndices[nodeIndices[i]] = eventIndex; } } diff --git a/src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java b/src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java index 98362c5d91..4cef785fd6 100644 --- a/src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java +++ b/src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java @@ -275,7 +275,7 @@ void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood l for (int i = 0; i < interval.getIntervalCount(); i++) { if (interval.getIntervalType(i) == IntervalType.COALESCENT) { - final double time = interval.getIntervalTime(i + 1); + final double time = interval.getEventTime(i + 1); // end time of ith interval final int nodeIndex = interval.getNodeNumbersForInterval(i)[1]; currentGridIndex = likelihood.getGridIndex(time, currentGridIndex); final double slope = likelihood.getGridSlope(currentGridIndex); @@ -422,7 +422,7 @@ private double getSingleTreePopulationInverseLogLikelihood(int index) { for (int i = 0; i < interval.getIntervalCount(); i++) { if (interval.getIntervalType(i) == IntervalType.COALESCENT) { - final double time = interval.getIntervalTime(i + 1); + final double time = interval.getEventTime(i + 1); // end time of ith interval currentGridIndex = getGridIndex(time, currentGridIndex); lnL -= getLogPopulationSize(time, currentGridIndex); } @@ -437,7 +437,7 @@ private void updateSingleTreePopulationInverseGradientWrtLogPopSize(int index, d for (int i = 0; i < interval.getIntervalCount(); i++) { if (interval.getIntervalType(i) == IntervalType.COALESCENT) { - final double time = interval.getIntervalTime(i + 1); + final double time = interval.getEventTime(i + 1); // end time of ith interval currentGridIndex = getGridIndex(time, currentGridIndex); updateLogPopSizeDerivative(time, currentGridIndex, gradient); }