Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created experimental folder and added all NTK sketching codes #142

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
c64f307
add NTK Random Features and Sketching codes
insuhan Feb 24, 2022
d1a9266
Add NTK Random Features and Sketching codes
insuhan Feb 24, 2022
9dc3536
Delete cache files
insuhan Feb 24, 2022
aa12545
Resolve pytype tests
insuhan Mar 9, 2022
12b828e
ntk sketch with polynomial approximation to the end kernel function
Mar 16, 2022
4a5cddd
Fix simple issues from code reviews (v1)
insuhan Mar 16, 2022
932987d
Fix simple issues from code reviews (v2)
insuhan Mar 16, 2022
6a156ed
Fix simple issues from code reviews (v2)
insuhan Mar 16, 2022
9f2a3e5
Automatically preprocess init_fn/feature_fn
insuhan Mar 17, 2022
c2bed91
Update for raw inputs
insuhan Mar 17, 2022
d874cde
Update FlattenFeatures for raw inputs
insuhan Mar 17, 2022
78038f9
changes to the poly sketching alg
Mar 24, 2022
09ea575
fc ntk sketch
Mar 24, 2022
71a0946
poly fitting using jaxopt
Mar 24, 2022
6404c1b
poly fitting minor edit
Mar 24, 2022
5492d68
Make poly_fitting jittable
insuhan Mar 25, 2022
049336d
Fix typo
insuhan Mar 25, 2022
a14f8aa
Edit format of sketching.py
insuhan Mar 25, 2022
5f5ac18
Fix typo in alpha_ computitation
insuhan Mar 25, 2022
7e0f580
Delete unnecessaries
insuhan Mar 25, 2022
391a1b8
Update FC NTK features and check pytype
insuhan Mar 27, 2022
d52ae47
Make jit-able test_fc_ntk.py
insuhan Mar 27, 2022
ec43fc7
Add ReluNTKFeatures (one-pass sketching)
insuhan Mar 28, 2022
e0f5bfe
Reflect comments in PR conversation
insuhan Mar 29, 2022
cb90151
Merge remote-tracking branch 'upstream/main' into NT
insuhan Mar 29, 2022
61bc32d
Add JAXopt package
insuhan Mar 29, 2022
1cfbe5a
Reflect comments in PR conversation (v2)
insuhan Mar 30, 2022
5485e17
Compare ReluNTKFeatures to neural_tangents.empirical_ntk_fn
insuhan Mar 30, 2022
5756f92
Merge remote-tracking branch 'upstream/main' into NT
insuhan Apr 1, 2022
3acf87a
test
amirzandieh Apr 5, 2022
e99f7d0
Merge branch 'NT' of https://github.com/insuhan/neural-tangents into NT
insuhan Apr 5, 2022
07f563d
Extend to ConvFeatures with retangular filter o shape
insuhan Apr 5, 2022
a65d729
Fix complex dtype warning
insuhan Apr 5, 2022
671c27d
Update Cholesky decomposition safely
insuhan Apr 6, 2022
c84ed65
Fix nans issue -- complex data type
insuhan May 9, 2022
50212c2
recover previous commit
insuhan May 9, 2022
52aacb5
Add bias term for DenseFeatures
insuhan Jun 1, 2022
0995169
Fix typo
insuhan Jun 1, 2022
70f53fe
Add bias in ConvFeatures
insuhan Jun 7, 2022
5564dac
Merge remote-tracking branch 'upstream/main' into NT
insuhan Jun 8, 2022
6d709ff
Add aggregate features for graph neural nets
insuhan Jun 8, 2022
a6724cb
Fix features_test
insuhan Jun 8, 2022
420f836
Fix features_test
insuhan Jun 8, 2022
9a14e60
Update dynamic axis
insuhan Jul 13, 2022
e35e551
Update neural-tangents v=0.6.0
insuhan Jul 13, 2022
7f0b5f7
Add setup.py
insuhan Jul 13, 2022
2036f86
Add jaxopt in setup.py
insuhan Jul 13, 2022
af524d2
Change the third argument of init_fn
insuhan Jul 13, 2022
a16098c
Add ReluNTKFeatures test
insuhan Jul 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions experimental/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Efficient Feature Map of Neural Tangent Kernels via Sketching and Random Features

Implementations developed in [[1]](#1-scaling-neural-tangent-kernels-via-sketching-and-random-features). The library is written for users familar with [JAX](https://github.com/google/jax) and [Neural Tangents](https://github.com/google/neural-tangents) library. The codes are compatible with NT v0.5.0.

[PyTorch](https://pytorch.org/) Implementations can be found in [here](https://github.com/insuhan/ntk-sketch-rf).


## Examples

### Fully-connected NTK approximation via Random Features:

```python
from jax import random
from experimental.features import DenseFeatures, ReluFeatures, serial

relufeat_arg = {
'method': 'RANDFEAT',
'feature_dim0': 64,
'feature_dim1': 128,
'sketch_dim': 256,
}

init_fn, feature_fn = serial(
DenseFeatures(512), ReluFeatures(**relufeat_arg),
DenseFeatures(512), ReluFeatures(**relufeat_arg),
DenseFeatures(1)
)

key1, key2 = random.split(random.PRNGKey(1))
x = random.normal(key1, (5, 4))

_, feat_fn_inputs = init_fn(key2, x.shape)
feats = feature_fn(x, feat_fn_inputs)
# feats.nngp_feat is a feature map of NNGP kernel
# feats.ntk_feat is a feature map of NTK
assert feats.nngp_feat.shape == (5, relufeat_arg['feature_dim1'])
assert feats.ntk_feat.shape == (5, relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim'])
```

For more details of fully connected NTK features, please check `test_fc_ntk.py`.

### Convolutional NTK approximation via Random Features:

```python
from experimental.features import ConvFeatures, AvgPoolFeatures, FlattenFeatures

init_fn, feature_fn = serial(
ConvFeatures(512, filter_shape=(3, 3)), ReluFeatures(**relufeat_arg),
AvgPoolFeatures((2, 2), strides=(2, 2)), FlattenFeatures(),
DenseFeatures(512)
)

n, H, W, C = 5, 8, 8, 3
key1, key2 = random.split(random.PRNGKey(1))
x = random.normal(key1, shape=(n, H, W, C))

_, feat_fn_inputs = init_fn(key2, x.shape)
feats = feature_fn(x, feat_fn_inputs)
# feats.nngp_feat is a feature map of NNGP kernel
# feats.ntk_feat is a feature map of NTK
assert feats.nngp_feat.shape == (5, (H/2)*(W/2)*relufeat_arg['feature_dim1'])
assert feats.ntk_feat.shape == (5, (H/2)*(W/2)*(relufeat_arg['feature_dim1'] + relufeat_arg['sketch_dim']))
```
For more complex CNTK features, please check `test_myrtle_networks.py`.

# Modules

All modules return a pair of functions `(init_fn, feature_fn)`. Instead of kernel function `kernel_fn` in [Neural Tangents](https://github.com/google/neural-tangents) library, we replace it with the feature map function `feature_fn`. We do not return `apply_fn` functions.

- `init_fn` takes (1) random seed and (2) input shape. It returns (1) a pair of shapes of both NNGP and NTK features and (2) parameters used for approximating the features (e.g., random vectors for Random Features approach).
- `feature_fn` takes (1) feature structure `features.Feature` and (2) parameters used for feature approximation (initialized by `init_fn`). It returns `features.Feature` including approximate features of the corresponding module.


## [`features.DenseFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L88)
`features.DenseFeatures` provides features for fully-connected dense layer and corresponds to `stax.Dense` module in [Neural Tangents](https://github.com/google/neural-tangents). We assume that the input is a tabular dataset (i.e., a n-by-d matrix). Its `feature_fn` updates the NTK features by concatenating NNGP features and NTK features. This is because `stax.Dense` updates a new NTK kernel matrix `(N x D)` by adding the previous NNGP and NTK kernel matrices. The features of dense layer are exact and no approximations are applied.

```python
from jax import numpy as np
from neural_tangents import stax
from experimental.features import DenseFeatures, serial

width = 1
x = random.normal(key1, shape=(3, 2))
_, _, kernel_fn = stax.Dense(width)
nt_kernel = kernel_fn(x)

_, feat_fn = serial(DenseFeatures(width))
feat = feat_fn(x, ())

assert np.linalg.norm(nt_kernel.nngp - feat.nngp_feat @ feat.nngp_feat.T) <= 1e-12
assert feat.ntk_feat == np.zeros(())
```

## [`features.ReluFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/ea23f8575a61f39c88aa57723408c175dbba0045/features.py#L119)
`features.ReluFeatures` is a key module of the NTK approximation. We implement feature approximations based on (1) Random Features of arc-cosine kernels [[2]](#2) and (2) Polynomial Sketch [[3]](#3). Parameters used for feature approximation are intialized in `init_fn`. We support tabular and image datasets. For tabular dataset, the input features are of form `N x D` matrix and the approximations are applied to the d-dimensional vectors.

For image dataset, the inputs are 4-D tensors with shape `N x H x W x D` where N is batch size, H is image height, W is image width and D is the feature dimension. We reshape the image features into 2-D tensor with shape `NHW x D` and apply proper feature approximations. Then, the resulting features reshape to 4-D tensor with shape `N x H x W x D'` where `D'` is the output dimension of the feature approximation.

To use the Random Features approach, set the parameter `method` to `rf` (default `rf`), e.g.,

```python
from experimental.features import DenseFeatures, ReluFeatures, serial

x = random.normal(key1, shape=(3, 32))

init_fn, feat_fn = serial(
DenseFeatures(1),
ReluFeatures(method='RANDFEAT', feature_dim0=10, feature_dim1=20, sketch_dim=30)
)

_, params = init_fn(key1, x.shape)

out_feat = feat_fn(x, params)

assert out_feat.nngp_feat.shape == (3, 20)
assert out_feat.ntk_feat.shape == (3, 30)
```

To use the exact feature map (based on Cholesky decomposition), set the parameter `method` to `exact`, e.g.,

```python
init_fn, feat_fn = serial(DenseFeatures(1), ReluFeatures(method='exact'))
_, params = init_fn(key1, x.shape)
out_feat = feat_fn(x, params)

assert out_feat.nngp_feat.shape == (3, 3)
assert out_feat.ntk_feat.shape == (3, 3)
```

(This is for debugging. The dimension of the exact feature map is equal to the number of inputs, i.e., `N` for tabular dataset, `NHW` for image dataset).


## [`features.ConvFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L236)

`features.ConvFeatures` is similar to `features.DenseFeatures` as it updates the NTK feature of the next layer by concatenting NNGP and NTK features of the previous one. But, it additionlly requires the kernel pooling operations. Precisely, [[4]](#4) studied that the NNGP/NTK kernel matrices require to compute the trace of submatrix of size `stride_size`. This can be seen as convolution with an identity matrix with size `stride_size`. However, in the feature side, this can be done via concatenating shifted features thus the resulting feature dimension becomes `stride_size` times larger. Moreover, since image datasets are 2-D matrices, the kernel pooling should be applied along with two axes hence the output feature has the shape `N x H x W x (d * filter_size**2)` where `filter_size` is the size of convolution filter and `d` is the input feature dimension.


## [`features.AvgPoolFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L269)

`features.AvgPoolFeatures` operates the average pooling on features of both NNGP and NTK. It calls [`_pool_kernel`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L3143) function in [Neural Tangents](https://github.com/google/neural-tangents) as a subroutine.

## [`features.FlattenFeatures`](https://github.com/insuhan/ntk-sketching-neural-tangents/blob/447cf2f6add6cf9f8374df4ea8530bf73d156c1b/features.py#L304)

`features.FlattenFeatures` makes the features 2-D tensors. Similar to [`Flatten`](https://github.com/google/neural-tangents/blob/dd7eabb718c9e3c6640c47ca2379d93db6194214/neural_tangents/_src/stax/linear.py#L1641) module in [Neural Tangents](https://github.com/google/neural-tangents), the flattened features recale by the square-root of the number of elements. For example, if `nngp_feat` has the shape `N x H x W x C`, it returns a `N x HWC` matrix where all entries are divided by `(H*W*C)**0.5`.


## References
#### [1] [Scaling Neural Tangent Kernels via Sketching and Random Features](https://arxiv.org/pdf/2106.07880.pdf)
#### [2] [Kernel methods for deep learning](https://cseweb.ucsd.edu/~saul/papers/nips09_kernel.pdf)
#### [3] [Oblivious Sketching of High-Degree Polynomial Kernels](https://arxiv.org/pdf/1909.01410.pdf)
#### [4] [On Exact Computation with an Infinitely Wide Neural Net](https://arxiv.org/pdf/1904.11955.pdf)

Empty file added experimental/__init__.py
Empty file.
Loading