Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Taggable functors #41

Open
darsnack opened this issue May 4, 2022 · 4 comments
Open

Taggable functors #41

darsnack opened this issue May 4, 2022 · 4 comments

Comments

@darsnack
Copy link
Member

darsnack commented May 4, 2022

In Flux, we have trainable to designate a subset of leaves as nodes to walk when updating parameters for training. In FluxPrune.jl, I defined pruneable to designate a subset of leaves for pruning (note that these cannot be the same as the trainable nodes).

Right now this creates an unfortunate circumstance as discussed in FluxML/Flux.jl#1946. Users need to @functor their types, remember to define trainable if necessary. Potentially, to use FluxPrune.jl, they might want to remember to define pruneable. On the developer side of things, we can use the walk keyword of fmap to walk the differently labeled leaf nodes. But this usually requires defining a separate walk function based on the subset that you are hoping to target.

An alternative would be to build this information directly into what @functor defines. Right now, each child of a functor has a name and a value. I propose adding "tags" which would be a tuple of symbols. Then we could do something like

@functor Conv trainable=(weight, bias) pruneable=(weight,)

Ideally, this mechanism should be dynamic, meaning that if Flux.jl already defines the trainable leaves of a type, then another package like FluxPrune.jl should be able to add a pruneable tag on top of that.

My hope is that we make it easier on users by only having one line for making your type Flux-compatible. And we make it easier on developers by making it easy to filter nodes when walking by tag. I haven't spent a lot of time on the implementation aspect, but I just wanted to float the notion of tags first and get some feedback.

@ToucheSir
Copy link
Member

Definitely thought about this before. The part I got stuck on was where/how to store this tag metadata. Do you have any proposals there?

@darsnack
Copy link
Member Author

I was thinking that functor should be broken up. Instead we can have children(x::T), rebuilder(::Type{T}), and tags(::Type{T}). Of course, the convenience macro, @functor, would define all three. I haven't thought too deeply about a simple convenience syntax for the macro that would support the initial declaration + adding more tags.

children would return the named tuple that is returned by functor, and rebuilder would return the function that puts the struct back together. tags would return a named tuple similar to children but instead of actual values for each key, it stores a tuple of symbols corresponding to the tags (defaulting to empty).

@ToucheSir
Copy link
Member

I like the idea of splitting things up. How would e.g. trainable it in under this design? @flexiblefunctor is another question mark. rebuilder sounds vaguely like ProjectTo, perhaps there's something we could learn from that too.

@mcabbott
Copy link
Member

Maybe @functor T defines functor(::Type{T}, x, ::Val), and @functor T trainable=(x,y) defines that and also functor(::Type{T}, x, ::Val{:trainable}).

Since it's a weird macro anyway, perhaps it can check for the existence of the simplest method before defining (so as not to over-write). Or some notation like @functor T +pruneable=(w,) could tell it to define only the ::Val{: pruneable} method?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants