Skip to content

Commit

Permalink
first draft for Alex, need to think of how to include Transform
Browse files Browse the repository at this point in the history
  • Loading branch information
xji3 committed Nov 15, 2024
1 parent ef8fbc5 commit c58707a
Showing 1 changed file with 134 additions and 18 deletions.
152 changes: 134 additions & 18 deletions src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,64 +27,136 @@

package dr.evomodel.treedatalikelihood;

import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.*;
import dr.inference.operators.hmc.NumericalHessianFromGradient;
import dr.math.MultivariateFunction;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import dr.xml.*;
import dr.math.NumericalDerivative;

import static dr.math.matrixAlgebra.ReadableVector.Utils.setParameter;

/**
* @author Alexander Fisher
*/

public class ApproximateTreeDataLikelihood {
public class ApproximateTreeDataLikelihood extends AbstractModelLikelihood {
private double marginalLikelihood;
private double[] parameterMAP;
private MaximizerWrtParameter maximizer;
private double[] numericalHessian;
private Parameter parameter;
private Likelihood likelihood;
private boolean likelihoodKnown = false;
private final HessianWrtParameterProvider hessianWrtParameterProvider;


// begin parser stuff
public static final String APPROXIMATE_LIKELIHOOD = "approximateTreeDataLikelihood";

// end parser stuff
public ApproximateTreeDataLikelihood(MaximizerWrtParameter maximizer) {
super(APPROXIMATE_LIKELIHOOD);

this.maximizer = maximizer;
this.likelihood = maximizer.getLikelihood();
this.parameter = maximizer.getGradient().getParameter();
this.numericalHessian = new double[parameter.getDimension()];
// todo: get Numerical Hessian.
// NumericalDerivative.getNumericalHessian();
this.marginalLikelihoodConst = Math.log(2) - parameter.getDimension() / 2 *Math.log(Math.PI);
// todo: get Numerical Hessian.
if (likelihood instanceof HessianWrtParameterProvider) {
this.hessianWrtParameterProvider = (HessianWrtParameterProvider) likelihood;
} else if (likelihood instanceof GradientWrtParameterProvider) {
this.hessianWrtParameterProvider = new NumericalHessianFromGradient((GradientWrtParameterProvider) likelihood);
} else {
this.hessianWrtParameterProvider = constructHessian();
}
updateParameterMAP();
updateMarginalLikelihood();
addVariable(parameter);
}

private HessianWrtParameterProvider constructHessian() {

final MultivariateFunction function = new MultivariateFunction() {
@Override
public double evaluate(double[] argument) {

setParameter(new WrappedVector.Raw(argument), parameter);
return getLogLikelihood();
}

@Override
public int getNumArguments() {
return parameter.getDimension();
}

@Override
public double getLowerBound(int n) {
return Double.NEGATIVE_INFINITY;
}

@Override
public double getUpperBound(int n) {
return Double.POSITIVE_INFINITY;
}
};

return new HessianWrtParameterProvider() {

@Override
public Likelihood getLikelihood() {
return likelihood;
}

@Override
public Parameter getParameter() {
return parameter;
}

@Override
public int getDimension() {
return parameter.getDimension();
}

@Override
public double[] getGradientLogDensity() {
return getGradientLogDensity();
}

@Override
public double[] getDiagonalHessianLogDensity() {
return NumericalDerivative.diagonalHessian(function, parameter.getParameterValues());
}

@Override
public double[][] getHessianLogDensity() {
return NumericalDerivative.getNumericalHessian(function, parameter.getParameterValues());
}
};
}

private void updateMarginalLikelihood() {
double[] diagonalHessian = hessianWrtParameterProvider.getDiagonalHessianLogDensity();
double diagonalDeterminant = 1;
for (int i = 0; i < parameter.getDimension(); i++) {
diagonalDeterminant *= numericalHessian[i];
diagonalDeterminant *= Math.abs(diagonalHessian[i]);
}
// 2pi^{-k/2} * det(Sigma)^{-1/2} * likelihood(map) * prior(map)
// todo: eval posterior(map)
// todo: log likelihood
this.marginalLikelihood = 2 / (Math.pow(Math.PI, -1 * parameter.getDimension() / 2) * Math.sqrt(diagonalDeterminant));
this.marginalLikelihood = marginalLikelihoodConst + 0.5 * Math.log(diagonalDeterminant) + likelihood.getLogLikelihood();
likelihoodKnown = true;
}

private void updateParameterMAP() {
this.parameterMAP = maximizer.getMinimumPoint(true);
}
private final double marginalLikelihoodConst;

public double getMarginalLikelihood() {
return marginalLikelihood;
private void updateParameterMAP() {
maximizer.maximize();
}

public double[] getParameterMAP() {
return parameterMAP;
}

// **************************************************************
// XMLObjectParser
Expand Down Expand Up @@ -127,4 +199,48 @@ public Class getReturnType() {
new ElementRule(MaximizerWrtParameter.class)
};
};

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) {

}

@Override
protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) {

}

@Override
protected void storeState() {

}

@Override
protected void restoreState() {

}

@Override
protected void acceptState() {

}

@Override
public Model getModel() {
return null;
}

@Override
public double getLogLikelihood() {
if (!likelihoodKnown) {
updateParameterMAP();
updateMarginalLikelihood();
}
return marginalLikelihood;
}

@Override
public void makeDirty() {
likelihoodKnown = false;
}
}

0 comments on commit c58707a

Please sign in to comment.