Skip to content

Commit

Permalink
Fix initialization of ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
insuhan committed Mar 9, 2022
1 parent 9dc3536 commit e13d9d9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 19 deletions.
Empty file added experimental/__init__.py
Empty file.
33 changes: 19 additions & 14 deletions experimental/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@
import neural_tangents
from neural_tangents import stax

from pkg_resources import parse_version
if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
from neural_tangents._src.utils import utils, dataclasses
from neural_tangents._src.stax.linear import _pool_kernel, Padding
from neural_tangents._src.stax.linear import _Pooling as Pooling
else:
from neural_tangents.utils import utils, dataclasses
from neural_tangents.stax import _pool_kernel, Padding, Pooling

from sketching import TensorSRHT2, PolyTensorSRHT
# from pkg_resources import parse_version
# if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'):
# from neural_tangents._src.utils import utils, dataclasses
# from neural_tangents._src.stax.linear import _pool_kernel, Padding
# from neural_tangents._src.stax.linear import _Pooling as Pooling
# else:
# from neural_tangents.utils import utils, dataclasses
# from neural_tangents.stax import _pool_kernel, Padding, Pooling
from neural_tangents._src.utils import dataclasses
# from neural_tangents._src.utils.typing import Optional
from typing import Optional
from neural_tangents._src.stax.linear import _pool_kernel, Padding
from neural_tangents._src.stax.linear import _Pooling as Pooling

from experimental.sketching import TensorSRHT2, PolyTensorSRHT
""" Implementation for NTK Sketching and Random Features """


Expand Down Expand Up @@ -50,13 +55,13 @@ def kappa1(x):

@dataclasses.dataclass
class Features:
nngp_feat: np.ndarray
ntk_feat: np.ndarray
nngp_feat: Optional[np.ndarray] = None
ntk_feat: Optional[np.ndarray] = None

batch_axis: int = dataclasses.field(pytree_node=False)
channel_axis: int = dataclasses.field(pytree_node=False)

replace = ... # type: Callable[..., 'Features']
replace = ...


def _inputs_to_features(x: np.ndarray,
Expand All @@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray,
nngp_feat = x / x.shape[channel_axis]**0.5
ntk_feat = np.empty((), dtype=nngp_feat.dtype)

return Features(nngp_feat=nngp_feat,
return Features.replace(nngp_feat=nngp_feat,
ntk_feat=ntk_feat,
batch_axis=batch_axis,
channel_axis=channel_axis)
Expand Down
6 changes: 3 additions & 3 deletions experimental/sketching.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs):
degree = len(coeffs) - 1
self.degree = degree

self.tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())]
self.tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())]
self.tree_rand_signs: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())]
self.tree_rand_inds: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())]
rng1, rng2, rng3 = random.split(rng, 3)

ske_dim_ = sketch_dim // 4
Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs):
def sketch(self, x):
n = x.shape[0]
log_degree = len(self.tree_rand_signs)
V = [0 for i in range(log_degree)]
V: Optional[np.ndarray] = [0 for i in range(log_degree)]
E1 = np.concatenate((np.ones(
(n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)),
1)
Expand Down
2 changes: 1 addition & 1 deletion experimental/test_fc_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
config.update("jax_enable_x64", True)
from neural_tangents import stax

from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial
from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial

seed = 1
n, d = 6, 4
Expand Down
2 changes: 1 addition & 1 deletion experimental/test_myrtle_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jax import random

from neural_tangents import stax
from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features
from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features

layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
width = 1
Expand Down

0 comments on commit e13d9d9

Please sign in to comment.