diff --git a/examples/cifar/inspect_factors.py b/examples/cifar/inspect_factors.py index 8684512..9c78059 100644 --- a/examples/cifar/inspect_factors.py +++ b/examples/cifar/inspect_factors.py @@ -9,16 +9,12 @@ def main(): logging.basicConfig(level=logging.INFO) name = "ekfac" - factor = ( - Analyzer.load_file(f"influence_results/cifar10/factors_{name}/activation_covariance.safetensors") - ) + factor = Analyzer.load_file(f"influence_results/cifar10/factors_{name}/activation_covariance.safetensors") plt.matshow(factor["6.0"]) plt.show() - factor = ( - Analyzer.load_file(f"influence_results/cifar10/factors_{name}/gradient_covariance.safetensors") - ) + factor = Analyzer.load_file(f"influence_results/cifar10/factors_{name}/gradient_covariance.safetensors") plt.matshow(factor["6.0"]) plt.show()