diff --git a/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java b/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java index c15a8d1fd6..6d71b4adab 100644 --- a/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java +++ b/src/dr/evomodel/treedatalikelihood/ApproximateTreeDataLikelihood.java @@ -27,24 +27,31 @@ package dr.evomodel.treedatalikelihood; +import dr.evolution.tree.Tree; import dr.evomodel.treedatalikelihood.discrete.MaximizerWrtParameter; import dr.inference.hmc.GradientWrtParameterProvider; -import dr.inference.model.Likelihood; -import dr.inference.model.Parameter; +import dr.inference.hmc.HessianWrtParameterProvider; +import dr.inference.model.*; +import dr.inference.operators.hmc.NumericalHessianFromGradient; +import dr.math.MultivariateFunction; +import dr.math.matrixAlgebra.WrappedVector; import dr.util.Transform; import dr.xml.*; import dr.math.NumericalDerivative; +import static dr.math.matrixAlgebra.ReadableVector.Utils.setParameter; + /** * @author Alexander Fisher */ -public class ApproximateTreeDataLikelihood { +public class ApproximateTreeDataLikelihood extends AbstractModelLikelihood { private double marginalLikelihood; - private double[] parameterMAP; private MaximizerWrtParameter maximizer; - private double[] numericalHessian; private Parameter parameter; + private Likelihood likelihood; + private boolean likelihoodKnown = false; + private final HessianWrtParameterProvider hessianWrtParameterProvider; // begin parser stuff @@ -52,39 +59,104 @@ public class ApproximateTreeDataLikelihood { // end parser stuff public ApproximateTreeDataLikelihood(MaximizerWrtParameter maximizer) { + super(APPROXIMATE_LIKELIHOOD); this.maximizer = maximizer; + this.likelihood = maximizer.getLikelihood(); this.parameter = maximizer.getGradient().getParameter(); - this.numericalHessian = new double[parameter.getDimension()]; - // todo: get Numerical Hessian. -// NumericalDerivative.getNumericalHessian(); + this.marginalLikelihoodConst = Math.log(2) - parameter.getDimension() / 2 *Math.log(Math.PI); + // todo: get Numerical Hessian. + if (likelihood instanceof HessianWrtParameterProvider) { + this.hessianWrtParameterProvider = (HessianWrtParameterProvider) likelihood; + } else if (likelihood instanceof GradientWrtParameterProvider) { + this.hessianWrtParameterProvider = new NumericalHessianFromGradient((GradientWrtParameterProvider) likelihood); + } else { + this.hessianWrtParameterProvider = constructHessian(); + } updateParameterMAP(); updateMarginalLikelihood(); + addVariable(parameter); + } + + private HessianWrtParameterProvider constructHessian() { + + final MultivariateFunction function = new MultivariateFunction() { + @Override + public double evaluate(double[] argument) { + + setParameter(new WrappedVector.Raw(argument), parameter); + return getLogLikelihood(); + } + + @Override + public int getNumArguments() { + return parameter.getDimension(); + } + + @Override + public double getLowerBound(int n) { + return Double.NEGATIVE_INFINITY; + } + + @Override + public double getUpperBound(int n) { + return Double.POSITIVE_INFINITY; + } + }; + return new HessianWrtParameterProvider() { + + @Override + public Likelihood getLikelihood() { + return likelihood; + } + + @Override + public Parameter getParameter() { + return parameter; + } + + @Override + public int getDimension() { + return parameter.getDimension(); + } + + @Override + public double[] getGradientLogDensity() { + return getGradientLogDensity(); + } + + @Override + public double[] getDiagonalHessianLogDensity() { + return NumericalDerivative.diagonalHessian(function, parameter.getParameterValues()); + } + + @Override + public double[][] getHessianLogDensity() { + return NumericalDerivative.getNumericalHessian(function, parameter.getParameterValues()); + } + }; } private void updateMarginalLikelihood() { + double[] diagonalHessian = hessianWrtParameterProvider.getDiagonalHessianLogDensity(); double diagonalDeterminant = 1; for (int i = 0; i < parameter.getDimension(); i++) { - diagonalDeterminant *= numericalHessian[i]; + diagonalDeterminant *= Math.abs(diagonalHessian[i]); } // 2pi^{-k/2} * det(Sigma)^{-1/2} * likelihood(map) * prior(map) // todo: eval posterior(map) // todo: log likelihood - this.marginalLikelihood = 2 / (Math.pow(Math.PI, -1 * parameter.getDimension() / 2) * Math.sqrt(diagonalDeterminant)); + this.marginalLikelihood = marginalLikelihoodConst + 0.5 * Math.log(diagonalDeterminant) + likelihood.getLogLikelihood(); + likelihoodKnown = true; } - private void updateParameterMAP() { - this.parameterMAP = maximizer.getMinimumPoint(true); - } + private final double marginalLikelihoodConst; - public double getMarginalLikelihood() { - return marginalLikelihood; + private void updateParameterMAP() { + maximizer.maximize(); } - public double[] getParameterMAP() { - return parameterMAP; - } // ************************************************************** // XMLObjectParser @@ -127,4 +199,48 @@ public Class getReturnType() { new ElementRule(MaximizerWrtParameter.class) }; }; + + @Override + protected void handleModelChangedEvent(Model model, Object object, int index) { + + } + + @Override + protected void handleVariableChangedEvent(Variable variable, int index, Parameter.ChangeType type) { + + } + + @Override + protected void storeState() { + + } + + @Override + protected void restoreState() { + + } + + @Override + protected void acceptState() { + + } + + @Override + public Model getModel() { + return null; + } + + @Override + public double getLogLikelihood() { + if (!likelihoodKnown) { + updateParameterMAP(); + updateMarginalLikelihood(); + } + return marginalLikelihood; + } + + @Override + public void makeDirty() { + likelihoodKnown = false; + } } \ No newline at end of file