Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a logabsdet_jacobian method #3

Open
torfjelde opened this issue Nov 8, 2021 · 19 comments
Open

Add a logabsdet_jacobian method #3

torfjelde opened this issue Nov 8, 2021 · 19 comments

Comments

@torfjelde
Copy link

Should we also have a method logabsdet_jacobian which is equivalent to Base.Fix(getindex, 2) ∘ with_logabsdet_jacobian, i.e. it only computes the logabsdet-jacobian term?

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

You mean

logabsdet_jacobian(f, x) = with_logabsdet_jacobian(f, x)[2]

as a convenience? Hm, I'm not so sure - writing with_logabsdet_jacobian(f, x)[2] is not much longer, and we have lot's of function of the style "return primary value plus something additional" in the ecosystem, people are used to it.

Is there a use case beyond convenience?

@torfjelde
Copy link
Author

torfjelde commented Nov 8, 2021

It creates overhead when broadcasting:

julia> f(x) = (2x, 3x)
f (generic function with 1 method)

julia> f1(x) = 2x
f1 (generic function with 1 method)

julia> f2(x) = 3x
f2 (generic function with 1 method)

julia> xs = randn(100_000);

julia> @benchmark $f1.($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   70.233 μs    1.699 ms  ┊ GC (min  max):  0.00%  79.00%
 Time  (median):      97.174 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   117.961 μs ± 108.078 μs  ┊ GC (mean ± σ):  10.99% ± 10.97%

  ▄██▆▅▃▂▂▁▁▁                                                   ▂
  ██████████████▇▇▆▅▄▃▄▃▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▅▆▇▆▆▆▆ █
  70.2 μs       Histogram: log(frequency) by time        812 μs <

 Memory estimate: 781.33 KiB, allocs estimate: 2.

julia> @benchmark $f2.($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   63.100 μs    1.900 ms  ┊ GC (min  max):  0.00%  64.76%
 Time  (median):      96.844 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   117.607 μs ± 108.086 μs  ┊ GC (mean ± σ):  11.08% ± 11.01%

   ▆█▆▅▄▃▂▁▁▁                                                   ▂
  ▇████████████▇▆▆▆▅▃▃▅▄▁▁▁▁▁▁▁▁▁▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▆▆▆▆▅▆▆ █
  63.1 μs       Histogram: log(frequency) by time        821 μs <

 Memory estimate: 781.33 KiB, allocs estimate: 2.

julia> @benchmark $(Base.Fix2(getindex, 1)  f).($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   86.023 μs    1.877 ms  ┊ GC (min  max): 0.00%  81.82%
 Time  (median):     128.063 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   155.207 μs ± 121.756 μs  ┊ GC (mean ± σ):  9.08% ± 10.73%

   ▄█▄                                                           
  ▃████▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂ ▃
  86 μs            Histogram: frequency by time          889 μs <

 Memory estimate: 781.33 KiB, allocs estimate: 2.

julia> @benchmark $(Base.Fix2(getindex, 2)  f).($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):   86.414 μs    1.591 ms  ┊ GC (min  max): 0.00%  84.73%
 Time  (median):     120.063 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   142.219 μs ± 110.132 μs  ┊ GC (mean ± σ):  9.23% ± 10.71%

  ▃▇█▇▅▄▃▃▂▂▁                                                   ▂
  █████████████▇▇█▆▆▅▅▅▃▃▁▄▁▄▁▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▄▆▆▇▇▆▇ █
  86.4 μs       Histogram: log(frequency) by time        838 μs <

 Memory estimate: 781.33 KiB, allocs estimate: 2.

This might also be worse when combined with AD, etc.

It came to mind because in Bijectors.jl we also want to define multivariate versions of the many transformations, e.g. elementwise log, in which case broadcasting is convenient and thus Tuple as a result is not 😕

EDIT: Also, in general, you might want to just compute the logabsdet-jacobian term, in which case you'll end up doing unnecessary computation in with_absdetlog_jacobian. This is not such an issue for log, etc. but for more involved transformations, this is an issue.

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

Also, in general, you might want to just compute the logabsdet-jacobian term

I would have thought that scenario very uncommon, needing only the ladj of a trafo but not the result. But if there are use cases, it would make sense to support it directly. Would you have an example or two (we could also add that to the docs then, maybe)?

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

Regarding broadcasting overhead and autodiff, that's an interesting question. We do have broadcasting support in ChangesOfVariables, but we didn't really benchmark it so far. Let's use a broadcased-log as a test case, with a simple "loss function" (just dot):

using ChangesOfVariables, LinearAlgebra, Zygote, BenchmarkTools

logabsdet_jacobian(f, x) = with_logabsdet_jacobian(f, x)[2]


function foo(xs)
    ys = log.(xs)
    ladj = sum(logabsdet_jacobian.(log, xs))
    dot(ys, ys) + ladj
end

grad_foo(xs) = Zygote.gradient(foo, xs)


function bar(xs)
    ys, ladj = with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs)
    dot(ys, ys) + ladj
end

grad_bar(xs) = Zygote.gradient(bar, xs)

Benchmarking-wise, I get

julia> xs = rand(10^3);

julia> @benchmark foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  11.530 μs   1.776 ms  ┊ GC (min  max): 0.00%  97.74%
 Time  (median):     13.678 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.012 μs ± 17.694 μs  ┊ GC (mean ± σ):  1.24% ±  0.98%

   ▁         ▂▂ ▁▃▂▄▅▆▄██▅▄█▇▃▂▂▂▁                            ▂
  ██▇▇▆▄▅▅▄▁▆██████████████████████▆▁▃▁▁▁▃▁▁▄▄▄▅▃▅▃▅▅▃▅▃▄▃▃▁▃ █
  11.5 μs      Histogram: log(frequency) by time      17.2 μs <

 Memory estimate: 15.88 KiB, allocs estimate: 2.

julia> @benchmark bar($xs)
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
 Range (min  max):  7.894 μs  393.329 μs  ┊ GC (min  max): 0.00%  95.37%
 Time  (median):     8.907 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.813 μs ±  16.138 μs  ┊ GC (mean ± σ):  8.48% ±  5.03%

           ▁▃▆██▆▄                                             
  ▁▁▁▂▂▃▄▆▇████████▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  7.89 μs         Histogram: frequency by time        12.6 μs <

 Memory estimate: 31.62 KiB, allocs estimate: 3.

julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  32.637 μs    2.148 ms  ┊ GC (min  max):  0.00%  95.02%
 Time  (median):     37.748 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   44.188 μs ± 102.325 μs  ┊ GC (mean ± σ):  12.09% ±  5.11%

      ▂▄▆██▇▇▆▄▄▃▂▁▁▁▁           ▁                             ▂
  ▄▄▅▇██████████████████▇▆▇▇▆▆▄▆▇████▆▆▄▃▄▄▆▄▄▅▄▃▃▃▃▄▁▁▃▄▁▄▁▁▄ █
  32.6 μs       Histogram: log(frequency) by time      67.3 μs <

 Memory estimate: 129.48 KiB, allocs estimate: 83.

julia> @benchmark grad_bar($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  41.127 μs    2.765 ms  ┊ GC (min  max):  0.00%  95.62%
 Time  (median):     47.307 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   61.313 μs ± 140.068 μs  ┊ GC (mean ± σ):  16.32% ±  6.97%

   ▃▆██▆▄▄▃▃▂▂▂▂▁▁                                             ▂
  ▇█████████████████▇▇▇▇▇▆▇▆▆▆▇▆▅▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▄▅▆▆▅▄▅▄▅▃▄▄▂▅ █
  41.1 μs       Histogram: log(frequency) by time       126 μs <

 Memory estimate: 248.20 KiB, allocs estimate: 99.

So using the with_logabsdet_jacobian(Base.Fix1(broadcast, log), xs) is significantly faster than calculating ys and ladj separately here. Not surprising in this case, since the calculating both at the same time is very efficient for log. With Zygote, on the other hand, the path via logabsdet_jacobian is a bit faster here. Hard to draw general conclusion for other trafos from this of course.

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

Turns out that with a custom pullback for _with_ladj_on_mapped_pullback

using ChainRulesCore

function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
    ys, ladj = ChainRulesCore.unthunk(thunked_ΔΩ)
    NoTangent(), NoTangent(), broadcast(x -> (x, ladj), ys)
end

function ChainRulesCore.rrule(::typeof(ChangesOfVariables._with_ladj_on_mapped), map_or_bc::Function, y_with_ladj)
    return ChangesOfVariables._with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
end

we can make AD on with_logabsdet_jacobian(Base.Fix1(broadcast, f), x) significantly faster:

julia> @benchmark foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  11.493 μs   2.092 ms  ┊ GC (min  max): 0.00%  98.32%
 Time  (median):     13.088 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   14.712 μs ± 21.011 μs  ┊ GC (mean ± σ):  1.40% ±  0.98%

  ▅▂▃▃▅▇█▃▂     ▂▄▅▆▅▅▅▄▃▂                                    ▂
  █████████▅▄▃▄████████████▇▅▆▆▅▆▆▇▆████▇█▇▇▅▅▅▃▃▃▁▃▃▃▃▁▁▁▅▃▅ █
  11.5 μs      Histogram: log(frequency) by time      25.9 μs <

 Memory estimate: 15.88 KiB, allocs estimate: 2.

julia> @benchmark bar($xs)
BenchmarkTools.Trial: 10000 samples with 4 evaluations.
 Range (min  max):  7.452 μs  449.738 μs  ┊ GC (min  max): 0.00%  95.81%
 Time  (median):     8.578 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   9.610 μs ±  15.170 μs  ┊ GC (mean ± σ):  8.11% ±  5.05%

   ▂▅▆▇▇██▇▆▄▂▁                          ▁                    ▂
  ▆█████████████▇▇▅▅▄▅▅▁▄▄▆▆▇█▇▇▇▅▅▅▅▃▃▇███▇▆▆████▆▃▅▅▁▅▆▃▅▄▅ █
  7.45 μs      Histogram: log(frequency) by time      17.1 μs <

 Memory estimate: 31.62 KiB, allocs estimate: 3.

julia> @benchmark grad_foo($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  29.964 μs    2.371 ms  ┊ GC (min  max):  0.00%  93.51%
 Time  (median):     37.055 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   43.938 μs ± 103.986 μs  ┊ GC (mean ± σ):  12.38% ±  5.15%

      ▂▄▇███▇▆▄▄▃▃▂▁▁       ▁                                  ▂
  ▄▁▁▆█████████████████▇▇▇████▇█▇▇██▇▆▆▇▆▇▇▆▅▃▃▅▃▆▆▅▄▅▆▃▅▆▄▆▅▆ █
  30 μs         Histogram: log(frequency) by time      77.2 μs <

 Memory estimate: 129.48 KiB, allocs estimate: 83.

julia> @benchmark grad_bar($xs)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  22.926 μs   2.078 ms  ┊ GC (min  max):  0.00%  97.35%
 Time  (median):     27.223 μs              ┊ GC (median):     0.00%
 Time  (mean ± σ):   33.296 μs ± 90.284 μs  ┊ GC (mean ± σ):  14.57% ±  5.30%

    ▃▆██▇▅▄▂▂▂▁                                               ▂
  ▃▇█████████████▇▆▇▇▄▆▇██▇▆▆▆▅▄▅▄▄▃▃▄▄▄▃▂▄▄▄▄▅▅▅▃▄▄▃▄▄▄▃▃▄▃▅ █
  22.9 μs      Histogram: log(frequency) by time      68.8 μs <

 Memory estimate: 136.44 KiB, allocs estimate: 50.

@devmotion , this would have quite an impact (almost a factor two speedup in my simple example), but it would mean adding ChainRulesCore as a dependency. What do you think? It would make ChangesOfVariables itself less lightweight - but on the other hand, maybe pretty much every package that we'd hope would support ChangesOfVariables will depend on ChainRulesCore already anyway (directly or indirectly)?

@devmotion
Copy link
Member

It came to mind because in Bijectors.jl we also want to define multivariate versions of the many transformations, e.g. elementwise log, in which case broadcasting is convenient and thus Tuple as a result is not

It can be inconvenient to work with tuples but in fact it is very common in the AD/ChainRules setting but also eg in Functors or ParameterHandling. It is common, for instance, to just collect the outputs separately from the resulting array of tuples with map(first, ...) and map(last, ...) (one example is the rrule for sum(f, xs): https://github.com/JuliaDiff/ChainRules.jl/blob/edf3a1f48fb5c9af01820aeca6ced94d4f97fa1a/src/rulesets/Base/mapreduce.jl#L66). BTW I think generally here one should prefer the more idiomatic first and last, and not use getindex.

I am not sure if the benchmark example is completely representative. Clearly, if the ladj and the output can be computed separately, one can save computations by only computing the ladj if the output is irrelevant (it seems though that even in the simple 2x, 3x example it does not reduce the computation time by 50%, as one might expect). However, I wonder how common it is that one is only interested in the ladj and not the transformed values as well. And how independent the computations of both components have to be that it is worth computing them separately and worth the additional code complexity and having to maintain a separate implementation.

Maybe an API such as the experimental derivatives_given_output (https://github.com/JuliaDiff/ChainRulesCore.jl/blob/99d56b145bb4829931c542e720a015d938efeee4/src/rule_definition_tools.jl#L158) could help to split up the computation of the transformed value and the ladj. But, of course, one would still have to compute the transformed value. And in cases where intermediate results in the computation of the transformed value could be reused in the computation of the ladj, such a separation would result in a less efficient implementation. Thus I don't it should be part of the API, which in turn then makes it difficult to use it in downstream packages (one could add a fallback to last(with_log..., neglecting the trsnsformed value, but this seems even less efficient than the current state since one might compute the transformed value twice).

If possible, an alternative for mapping/broadcasting could be to use StructArrays.

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

I have to admit I'm also not convinced that there will be many use cases that don't require the result of the transformation at all.

@devmotion, how would you feel about a custom pullback for the internal _with_ladj_on_mapped though? The gain there is real and substantial, and shouldn't depend on the trafo (we simple save a lot of useless AD). I tried to spin this function several ways, but didn't find a variant that could come even close to a custom pullback. There's just a limit to how smart AD can be, I guess, at least the one we have. :-)

@torfjelde
Copy link
Author

torfjelde commented Nov 8, 2021

I have to admit I'm also not convinced that there will be many use cases that don't require the result of the transformation at all.

The logabsdet_jacobian implementation is usually different for a function f and it's inverse, sometimes one direction is more efficient than the other, sometimes it's more numerically stable, etc., and so it could be ambiguous which way you want to do things.

Sometimes the API is just in such a way that you don't need the transformed variable, e.g.:

https://github.com/TuringLang/Bijectors.jl/blob/31b1c387ed6a243e02fd906cb615b61de47b935f/src/Bijectors.jl#L152-L158

And as mentioned, sometimes you don't actually need anything from the "forward" evaluation, i.e. it's a completely separate computation.

Even if we don't encourage people to implement this, IMO we should at least have a default implementation, i.e. just logabsdet_jacobian(f, x) = last(with_logabsdet_jacobian(f, x)).

Or let me put it like this: why shouldn't we have a logabsdet_jacobian?:) Even if the use-cases aren't "many", I think if there are a few, and adding it has zero cost, we should just add it, no?

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

Let's discuss the custom rrule cost/benfit (not directly related to the question of adding logabsdet_jacobian) here: #4

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

The logabsdet_jacobian implementation is usually different for a function f and it's inverse, sometimes one direction is more efficient than the other, sometimes it's more numerically stable

Oh, sure! But the implementation of with_logabsdet_jacobian always has both x and y available, so the idea would be that whoever codes it up chooses the optimal ladj calculation (in terms of speed and stability). I expect that quite often that will mean that both the function and the inverse use similar ladj code. But in at least one direction, that code will often be able to profit from (intermediate) results of the primary calculation - and sometimes in both directions, I expect.

But you do have a point - there are cases where the ladj-calculation does not share any code with the function and it's inverse, and ladj is easier to calculate in one direction than the other. So there, people would anyway write something like

_logabsdet_jacobian_only(::typeof(myfunc), x) = ...

with_logabsdet_jacobian(::typeof(myfunc), x) = myfunc(x), _logabsdet_jacobian_only(myfunc, x)

function with_logabsdet_jacobian(::typeof(inv_myfunc), y)
    x = inv_myfunc(y)
    return x, - _logabsdet_jacobian_only(myfunc, x)
end

I know I have code like that in some places. Might as well give _logabsdet_jacobian_only an official name, and logabsdet_jacobian makes sense.

Sometimes the API is just in such a way that you don't need the transformed variable, e.g.:
https://github.com/TuringLang/Bijectors.jl/blob/31b1c387ed6a243e02fd906cb615b61de47b935f/src/Bijectors.jl#L152-L158

I'm not sure if I understand that code correctly - the transformation not applied at all in that use case (even before logpdf_with_trans is called)?

why shouldn't we have a logabsdet_jacobian

Hm, you do have a point there. I'm still not sure about the use case on the end-user side - but then, our AD frameworks all offer gradient even though they always calculate the result as well. :-) And it makes sense on the ladj-rule-implementor's side in some cases (see above).

So I'm not against adding

logabsdet_jacobian(f, x) = last(with_logabsdet_jacobian(f, x))

If we do though, it should be clearly documented that people always have to specialize with_logabsdet_jacobian and that specializing logabsdet_jacobian should be reserved for special cases.

@devmotion
Copy link
Member

Wouldn't it be suffcient to add implementations of

function (::ComposedFunction{typeof(last),typeof(with_logabsdet_jacobian)})(f::MyFunction, x)
...
end

if last \circ with_logabsdet_jacobian can be optimized for MyFunction? And I guess one could maybe even in ChangesOfVariables define

(::ComposedFunction{typeof(first),typeof(with_logabsdet_jacobian)})(f, x) = f(x)

Then one could exploit optimizations of first \circ with_logabsdet_jacobian and last \circ with_logabsdet_jacobian, if one is only interested in one of the outputs, without having to add a special API.

@devmotion
Copy link
Member

I think even if a logabsdet_jacobian is added it could be useful to keep the additional structure of ComposedFunction and define it as

const logabsdet_jacobian = last \circ with_logabsdet_jacobian

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

Wouldn't it be suffcient to add implementations of ... function (::ComposedFunction{typeof(last),

I think it would work, but it's not exactly very readable or convenient. What would be the advantage of taking the const logabsdet_jacobian = ... approach?

It would not play well with use cases like

logabsdet_jacobian(::typeof(myfunc), x) = ...

with_logabsdet_jacobian(::typeof(myfunc), x) = myfunc(x), logabsdet_jacobian(myfunc, x)

function with_logabsdet_jacobian(::typeof(inv_myfunc), y)
    x = inv_myfunc(y)
    return x, - logabsdet_jacobian(myfunc, x)
end

at least (typical use case if ladj is better calculated in one direction and there's no advantage in calculating myfunc and it's ladj together (no shared code / synergy).

@devmotion
Copy link
Member

What would be the advantage of taking the const logabsdet_jacobian = ... approach?

That the optimized implementation would also be used if someone calls (last \circ with_logabsdet_jacobian)(f, x). It does not enforce any specific implementation of logabsdet_jacobian and only adds the natural default definition. Specializations as in your example would still be possible.

@oschulz
Copy link
Collaborator

oschulz commented Nov 8, 2021

That the optimized implementation would also be used if someone calls

Sure, but only in that case. Most people would probably use an anonymous function, I don't think we'd get a lot of opportunistic optimization that way.

Specializations as in your example would still be possible.

They are, but especially if one want to implement a specialized logabsdet_jabobian with a custom rrule, it'll be quite awkward and the code will be harder to read.

@devmotion
Copy link
Member

Even if most people use logabsdet_jacobian directly, isn't it the natural definition? It's just a composition of with_logabsdet_jacobian and last. You don't have to know this definition and can just write

function ChainRulesCore.rrule(::typeof(logabsdet_jacobian), ::MyFunction, x)
...
end

In fact, this is what should be recommended to users anyway. It's merely an implementation detail.

@oschulz
Copy link
Collaborator

oschulz commented Nov 9, 2021

You don't have to know this definition and can just write ... ChainRulesCore.rrule ...

You're right, for the rrule it's fine. But for the use case I mentioned, using logabsdet_jacobian as part of defining with_logabsdet_jacobian a const wouldn't work, it would have to be an actual function. That use case is actually what would make adding logabsdet_jacobian seem more attractive to me in the first place - it makes such code neater.

@torfjelde
Copy link
Author

Though I agree that it's possible to overload just the composition, I'm not a big fan due to readability 🤷

@oschulz
Copy link
Collaborator

oschulz commented Dec 14, 2021

I think I found a way of doing this so that users can either define with_logabsdet_jacobian or logabsdet_jacobian or both, while still preventing endless recursion (resp. a stack overflow).

Related to TuringLang/Bijectors.jl#212.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants