Skip to content

Commit

Permalink
updating indexes in tests flagged by ci
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Nov 23, 2024
1 parent 07cb1fb commit 9a48919
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 23 deletions.
12 changes: 6 additions & 6 deletions src/dr/evolution/coalescent/CoalescentGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ private double[] getGradientLogDensityWrtNodeHeights() {
}

IntervalList intervals = likelihood.getIntervalList();
BigFastTreeIntervals bigFastTreeIntervals = (BigFastTreeIntervals) intervals;
TreeIntervalList bigFastTreeIntervals = (TreeIntervalList) intervals; // TODO should not be BFT specific

DemographicFunction demographicFunction = likelihood.getDemoModel().getDemographicFunction();

int numSameHeightNodes = 1;
double thisGradient = 0;
for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); i++) {
if (bigFastTreeIntervals.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = bigFastTreeIntervals.getIntervalTime(i + 1);
for (int i = 0; i < bigFastTreeIntervals.getIntervalCount(); i++) {
if (bigFastTreeIntervals.getIntervalType(i) == IntervalType.COALESCENT) {
final double endTime = bigFastTreeIntervals.getEventTime(i + 1);
final int lineageCount = bigFastTreeIntervals.getLineageCount(i);
final double kChoose2 = Binomial.choose2(lineageCount);
final double intensityGradient = demographicFunction.getIntensityGradient(time);
thisGradient += demographicFunction.getLogDemographicGradient(time);
final double intensityGradient = demographicFunction.getIntensityGradient(endTime);
thisGradient += demographicFunction.getLogDemographicGradient(endTime);

if (bigFastTreeIntervals.getInterval(i) != 0) {
thisGradient -= kChoose2 * intensityGradient;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

package dr.evomodel.coalescent;

import dr.evolution.coalescent.IntervalEventList;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervals;
Expand Down Expand Up @@ -879,7 +880,7 @@ public Tree getTree(int nt) {
return treeList.get(nt);
}

public IntervalList getTreeIntervals(int nt) {
public IntervalEventList getTreeIntervals(int nt) {
return intervalsList.get(nt);
}

Expand Down
3 changes: 2 additions & 1 deletion src/dr/evomodel/coalescent/UnifiedGMRFLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.util.List;

import dr.evolution.coalescent.IntervalEventList;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.TreeIntervalList;
import dr.evolution.tree.Tree;
Expand Down Expand Up @@ -76,7 +77,7 @@ interface SkyGrid extends Skyride{
double[] getNumCoalEvents();
Tree getTree(int nt);

IntervalList getTreeIntervals(int nt);
IntervalEventList getTreeIntervals(int nt);

double getPopulationFactor(int nt);

Expand Down
32 changes: 20 additions & 12 deletions src/dr/evomodel/coalescent/hmc/GMRFGradient.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package dr.evomodel.coalescent.hmc;

import dr.evolution.coalescent.IntervalEventList;
import dr.evolution.coalescent.IntervalList;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
Expand Down Expand Up @@ -291,24 +292,31 @@ private double[] getGradientWrtNodeHeights(UnifiedGMRFLikelihood.SkyGrid likelih

final IntervalList intervals = likelihood.getTreeIntervals(0);

int[] intervalIndices = new int[tree.getInternalNodeCount()];
int[] eventIndices = new int[tree.getInternalNodeCount()];
int[] gridIndices = new int[tree.getInternalNodeCount()];

getGridIndexForInternalNodes(likelihood, 0, intervalIndices, gridIndices);
// event index of a node is the interval it starts
// for this it should be the interval it ends.
getGridIndexForInternalNodes(likelihood, 0, eventIndices, gridIndices);

for (int i = 0; i < tree.getInternalNodeCount(); i++) {
NodeRef node = tree.getNode(i + tree.getExternalNodeCount());

final int nodeIndex = getNodeHeightParameterIndex(node, tree);

final int numLineage = intervals.getLineageCount(intervalIndices[i]);
final int precedingIntervalIndex = eventIndices[i] - 1;
final int numIncomingLineages;
if(precedingIntervalIndex == -1) {
numIncomingLineages =0 ;
} else {
numIncomingLineages = intervals.getLineageCount(precedingIntervalIndex);
}

final double currentPopSize = Math.exp(-currentGamma[gridIndices[nodeIndex]]);

gradient[nodeIndex] += -currentPopSize * numLineage * (numLineage - 1);
gradient[nodeIndex] += -currentPopSize * numIncomingLineages * (numIncomingLineages - 1);

if (!tree.isRoot(node)) {
final int nextNumLineage = intervals.getLineageCount(intervalIndices[i] + 1);
final int nextNumLineage = intervals.getLineageCount(precedingIntervalIndex + 1);
gradient[nodeIndex] -= -currentPopSize * nextNumLineage * (nextNumLineage - 1);
}
}
Expand All @@ -326,7 +334,7 @@ private int getNodeHeightParameterIndex(NodeRef node, Tree tree) {
}

private void getGridIndexForInternalNodes(UnifiedGMRFLikelihood.SkyGrid likelihood, int treeIndex,
int[] intervalIndices, int[] gridIndices) {
int[] eventIndices, int[] gridIndices) {
Tree tree = likelihood.getTree(treeIndex);
double[] sortedValues = new double[tree.getInternalNodeCount()];
double[] nodeHeights = new double[tree.getInternalNodeCount()];
Expand All @@ -335,18 +343,18 @@ private void getGridIndexForInternalNodes(UnifiedGMRFLikelihood.SkyGrid likeliho

int gridIndex = 0;
double[] gridPoints = likelihood.getGridPoints();
int intervalIndex = 0;
final IntervalList intervals = likelihood.getTreeIntervals(treeIndex);
int eventIndex = 0;
final IntervalEventList intervals = likelihood.getTreeIntervals(treeIndex);
for (int i = 0; i < tree.getInternalNodeCount(); i++) {
while(gridIndex < gridPoints.length && gridPoints[gridIndex] < sortedValues[i]) {
gridIndex++;
}
gridIndices[nodeIndices[i]] = gridIndex;

while(intervalIndex < intervals.getIntervalCount() - 1 && intervals.getIntervalTime(intervalIndex) < sortedValues[i]) {
intervalIndex++;
while(eventIndex < intervals.getEventCount() - 1 && intervals.getEventTime(eventIndex) < sortedValues[i]) {
eventIndex++;
}
intervalIndices[nodeIndices[i]] = intervalIndex;
eventIndices[nodeIndices[i]] = eventIndex;
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/dr/evomodel/coalescent/smooth/SkyGlideLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ void updateSingleTreePopulationInverseGradientWrtNodeHeight(SkyGlideLikelihood l

for (int i = 0; i < interval.getIntervalCount(); i++) {
if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = interval.getIntervalTime(i + 1);
final double time = interval.getEventTime(i + 1); // end time of ith interval
final int nodeIndex = interval.getNodeNumbersForInterval(i)[1];
currentGridIndex = likelihood.getGridIndex(time, currentGridIndex);
final double slope = likelihood.getGridSlope(currentGridIndex);
Expand Down Expand Up @@ -422,7 +422,7 @@ private double getSingleTreePopulationInverseLogLikelihood(int index) {

for (int i = 0; i < interval.getIntervalCount(); i++) {
if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = interval.getIntervalTime(i + 1);
final double time = interval.getEventTime(i + 1); // end time of ith interval
currentGridIndex = getGridIndex(time, currentGridIndex);
lnL -= getLogPopulationSize(time, currentGridIndex);
}
Expand All @@ -437,7 +437,7 @@ private void updateSingleTreePopulationInverseGradientWrtLogPopSize(int index, d

for (int i = 0; i < interval.getIntervalCount(); i++) {
if (interval.getIntervalType(i) == IntervalType.COALESCENT) {
final double time = interval.getIntervalTime(i + 1);
final double time = interval.getEventTime(i + 1); // end time of ith interval
currentGridIndex = getGridIndex(time, currentGridIndex);
updateLogPopSizeDerivative(time, currentGridIndex, gradient);
}
Expand Down

0 comments on commit 9a48919

Please sign in to comment.