Skip to content

Commit

Permalink
Python set sorts dict keys. Use list-comprehension to find difference…
Browse files Browse the repository at this point in the history
…s in keys
  • Loading branch information
GStechschulte committed Nov 1, 2023
1 parent ce93b35 commit be0a641
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions bambi/interpret/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def log_interpret_defaults(func):
"""
Decorator for functions that compute default values.
Logs outpout to console if 'bmb.interpret.logger.messages = True' and when
Logs output to console if 'bmb.interpret.logger.messages = True' and when
default values are computed for the variable of interest, i.e., 'contrast'
or 'wrt' of 'comparisons' and 'slopes', as well as the 'conditional'
parameter of 'comparisons', 'predictions', and 'slopes'.
Expand All @@ -17,29 +17,29 @@ def wrapper(*args, **kwargs):
if not logger.messages:
return func(*args, **kwargs)

name_key = None
term_name = None
arg_name = None
covariate_name = None

if func.__name__ in ["set_default_values", "make_group_panel_values"]:
data_dict = kwargs.get("data_dict", args[1])
keys_before = set(data_dict.keys())
keys_after = set(func(*args, **kwargs).keys())
term_name = ", ".join(keys_after - keys_before)
keys_before = list(data_dict.keys())
keys_after = list(func(*args, **kwargs).keys())
covariate_name = ", ".join([key for key in keys_after if key not in keys_before])

if len(term_name) > 0:
name_key = "unspecified" if func.__name__ == "set_default_values" else "group/panel"
if len(covariate_name) > 0:
arg_name = "unspecified" if func.__name__ == "set_default_values" else "group/panel"

elif func.__name__ == "make_main_values":
term_name = args[1]
name_key = "main"
covariate_name = args[1]
arg_name = "main"

elif func.__name__ == "set_default_variable_values":
variables = {"comparisons": "contrast", "slopes": "wrt"}
name_key = variables.get(args[0].kind)
term_name = args[0].name
arg_name = variables.get(args[0].kind)
covariate_name = args[0].name

if name_key:
interpret_logger.info("Default computed for %s variable: %s", name_key, term_name)
if arg_name:
interpret_logger.info("Default computed for %s variable: %s", arg_name, covariate_name)

return func(*args, **kwargs)

Expand Down

0 comments on commit be0a641

Please sign in to comment.