Skip to content

Commit

Permalink
Simplify running KerasNLP with Keras 3 (keras-team#1308)
Browse files Browse the repository at this point in the history
* Simplify running KerasNLP with Keras 3

We should not land this until Keras 3, TensorFlow 2.15, and
keras-nlp-nightly are released.

* Address comments

* Tweaks

* Add link

* fix link
  • Loading branch information
mattdangerw committed Dec 7, 2023
1 parent 9f94b02 commit 3d3a211
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 92 deletions.
69 changes: 59 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
![Tensorflow](https://img.shields.io/badge/tensorflow-v2.5.0+-success.svg)
[![contributions welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat)](https://github.com/keras-team/keras-nlp/issues)

KerasNLP is a natural language processing library that works natively
with TensorFlow, JAX, or PyTorch. Built on [Keras Core](https://keras.io/keras_core/announcement/),
these models, layers, metrics, callbacks, etc., can be trained and serialized
in any framework and re-used in another without costly migrations. See "Using
KerasNLP with Keras Core" below for more details on multi-framework KerasNLP.
KerasNLP is a natural language processing library that works natively
with TensorFlow, JAX, or PyTorch. Built on Keras 3, these models, layers,
metrics, and tokenizers can be trained and serialized in any framework and
re-used in another without costly migrations.

KerasNLP supports users through their entire development cycle. Our workflows
are built from modular components that have state-of-the-art preset weights and
Expand Down Expand Up @@ -47,17 +46,28 @@ We are a new and growing project and welcome [contributions](CONTRIBUTING.md).

## Installation

To install the latest official release:
KerasNLP supports both Keras 2 and Keras 3. We recommend Keras 3 for all new
users, as it enables using KerasNLP models and layers with JAX, TensorFlow and
PyTorch.

### Keras 2 Installation

To install the latest KerasNLP release with Keras 2, simply run:

```
pip install keras-nlp --upgrade
pip install --upgrade keras-nlp
```

To install the latest unreleased changes to the library, we recommend using
pip to install directly from the master branch on github:
### Keras 3 Installation

There are currently two ways to install Keras 3 with KerasNLP. To install the
stable versions of KerasNLP and Keras 3, you should install Keras 3 **after**
installing KerasNLP. This is a temporary step while TensorFlow is pinned to
Keras 2, and will no longer be necessary after TensorFlow 2.16.

```
pip install git+https://github.com/keras-team/keras-nlp.git --upgrade
pip install --upgrade keras-nlp
pip install --upgrade keras>=3
```
## Using KerasNLP with Keras Core

Expand Down Expand Up @@ -88,12 +98,28 @@ Until Keras Core is officially released as Keras 3.0, KerasNLP will use
`.keras/keras_nlp.json`. You will need to restart the Python runtime for changes
to take effect.

To install the latest nightly changes for both KerasNLP and Keras, you can use
our nightly package.

```
pip install --upgrade keras-nlp-nightly
```

> [!IMPORTANT]
> Keras 3 will not function with TensorFlow 2.14 or earlier.
Read [Getting started with Keras](https://keras.io/getting_started/) for more information
on installing Keras 3 and compatibility with different frameworks.

## Quickstart

Fine-tune BERT on a small sentiment analysis task using the
[`keras_nlp.models`](https://keras.io/api/keras_nlp/models/) API:

```python
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "jax" or "torch"!

import keras_nlp
import tensorflow_datasets as tfds

Expand All @@ -116,6 +142,29 @@ classifier.predict(["What an amazing movie!", "A total waste of my time."])

For more in depth guides and examples, visit https://keras.io/keras_nlp/.

## Configuring your backend

If you have Keras 3 installed in your environment (see installation above),
you can use KerasNLP with any of JAX, TensorFlow and PyTorch. To do so, set the
`KERAS_BACKEND` environment variable. For example:

```shell
export KERAS_BACKEND=jax
```

Or in Colab, with:

```python
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras_nlp
```

> [!IMPORTANT]
> Make sure to set the `KERAS_BACKEND` before import any Keras libraries, it
> will be used to set up Keras when it is first imported.
## Compatibility

We follow [Semantic Versioning](https://semver.org/), and plan to
Expand Down
16 changes: 9 additions & 7 deletions keras_nlp/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
"""
Keras backend module.
This module adds a temporarily Keras API surface that is fully under KerasNLP
control. This allows us to switch between `keras_core` and `tf.keras`, as well
as add shims to support older version of `tf.keras`.
This module adds a temporary Keras API surface that is fully under KerasNLP
control. The goal is to allow us to write Keras 3-like code everywhere, while
still supporting Keras 2. We do this by using the `keras_core` package with
Keras 2 to backport Keras 3 numerics APIs (`keras.ops` and `keras.random`) into
Keras 2. The sub-modules exposed are as follows:
- `config`: check which backend is being run.
- `keras`: The full `keras` API (via `keras_core` or `tf.keras`).
- `ops`: `keras_core.ops`, always tf backed if using `tf.keras`.
- `random`: `keras_core.random`, always tf backed if using `tf.keras`.
- `config`: check which version of Keras is being run.
- `keras`: The full `keras` API with compat shims for older Keras versions.
- `ops`: `keras.ops` for Keras 3 or `keras_core.ops` for Keras 2.
- `random`: `keras.random` for Keras 3 or `keras_core.ops` for Keras 2.
"""

from keras_nlp.backend import config
Expand Down
73 changes: 12 additions & 61 deletions keras_nlp/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

_MULTI_BACKEND = False
_USE_KERAS_3 = False

# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if "KERAS_HOME" in os.environ:
_keras_dir = os.environ.get("KERAS_HOME")
else:
_keras_base_dir = os.path.expanduser("~")
if not os.access(_keras_base_dir, os.W_OK):
_keras_base_dir = "/tmp"
_keras_dir = os.path.join(_keras_base_dir, ".keras")

# Attempt to read KerasNLP config file.
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_nlp.json"))
if os.path.exists(_config_path):
try:
with open(_config_path) as f:
_config = json.load(f)
except ValueError:
_config = {}
_MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND)

# Save config file, if possible.
if not os.path.exists(_keras_dir):
try:
os.makedirs(_keras_dir)
except OSError:
# Except permission denied and potential race conditions
# in multi-threaded environments.
pass

if not os.path.exists(_config_path):
_config = {
"multi_backend": _MULTI_BACKEND,
}
try:
with open(_config_path, "w") as f:
f.write(json.dumps(_config, indent=4))
except IOError:
# Except permission denied.
pass

# If KERAS_BACKEND is set in the environment use multi-backend keras.
if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
_MULTI_BACKEND = True


def detect_if_tensorflow_uses_keras_3():
# We follow the version of keras that tensorflow is configured to use.
Expand All @@ -84,29 +36,28 @@ def detect_if_tensorflow_uses_keras_3():


_USE_KERAS_3 = detect_if_tensorflow_uses_keras_3()
if _USE_KERAS_3:
_MULTI_BACKEND = True

if not _USE_KERAS_3:
backend = os.environ.get("KERAS_BACKEND")
if backend and backend != "tensorflow":
raise RuntimeError(
"When running Keras 2, the `KERAS_BACKEND` environment variable "
f"must either be unset or `'tensorflow'`. Received: `{backend}`. "
"To set another backend, please install Keras 3. See "
"https://github.com/keras-team/keras-nlp#installation"
)


def keras_3():
"""Check if Keras 3 is being used."""
return _USE_KERAS_3


def multi_backend():
"""Check if multi-backend Keras is enabled."""
return _MULTI_BACKEND


def backend():
"""Check the backend framework."""
if not multi_backend():
return "tensorflow"
if not keras_3():
import keras_core

return keras_core.config.backend()
return "tensorflow"

from tensorflow import keras
import keras

return keras.config.backend()
2 changes: 0 additions & 2 deletions keras_nlp/backend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

if config.keras_3():
from keras import * # noqa: F403, F401
elif config.multi_backend():
from keras_core import * # noqa: F403, F401
else:
from tensorflow.keras import * # noqa: F403, F401

Expand Down
12 changes: 9 additions & 3 deletions keras_nlp/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,26 @@ def pytest_collection_modifyitems(config, items):
not run_extra_large_tests,
reason="need --run_extra_large option to run",
)
skip_tf_only = pytest.mark.skipif(
tf_only = pytest.mark.skipif(
not backend_config.backend() == "tensorflow",
reason="tests only run on tf backend",
)
keras_3_only = pytest.mark.skipif(
not backend_config.keras_3(),
reason="tests only run on with multi-backend keras",
)
for item in items:
if "large" in item.keywords:
item.add_marker(skip_large)
if "extra_large" in item.keywords:
item.add_marker(skip_extra_large)
if "tf_only" in item.keywords:
item.add_marker(skip_tf_only)
item.add_marker(tf_only)
if "keras_3_only" in item.keywords:
item.add_marker(keras_3_only)


# Disable traceback filtering for quicker debugging of tests failures.
tf.debugging.disable_traceback_filtering()
if backend_config.multi_backend():
if backend_config.keras_3():
keras.config.disable_traceback_filtering()
4 changes: 2 additions & 2 deletions keras_nlp/layers/modeling/cached_multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def test_layer_behaviors(self):
expected_output_shape=(2, 4, 6),
expected_num_trainable_weights=8,
expected_num_non_trainable_variables=1,
# tf.keras does not handle mixed precision correctly when not set
# Keras 2 does not handle mixed precision correctly when not set
# globally.
run_mixed_precision_check=config.multi_backend(),
run_mixed_precision_check=config.keras_3(),
)

def test_cache_call_is_correct(self):
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def bold_text(x):
print_fn(console.end_capture(), line_break=False)

# Avoid `tf.keras.Model.summary()`, so the above output matches.
if config.multi_backend():
if config.keras_3():
super().summary(
line_length=line_length,
positions=positions,
Expand Down
10 changes: 5 additions & 5 deletions keras_nlp/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def call(self, x):
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
model.fit(input_data, output_data, verbose=0)

if config.multi_backend():
if config.keras_3():
# Build test.
layer = layer_cls(**init_kwargs)
if isinstance(input_data, dict):
Expand Down Expand Up @@ -205,8 +205,8 @@ def run_class_serialization_test(self, instance):
revived_cfg = revived_instance.get_config()
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
self.assertEqual(cfg_json, revived_cfg_json)
# Dir tests only work on keras-core.
if config.multi_backend():
# Dir tests only work with Keras 3.
if config.keras_3():
self.assertEqual(ref_dir, dir(revived_instance))

# serialization roundtrip
Expand All @@ -218,8 +218,8 @@ def run_class_serialization_test(self, instance):
revived_cfg = revived_instance.get_config()
revived_cfg_json = json.dumps(revived_cfg, sort_keys=True, indent=4)
self.assertEqual(cfg_json, revived_cfg_json)
# Dir tests only work on keras-core.
if config.multi_backend():
# Dir tests only work with Keras 3.
if config.keras_3():
new_dir = dir(revived_instance)[:]
for lst in [ref_dir, new_dir]:
if "__annotations__" in lst:
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/utils/tensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def is_tensor_type(x):


def standardize_dtype(dtype):
if config.multi_backend():
if config.keras_3():
return keras.backend.standardize_dtype(dtype)
if hasattr(dtype, "name"):
return dtype.name
Expand Down

0 comments on commit 3d3a211

Please sign in to comment.