Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register custom __repr__ functions outside of their classes #3393

Open
danijar opened this issue Jun 20, 2024 · 2 comments
Open

Register custom __repr__ functions outside of their classes #3393

danijar opened this issue Jun 20, 2024 · 2 comments

Comments

@danijar
Copy link

danijar commented Jun 20, 2024

TLDR: Overriding __repr__ of Python types is not always possible (and if it does, may change library behavior unexpectedly), so it would be valuable if rich provided a way to set custom formatters. For example, using a custom __repr__ for Numpy arrays is currently impossible and would become possible with this feature.

How would you improve Rich?

Add a function to register custom representation functions with rich, e.g.:

rich.pretty.set_custom_formatter(condition, formatter)

Example usage:

import rich
import rich.pretty

def array_formatter(x):
  if x.size <= 100:
    return repr(x)
  return f'{type(x)}(shape={x.shape} dtype={x.dtype})'

rich.pretty.set_custom_formatter(
    lambda x: hasattr(x, 'shape') and hasattr(x, 'dtype'),
    array_formatter)

import numpy as np
import jax.numpy as jnp

obj = {'foo': np.zeros((10, 10, 10)), 'bar': jnp.ones((1000,))}
rich.pretty.pprint(obj)
# {
#    'foo': np.ndarray(shape=(10, 10, 10), dtype=float64),
#    'bar': jax.Array(shape=(1000,), dtype=float32),
# }

Or equivalently:

@functools.partial(rich.pretty.set_custom_formatter, condition)
def formatter(x):
  ...

What problem does it solve for you?

We're using rich.traceback(show_locals=True) in a large code base that makes heavy use of Numpy and JAX. While it's generally very helpful, it frequently prints arrays that take up a whole screen height in the stack trace. As a result, we find ourselves repeatedly commenting the rich traceback hook in and out.

For our own Python classes, we could just implement __repr__ or __rich_repr__. However, for Python objects of external dependencies, we'd have to monkey-patch these, which poses the rich of changing behavior in unexpected ways. More importantly, the methods cannot be overridden for np.ndarray or other types implemented in C unless the library provides a mechanism for that (which Numpy doesn't).

Moreover, monkey-patching a solution in rich.pretty from the outside is difficult, because I believe the function that needs to be changed is to_repr() inside traverse() in rich/pretty.py, and changing a nested function from the outside does not work (because it is redefined each time the outer function runs).

The suggested feature would enable users to adjust the tracebacks (and pretty printing in general) to their needs, including formatting of Numpy arrays which is currently not possible, and without risk of affecting code behavior.

Copy link

Thank you for your issue. Give us a little time to review it.

PS. You might want to check the FAQ if you haven't done so already.

This is an automated reply, generated by FAQtory

@leogott
Copy link

leogott commented Sep 8, 2024

Do you need any info from these specific locals? If so, then you can stop reading here.

But if not, there's already this mechanism for excluding certain locals from the traceback by name:

rich/rich/traceback.py

Lines 436 to 441 in 22c2cff

for key, value in iter_locals:
if locals_hide_dunder and key.startswith("__"):
continue
if locals_hide_sunder and key.startswith("_"):
continue
yield key, value

that could be extended with a new show_locals_predicate (or something) argument to traceback.install .
(Alternatively widen the type of show_locals to bool | Callable)

# ... L440
if not show_locals_predicate(key, value):
    continue

That would feel pretty clean to me.
This predicate could be of type

Callable[str, obj] -> bool

Which would even allow you to do something like

def predicate(name, ref):
    try:
        return ref.__module__ not in ('numpy', 'jax' )
    except AttributeError:
        return True

edit: Ofc it could also be a hide_locals_predicate, which might make more sense, now that I think about it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants