From e13d9d912fc7439a76d05500fd2cff32571b22ac Mon Sep 17 00:00:00 2001 From: insuhan Date: Thu, 10 Mar 2022 00:22:42 +0900 Subject: [PATCH] Fix initialization of ndarray --- experimental/__init__.py | 0 experimental/features.py | 33 +++++++++++++++++------------ experimental/sketching.py | 6 +++--- experimental/test_fc_ntk.py | 2 +- experimental/test_myrtle_network.py | 2 +- 5 files changed, 24 insertions(+), 19 deletions(-) create mode 100644 experimental/__init__.py diff --git a/experimental/__init__.py b/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/experimental/features.py b/experimental/features.py index 196236da..fb19f49e 100644 --- a/experimental/features.py +++ b/experimental/features.py @@ -6,16 +6,21 @@ import neural_tangents from neural_tangents import stax -from pkg_resources import parse_version -if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'): - from neural_tangents._src.utils import utils, dataclasses - from neural_tangents._src.stax.linear import _pool_kernel, Padding - from neural_tangents._src.stax.linear import _Pooling as Pooling -else: - from neural_tangents.utils import utils, dataclasses - from neural_tangents.stax import _pool_kernel, Padding, Pooling - -from sketching import TensorSRHT2, PolyTensorSRHT +# from pkg_resources import parse_version +# if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'): +# from neural_tangents._src.utils import utils, dataclasses +# from neural_tangents._src.stax.linear import _pool_kernel, Padding +# from neural_tangents._src.stax.linear import _Pooling as Pooling +# else: +# from neural_tangents.utils import utils, dataclasses +# from neural_tangents.stax import _pool_kernel, Padding, Pooling +from neural_tangents._src.utils import dataclasses +# from neural_tangents._src.utils.typing import Optional +from typing import Optional +from neural_tangents._src.stax.linear import _pool_kernel, Padding +from neural_tangents._src.stax.linear import _Pooling as Pooling + +from experimental.sketching import TensorSRHT2, PolyTensorSRHT """ Implementation for NTK Sketching and Random Features """ @@ -50,13 +55,13 @@ def kappa1(x): @dataclasses.dataclass class Features: - nngp_feat: np.ndarray - ntk_feat: np.ndarray + nngp_feat: Optional[np.ndarray] = None + ntk_feat: Optional[np.ndarray] = None batch_axis: int = dataclasses.field(pytree_node=False) channel_axis: int = dataclasses.field(pytree_node=False) - replace = ... # type: Callable[..., 'Features'] + replace = ... def _inputs_to_features(x: np.ndarray, @@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray, nngp_feat = x / x.shape[channel_axis]**0.5 ntk_feat = np.empty((), dtype=nngp_feat.dtype) - return Features(nngp_feat=nngp_feat, + return Features.replace(nngp_feat=nngp_feat, ntk_feat=ntk_feat, batch_axis=batch_axis, channel_axis=channel_axis) diff --git a/experimental/sketching.py b/experimental/sketching.py index 29f60cf4..1d21961b 100644 --- a/experimental/sketching.py +++ b/experimental/sketching.py @@ -61,8 +61,8 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs): degree = len(coeffs) - 1 self.degree = degree - self.tree_rand_signs = [0 for i in range((self.degree - 1).bit_length())] - self.tree_rand_inds = [0 for i in range((self.degree - 1).bit_length())] + self.tree_rand_signs: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())] + self.tree_rand_inds: Optional[np.ndarray] = [0 for i in range((self.degree - 1).bit_length())] rng1, rng2, rng3 = random.split(rng, 3) ske_dim_ = sketch_dim // 4 @@ -92,7 +92,7 @@ def __init__(self, rng, input_dim, sketch_dim, coeffs): def sketch(self, x): n = x.shape[0] log_degree = len(self.tree_rand_signs) - V = [0 for i in range(log_degree)] + V: Optional[np.ndarray] = [0 for i in range(log_degree)] E1 = np.concatenate((np.ones( (n, 1), dtype=x.dtype), np.zeros((n, x.shape[-1] - 1), dtype=x.dtype)), 1) diff --git a/experimental/test_fc_ntk.py b/experimental/test_fc_ntk.py index 23e5d575..f045365e 100644 --- a/experimental/test_fc_ntk.py +++ b/experimental/test_fc_ntk.py @@ -6,7 +6,7 @@ config.update("jax_enable_x64", True) from neural_tangents import stax -from features import _inputs_to_features, DenseFeatures, ReluFeatures, serial +from experimental.features import _inputs_to_features, DenseFeatures, ReluFeatures, serial seed = 1 n, d = 6, 4 diff --git a/experimental/test_myrtle_network.py b/experimental/test_myrtle_network.py index 29d34c73..197ef7eb 100644 --- a/experimental/test_myrtle_network.py +++ b/experimental/test_myrtle_network.py @@ -12,7 +12,7 @@ from jax import random from neural_tangents import stax -from features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features +from experimental.features import ReluFeatures, ConvFeatures, AvgPoolFeatures, serial, FlattenFeatures, DenseFeatures, _inputs_to_features layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]} width = 1