Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AD testing utilities #799

Open
wants to merge 2 commits into
base: release-0.35
Choose a base branch
from
Open

Add AD testing utilities #799

wants to merge 2 commits into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 5, 2025

Overview

This is a, perhaps somewhat overdue, PR to add the functionality which I first wrote in https://github.com/penelopeysm/ModelTests.jl.

It provides two main functions:

  • DynamicPPL.TestUtils.AD.ad_ldp(::Model, ::Vector{<:Real}, ::AbstractADType, ::AbstractVarInfo)
  • and DynamicPPL.TestUtils.AD.ad_di (same signature)

which calculate the logdensity and its gradient of a given model at the specified parameters.

The former uses LogDensityProblemsAD.jl; the latter circumvents this and goes straight to DifferentiationInterface.jl. (The varinfo argument is used only to specify the type of varinfo used during the evaluation, its contents are ignored. I wish that there was a cleaner way to specify this, but as far as I can tell it's not possible, especially with SimpleVarInfo which often requires parameters to be initialised inside it.)

There are three auxiliary functions:

  • DynamicPPL.TestUtils.AD.make_function and DynamicPPL.TestUtils.AD.make_params generate a function f and an argument x, such that f(x) evaluates the logdensity of a model at the point x. These can, in theory, be passed to any autodiff library, even those which do not have integrations with LogDensityProblemsAD, DifferentiationInterface, or ADTypes.
  • DynamicPPL.TestUtils.AD.test_correctness provides a quick and easy wrapper to test a model plus a given set of AD backends (using the default VarInfo) for correctness.

Testing

Unfortunately, I didn't manage to make much use of test_correctness in the current DynamicPPL test suite. The main reason is because we are testing all the demo models with pretty much all possible variations of VarInfo.

I have made sure to not change the tests, but I'm not entirely convinced that we need to test AD with different combinations of VarInfo. The reason is because AD is used primarily during sampling, and there isn't really any way to actually call AbstractMCMC.sample on a model (cf. #606) with anything but the default VarInfo.

The use of non-default varinfos is, as far as I can tell, restricted to fairly small sections of the codebase (e.g. the loglikelihood / logjoint / logprior functions), and it's not clear to me that AD is used in any part of that. So, it seems to me that these are orthogonal concerns.

I've left it versatile for now to be on the safe side, but if people agree then I would be very happy to remove the varinfo argument from the functions above.

Miscellaneous bits

The names of the functions can be changed, I'm not super happy with them, but also I've stared at this code for too long so I'm not the best person to suggest names 😉

@penelopeysm penelopeysm changed the base branch from master to release-0.35 February 5, 2025 16:02
Copy link

codecov bot commented Feb 5, 2025

Codecov Report

Attention: Patch coverage is 14.28571% with 18 lines in your changes missing coverage. Please review.

Project coverage is 85.43%. Comparing base (1366440) to head (32ee4bb).

Files with missing lines Patch % Lines
src/test_utils/ad.jl 14.28% 18 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff                @@
##           release-0.35     #799      +/-   ##
================================================
- Coverage         85.78%   85.43%   -0.36%     
================================================
  Files                36       37       +1     
  Lines              4207     4228      +21     
================================================
+ Hits               3609     3612       +3     
- Misses              598      616      +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Feb 5, 2025

Pull Request Test Coverage Report for Build 13162839206

Details

  • 0 of 21 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.4%) to 85.511%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/test_utils/ad.jl 0 21 0.0%
Totals Coverage Status
Change from base Build 13156283797: -0.4%
Covered Lines: 3612
Relevant Lines: 4224

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants