From 4f6a145c4ec1a142a929c93142ad61671d8c1b9b Mon Sep 17 00:00:00 2001 From: RobertDeibel Date: Sat, 2 Mar 2024 19:08:08 +0100 Subject: [PATCH] [DOC] Example: Second-order extensions for new layers (#320) * loosen allclose threshold * add second-order extension to example * add image for second-order example * [FMT] Apply `black` and `isort` * [REF] One pass through the tutorial * [FIX] Long lines * Try decreasing atol in ResNet example * Try decreasing atol --------- Co-authored-by: Felix Dangel --- .../use_cases/example_custom_module.py | 530 +++++++++++++++++- docs_src/images/comp_graph.jpg | Bin 0 -> 7641 bytes 2 files changed, 502 insertions(+), 28 deletions(-) create mode 100644 docs_src/images/comp_graph.jpg diff --git a/docs_src/examples/use_cases/example_custom_module.py b/docs_src/examples/use_cases/example_custom_module.py index e59b10c8..1aa8bc2d 100644 --- a/docs_src/examples/use_cases/example_custom_module.py +++ b/docs_src/examples/use_cases/example_custom_module.py @@ -1,19 +1,28 @@ """Custom module example ========================================= -This tutorial shows how to support a custom module in a simple fashion. -We focus on `BackPACK's first-order extensions `_. -They don't backpropagate additional information and thus require less functionality be -implemented. +This tutorial explains how to support new layers in BackPACK. + +We will write a custom module and show how to implement first-order extensions, +specifically :py:class:`BatchGrad `, and second-order +extensions, specifically :py:class:`DiagGGNExact `. Let's get the imports out of our way. """ # noqa: B950 +from typing import Tuple + import torch +from einops import einsum +from torch.nn.utils.convert_parameters import parameters_to_vector from backpack import backpack, extend from backpack.extensions import BatchGrad from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.extensions.module_extension import ModuleExtension +from backpack.extensions.secondorder.diag_ggn import DiagGGNExact +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list # make deterministic torch.manual_seed(0) @@ -32,42 +41,55 @@ class ScaleModule(torch.nn.Module): """Defines the module.""" - def __init__(self, weight=2.0): + def __init__(self, weight: float = 2.0): """Store scalar weight. Args: - weight(float, optional): Initial value for weight. Defaults to 2.0. + weight: Initial value for weight. Defaults to 2.0. """ super(ScaleModule, self).__init__() self.weight = torch.nn.Parameter(torch.tensor([weight])) - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: """Defines forward pass. Args: - input(torch.Tensor): input + input: The layer input. Returns: - torch.Tensor: product of input and weight + Product of input and weight. """ return input * self.weight # %% -# You don't necessarily need to write a custom layer. Any PyTorch layer can be extended -# as described (it should be a :py:class:`torch.nn.Module `'s because -# BackPACK uses module hooks). +# We choose this custom simple layer as its related operations for backpropagation are +# easy to understand. Of course, you don't have to define a new layer if it already +# exists within :py:mod:`torch.nn`. +# +# It is important to understand though that BackPACK relies on module hooks and therefore +# can only be extended on the modular level: If your desired functionality is not a +# :py:class:`torch.nn.Module ` yet, you need to wrap it in a +# :py:class:`torch.nn.Module `. +# +# First-order extensions +# ---------------------- +# First we focus on `BackPACK's first-order extensions +# `_. +# They don't backpropagate additional information and thus require less functionality. # -# Custom module extension -# ----------------------- # Let's make BackPACK support computing individual gradients for ``ScaleModule``. # This is done by the :py:class:`BatchGrad ` extension. # To support the new module, we need to create a module extension that implements -# how individual gradients are extracted with respect to ``ScaleModule``'s parameter. +# how individual gradients are extracted with respect to ``ScaleModule``'s parameter +# called ``weight``. # # The module extension must implement methods named after the parameters passed to the -# constructor. Here it goes. +# constructor (in this case ``weight``). For a module with additional parametes, e.g. a +# ``bias``, an additional method named like the parameter has to be added. +# +# Here it goes. class ScaleModuleBatchGrad(FirstOrderModuleExtension): @@ -75,24 +97,36 @@ class ScaleModuleBatchGrad(FirstOrderModuleExtension): def __init__(self): """Store parameters for which individual gradients should be computed.""" - # specify parameter names super().__init__(params=["weight"]) - def weight(self, ext, module, g_inp, g_out, bpQuantities): + def weight( + self, + ext: BatchGrad, + module: ScaleModule, + g_inp: Tuple[torch.Tensor], + g_out: Tuple[torch.Tensor], + bpQuantities: None, + ) -> torch.Tensor: """Extract individual gradients for ScaleModule's ``weight`` parameter. Args: - ext(BatchGrad): extension that is used - module(ScaleModule): module that performed forward pass - g_inp(tuple[torch.Tensor]): input gradient tensors - g_out(tuple[torch.Tensor]): output gradient tensors - bpQuantities(None): additional quantities for second-order + ext: BackPACK extension that is used. + module: The module that performed forward pass. + g_inp: Input gradient tensors. + g_out: Output gradient tensors. + bpQuantities: The quantity backpropagated for the extension by BackPACK. + ``None`` for ``BatchGrad``. Returns: - torch.Tensor: individual gradients + The per-example gradients w.r.t. to the ``weight`` parameters. + Has shape ``[batch_size, *weight.shape]``. """ - show_useful = True + # The ``BatchGrad`` extension supports considering only a sub-set of + # data in the mini-batch. We will not account for this here for simplicity + # and therefore raise an exception if this feature is active. + assert ext.get_subsampling() is None + show_useful = True if show_useful: print("Useful quantities:") # output is saved under field output @@ -103,10 +137,16 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): print("\tg_out[0].shape: ", g_out[0].shape) # actual computation - return (g_out[0] * module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1) + return einsum(g_out[0], module.input0, "batch d,batch d->batch").unsqueeze(-1) # %% +# +# Note that we have access to the layer's inputs and outputs from the forward pass, as +# they are stored by BackPACK. The computation itself basically +# computes vector-Jacobian-products of the incoming gradient with the layer's +# output-parameter Jacobian for each sample in the batch. +# # Lastly, we need to register the mapping between layer (``ScaleModule``) and layer # extension (``ScaleModuleBatchGrad``) in an instance of # :py:class:`BatchGrad `. @@ -121,8 +161,8 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): # gradients with respect to ``ScaleModule``'s ``weight`` parameter. # %% -# Test custom module -# ------------------ +# Verifying first-order extensions +# -------------------------------- # Here, we verify the custom module extension on a small net with random inputs. # Let's create these. @@ -196,3 +236,437 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): "Individual gradients don't match:" + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}" ) + +# %% +# Second-order extension +# ---------------------- +# Next, we focus on `BackPACK's second-order extensions +# `_. +# They backpropagate additional information and thus require more functionality to be +# implemented and a more in-depth understanding of BackPACK's internals and +# the quantity of interest. +# +# Let's make BackPACK support computing the exact diagonal of the generalized +# Gauss-Newton (GGN) matrix +# (:py:class:`DiagGGNExact `) for ``ScaleModule``. +# +# To do that, we need to write a module extension that implements how the exact +# GGN diagonal is computed for ``ScaleModule``'s parameter called ``weight``. +# Also, we need to implement how information is propagated from the layer's output +# to the layer's input. +# +# We need to understand the following details about +# :py:class:`DiagGGNExact `: +# +# 1. The extension backpropagates a matrix square root factorization of the loss +# function's Hessian w.r.t. its input via vector-Jacobian products. +# 2. To compute the GGN diagonal for a parameter, we need to multiply the incoming +# matrix square root of the GGN with the output-parameter Jacobian of the layer, +# then square it to obtain the GGN, and take its diagonal. +# +# These details vary between different second-order extensions and a good place to get +# started understanding their details is the BackPACK paper. +# +# We now describe the details for the GGN diagonal. +# +# Definition of the GGN +# ^^^^^^^^^^^^^^^^^^^^^ +# +# The GGN is calculated by multiplying the neural network's Jacobian (w.r.t. the +# parameters) with the Hessian of the loss function w.r.t. its prediction, +# +# .. math:: +# \mathbf{G}(\mathbf{\theta}) +# = +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top\; +# \nabla^2_{f_\mathbf{\theta}(x)} \ell (f_\mathbf{\theta}(x, y) \; +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))\,. +# +# The Jacobian (left & right of RHS) is the matrix of all first-order derivatives +# of the function (neural network) w.r.t. the parameters. +# The Hessian (center of RHS) is the matrix of all second-order derivatives of the +# loss function w.r.t. the neural network's output. +# The GGN (LHS) is a matrix with dimension :math:`p \times p` where :math:`p` is the +# number of parameters. Note that in the presence of multiple data (a batch), the GGN +# is a sum/mean over per-sample GGNs. We will focus on the GGN for one sample, but +# also handle the parallel computation over all samples in the batch in the code. +# +# Our goal is to compute the diagonal of that matrix. To do that, we will re-write it +# in terms of a self-outer product as follows: Note that the loss function is convex. +# Let the neural network's prediction be +# :math:`f_\mathbf{\theta}(x) \in \mathbb{R}^C` where :math:`C` is the number of +# classes. Due to the convexity of :math:`\ell`, we can find a symmetric factorization +# of its Hessian, +# +# .. math:: +# \exists \mathbf{S} \in \mathbb{R}^{C \times C} +# \text{ s.t. } +# \mathbf{S} \mathbf{S}^\top +# = +# \nabla^2_{f_\mathbf{\theta}(x)} \ell (f_\mathbf{\theta}(x), y)\,. +# +# For our purposes, we will use a loss that is already supported within BackPACK, +# and there we don't need to be concerned how to compute this factorization. +# +# With that, we can define +# :math:`\mathbf{V}= (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top\;\mathbf{S}` +# and write the GGN as +# +# .. math:: +# \mathbf{G}(\mathbf{\theta}) = \mathbf{V} \mathbf{V}^\top\,. +# +# Instead of computing the GGN, we will compute :math:`\mathbf{V}` by backpropagating +# :math:`\mathbf{S}` via vector-Jacobian products, then square-and-take-the-diagonal +# to obtain the GGN's diagonal. +# +# Backpropagation for the GGN diagonal +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# To break down the multiplication with +# :math:`(\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top` to the per-layer level, +# we will use the chain rule. +# +# Consider the following computation graph, where :math:`x = x^{(0)}`: +# +# .. image:: ../../images/comp_graph.jpg +# :width: 75% +# :align: center +# +# Each node in the graph represents a tensor. The arrows represent the flow of +# information and the computation associated with the incoming and outgoing tensors: +# :math:`f_{\mathbf{\theta}^{(k)}}^{(k)}(x^{(k)}) = x^{(k+1)}`. The intermediates +# correspond to the outputs of the neural network layers. +# +# The parameter vector :math:`\mathbf{\theta}` contains all NN parameters, flattened +# and concatenated over layers, +# +# .. math:: +# \mathbf{\theta} +# = +# \begin{pmatrix} +# \mathbf{\theta}^{(1)} +# \\ +# \mathbf{\theta}^{(2)} +# \\ +# \vdots +# \\ +# \mathbf{\theta}^{(l)} +# \end{pmatrix}\,. +# +# The Jacobian inherits this structure and is a stack of Jacobians of each layer, +# +# .. math:: +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top +# = +# \begin{pmatrix} +# (\mathbf{J}_{\mathbf{\theta}^{(1)}} f_{\mathbf{\theta}}(x))^\top +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(2)}} f_{\mathbf{\theta}}(x))^\top +# \\ +# \vdots +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(l)}} f_\mathbf{\theta}(x))^\top +# \end{pmatrix}\,. +# +# The same holds for the matrix :math:`\mathbf{V}`, +# +# .. math:: +# \mathbf{V} +# = +# \begin{pmatrix} +# \mathbf{V}_{\mathbf{\theta}^{(1)}} +# \\ +# \mathbf{V}_{\mathbf{\theta}^{(2)}} +# \\ +# \vdots +# \\ +# \mathbf{V}_{\mathbf{\theta}^{(l)}} +# \end{pmatrix} +# = +# \begin{pmatrix} +# (\mathbf{J}_{\mathbf{\theta}^{(1)}} f_{\mathbf{\theta}}(x))^\top \mathbf{S} +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(2)}} f_{\mathbf{\theta}}(x))^\top \mathbf{S} +# \\ +# \vdots +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(l)}} f_\mathbf{\theta}(x))^\top \mathbf{S} +# \end{pmatrix}\,. +# +# With the chain rule recursions +# +# .. math:: +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} f_{\mathbf{\theta}}(x))^\top +# = +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} x^{(k)})^\top +# \;(\mathbf{J}_{x^{(k)}} f_{\mathbf{\theta}}(x))^\top +# +# and +# +# .. math:: +# (\mathbf{J}_{x^{(k-1)}} f_{\mathbf{\theta}}(x))^\top +# = +# (\mathbf{J}_{x^{(k-1)}} x^{(k)})^\top +# \;(\mathbf{J}_{x^{(k)}} f_{\mathbf{\theta}}(x))^\top +# +# we can identify the following recursions for the blocks of :math:`\mathbf{V}`: +# +# .. math:: +# \mathbf{V}_{\mathbf{\theta}^{(k)}} +# = +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} x^{(k)})^\top +# \mathbf{V}_{x^{(k)}} +# +# and +# +# .. math:: +# \mathbf{V}_{x^{(k-1)}} +# = +# (\mathbf{J}_{x^{(k-1)}} x^{(k)})^\top +# \mathbf{V}_{x^{(k)}}\,. +# +# The above two recursions are the backpropagations performed by BackPACK's +# :py:class:`DiagGGNExact `. Layer :math:`k` +# receives the backpropagated quantity :math:`\mathbf{V}_{x^{(k)}}`, then +# (i) computes :math:`\mathbf{V}_{\mathbf{\theta}^{(k)}}`, then +# :math:`\mathrm{diag}(\mathbf{V}_{\mathbf{\theta}^{(k)}} +# \mathbf{V}_{\mathbf{\theta}^{(k)}}^\top)`, which is the GGN diagonal for +# the layer's parameters, and (ii) computes :math:`\mathbf{V}_{x^{(k-1)}}` +# which is sent to its parent layer :math:`k-1` which proceeds likewise. +# +# Implementation +# ^^^^^^^^^^^^^^ +# +# Now, let's create a module extension that specifies two methods: +# Step (i) from above is implemented by a function whose name +# matches the layer parameter's name (``weight`` in our case). Step (ii) +# is implemented by a function named ``backpropagate``. + + +class ScaleModuleDiagGGNExact(ModuleExtension): + """Backpropagation through ``ScaleModule`` for computing the GGN diagonal.""" + + def __init__(self): + """Store parameter names for which the GGN diagonal will be computed.""" + super().__init__(params=["weight"]) + + def backpropagate( + self, + ext: DiagGGNExact, + module: ScaleModule, + grad_inp: Tuple[torch.Tensor], + grad_out: Tuple[torch.Tensor], + bpQuantities: torch.Tensor, + ) -> torch.Tensor: + """Propagate GGN diagonal information from layer output to input. + + Args: + ext: The GGN diagonal extension. + module: Layer through which to perform backpropagation. + grad_inp: Input gradients. + grad_out:: Output gradients. + bpQuantities: Backpropagation information. For the GGN diagonal + this is a tensor V of shape ``[C, *module.output.shape]`` where + ``C`` is the neural network's output dimension and the layer's + output shape is typically something like ``[batch_size, D_out]``. + + Returns: + The GGN diagonal's backpropagated quantity V for the layer input. + Has shape ``[C, *layer.input0.shape]``. + """ + # The GGN diagonal extension supports considering only a sub-set of + # data in the mini-batch. We will not account for this here for simplicity + # and therefore raise an exception if this feature is active. + assert ext.get_subsampling() is None + + # Layer: + # - Input to the layer has shape ``[batch_size, D_in]`` + # - Output of the layer has shape ``[batch_size, D_out]`` + + # Loss function: + # - Neural networks prediction has shape ``[batch_size, C]`` + + # Quantity backpropagated by ``DiagGGNExact`` has shape + # ``[C, batch_size, D_out]`` imagine this as a set of ``C`` vectors + # which all have the same shape as the layer's output that represent + # the rows of the incoming V. + + # What we need to to do: + # - Take each of the C vectors + # - Multiply each of them with the layer's output-input Jacobian. + # The result of each VJP will have shape ``[batch_size, D_in]`` + # - Stack them back together into a tensor of shape + # ``[C, batch_size, D_in]`` that represents the outgoing V + + input0 = module.input0 + output = module.output + weight = module.weight + V_out = bpQuantities + + C = V_out.shape[0] + batch_size, D_in = input0.shape + assert V_out.shape == (C, *output.shape) + + show_useful = True + if show_useful: + print("backpropagate: Useful quantities:") + print(f" module.output.shape: {output.shape}") + print(f" module.input.shape: {input0.shape}") + print(f" V_out.shape: {V_out.shape}") + print(f" V_in.shape: {(C, *input0.shape)}") + + V_in = torch.zeros( + (C, batch_size, D_in), device=input0.device, dtype=input0.dtype + ) + + # forward pass computation performs: ``X * weight`` + # (``[batch_size, D_in] * [1] [batch_size, D_out=D_in]``) + for c in range(C): + V_in[c] = bpQuantities[c] * weight + # NOTE We could do this more efficiently with the following: + # V_in = V_out * weight + assert V_in.shape == (C, *input0.shape) + + return V_in + + def weight( + self, + ext: DiagGGNExact, + module: ScaleModule, + g_inp: Tuple[torch.Tensor], + g_out: Tuple[torch.Tensor], + bpQuantities: torch.Tensor, + ) -> torch.Tensor: + """Extract the GGN diagonal for the ``weight`` parameter. + + Args: + ext: The BackPACK extension. + module: Module through which to perform backpropagation. + grad_inp: Input gradients. + grad_out: Output gradients. + bpQuantities: Backpropagation information. For the GGN diagonal + this is a tensor V of shape ``[C, *module.output.shape]`` where + ``C`` is the neural network's output dimension and the layer's + output shape is typically something like ``[batch_size, D_out]``. + + Returns: + The GGN diagonal w.r.t. the layer's ``weight``. + Has shape ``[batch_size, *weight.shape]``. + """ + input0 = module.input0 + output = module.output + weight = module.weight + V_out = bpQuantities + + C = bpQuantities.shape[0] + assert V_out.shape == (C, *output.shape) + + show_useful = True + if show_useful: + print("weight: Useful quantities:") + print(f" module.output.shape {output.shape}") + print(f" module.input.shape {input0.shape}") + print(f" module.weight.shape {weight.shape}") + print(f" bpQuantities.shape {bpQuantities.shape}") + print(f" returned.shape {weight.shape}") + + # forward pass computation performs: ``X * weight`` + # (``[batch_size, D_in] * [1] = [batch_size, D_out]``) + V_theta = einsum(V_out, input0, "c batch d, batch d -> c batch") + # compute diag( V_theta @ V_theta^T ) + weight_ggn_diag = einsum(V_theta, V_theta, "c batch, c batch ->").unsqueeze(0) + + assert weight_ggn_diag.shape == weight.shape + return weight_ggn_diag + + +# %% +# After we have implemented the module extension we need to register the mapping +# between layer (``ScaleModule``) and layer extension (``ScaleModuleDiagGGNExact``) +# in an instance of :py:class:`DiagGGNExact `. + +extension = DiagGGNExact() +extension.set_module_extension(ScaleModule, ScaleModuleDiagGGNExact()) + +# %% +# We can then use this extension to compute the exact GGN diagonal for +# ``ScaleModule``s. +# +# +# Verifying second-order extensions +# --------------------------------- +# +# Here, we verify the custom module extension on a small net with random inputs. +# First, the setup: + +batch_size = 10 +input_size = 4 + +inputs = torch.randn(batch_size, input_size, device=device) +targets = torch.randint(0, 2, (batch_size,), device=device) + +reduction = ["mean", "sum"][1] + +my_module = ScaleModule().to(device) +lossfunc = torch.nn.CrossEntropyLoss(reduction=reduction).to(device) + +# %% +# As ground truth, we compute the GGN diagonal using GGN-vector products +# which exclusively rely on PyTorch's autodiff: +params = list(my_module.parameters()) +ggn_dim = sum(p.numel() for p in params) +diag_ggn_flat = torch.zeros(ggn_dim, device=inputs.device, dtype=inputs.dtype) + +outputs = my_module(inputs) +loss = lossfunc(outputs, targets) + +# compute GGN-vector products with all one-hot vectors +for d in range(ggn_dim): + # create unit vector d + e_d = torch.zeros(ggn_dim, device=inputs.device, dtype=inputs.dtype) + e_d[d] = 1.0 + # convert to list format + e_d = vector_to_parameter_list(e_d, params) + + # multiply GGN onto the unit vector -> get back column d of the GGN + ggn_e_d = ggn_vector_product(loss, outputs, my_module, e_d) + # flatten + ggn_e_d = parameters_to_vector(ggn_e_d) + + # extract the d-th entry (which is on the GGN's diagonal) + diag_ggn_flat[d] = ggn_e_d[d] + +print(f"Tr(GGN): {diag_ggn_flat.sum():.3f}") + +# %% +# Now we can use BackPACK to compute the GGN diagonal: + +my_module = extend(my_module) +lossfunc = extend(lossfunc) + +outputs = my_module(inputs) +loss = lossfunc(outputs, targets) + +with backpack(extension): + loss.backward() + +diag_ggn_flat_backpack = parameters_to_vector( + [p.diag_ggn_exact for p in my_module.parameters()] +) +print(f"Tr(GGN, BackPACK): {diag_ggn_flat_backpack.sum():.3f}") + +# %% +# +# Finally, let's compare the two results. + +match = torch.allclose(diag_ggn_flat, diag_ggn_flat_backpack) +print(f"Do manual and BackPACK GGN match? {match}") + +if not match: + raise AssertionError( + "Exact GGN diagonals do not match:" + + f"\n{diag_ggn_flat}\nvs.\n{diag_ggn_flat_backpack}" + ) + +# %% +# +# That's all for now. diff --git a/docs_src/images/comp_graph.jpg b/docs_src/images/comp_graph.jpg new file mode 100644 index 0000000000000000000000000000000000000000..23d9be235a49c343f628ea49661f2a2bc1c0a430 GIT binary patch literal 7641 zcmd6LcTiN@((fKX0m+~uaX>JVk)#8RNJa$&$w?T3GUOp=L4u?qC?HWuN|rF>paL_4 zOLJlod#%C6qFSJ5C{N3 zgbO&G0Szn5%9^Nas3|BvlK)o&79bFbAOP6fIXh`6KH$*R)8`SX$tI}!pX zva|$%t&ader3U~E-2iaG;IB3U`wwH|AT+TO#$`{qtN?jMrWLuKqTi$&WP;%`E%#V$u3+VCnG1nKuJS+fr6TXoSce|ikgO&mX7uU zB|QT@Edv45o;d-XSrU_w5e#W5$SDZY|4Tak0??5GF`y-4&{cql4n#}`I&B8n&SH&- z7<3kj|3jqb$wl5D5_xh!{jnK}>v(pmC;6a_&4S9X%NXr_jZ#GUS&o zb4fhWaZ5lA2y2+$dEm%+SM%iug6+jWeEyH?Gb!Egp8l4O0+hrAzjVZO02IJKrlX~! z1O2~amoK%|Re#zjO6JE^wm~vnLO=CnYo*}|B1ckZsdSl!x8APp5r3A>iU#pqob~l9 zLpd4S2k+KjHWGJiFJ1ZFS zT`_oS|2@wDfwFODw}kF{{rWtBAuH3Hcr#L>!BGgO-+1^@m4#NL*`76t39V)%&Fh>Uw6>GqaX>*1HFW*XfRH3j&5ShO?Vn%G~L( z$}(R46&-#aLpN<}D&&k>=4KdllE<9wmK(y2G5&j#LqtOg8|)a-4hL!y>Iqqi*i&F$ zK!ddhLpvO;?j`8uqf*3C%49DqA{OBnN*x~^jr{l~ifc)$&OJH zFPge;`)vkobw`C_QzMW?Yij#d8J27}dvS7HCM%uZPpv?D}Rf~m5G8KDUGnZkZ z#LR=3D@-iFa983;^gWR>6gU@JmgHVoCFN%J-Tq2Dn^ZZfN+>mY<|FoQCtQE{eFfLs zPDihM_DHLaS46Ckuq^?eW%djS`;c{OaT{@6+(aWIvX*?4zE$=5aHVxhs@srd_>XAf zElG?Y4+?*6a-qYUyveZySqwKv)tG;^Plz;>WOpQj0?XPwE~fzRDe#KX`*K^*h@Q(Z zJZ%#_na*Y>uCO;3(5$-Y#ftI&G1ZjhxwzuR-a9WZaXw zNLFwK+JKgGCZDG;HYcXSGAK<5{j*IihO>EFTG?a>v|8!WwH31k=4{Qk62*!bL-X_X z(705CRYAR^HcQI5mA=P}KZJftMk``TXocOh#wBP5KHhW6O0lZ2{TQ>6$hgEs5tcNr zIhwDx(V_Yxk`XspvJ#y|Uk=JnQ> zJun?-@iRomyO*oH(voRpKUAO07$rNkt<-(UUy+y?gIHN&wK8*cF^@0pKP#qL%JF^R zT0#50<1i|+^IG%%%Z&nJ@*t+nk_dUuc$X(iQ-=`gZg|Lks845Iog^{=>Erfm2>w#KpY^U7v13LYtf;U|2%`IlBH4V{E` z^oD<;54HKuOL*TcL4Apq5r2!uz)JSKeQZ>(2YEhy@!&f?U;A}Q3(2GB@GEE$V6dF< zyLu0RoYen)xIG)6`mk!4+~RdUGRuiwQNGjSnYr_?y`4Cz8x;z5@!RuhDa^VL4OyAD zWJogt@5&=LPihnToo&t+rjHicj7X}A)$@2HM^L46iZ}*0xcRH8I+CQv`uZwM#RXcA z%VKo})W_`kaUNxmg#KT9dE1Vo<||8S{H~WhI>4%%>uZpVD??z9AGk@rn>;r^8Xlt$ zD|sE`WtC^wv|BePEyo9N;SgltM3R^ykMZ6CHO$}rTo!7Tt1%5@>-wSyX8!}fV$5|b zF`3Ub@+OZS_R*JCzVGPs4OXI!pe~N8mmB0V!x>kJu5w&GV~dr1X~A)yti*Y9uf|ib zSzUXEJduDoN1Xzj*&aQMnQ3nP0^42K@x@vPdyY9L!ytQ}D&x_7R;&4W^v;!Kk$VOa z$zu^BKo6r&6)MsVi*iYWiYS=Fe$+h8HE@A3#rC;WZR0EO z(8tIdkoUf7)EpDjXyxe2Y}}7Ire)D$?aqKJKjXxZN7sJnJBl z!zbTctmzw9?U=I+fYXW`v;Kq+7B38Sk8DCY zUb-7TGUj4LW28d+(9Fv|tbpouRGNp8NN;pE$U7Fh&%U0YvMUj*%oFb)8cG)W zHxv0d5O1F>WSEZG7G$X(m-e!oyM4u0lL6V<#wHX{Oqx#AMB6 z+@(e9zMLG1+~4e6r26FTskxMO7U773vVgBMh*&*-c6#CC@QqSI>G$QcY03w5T%HG) zVso{t4>h;ITc^O|li8B@dB69^Uz-G7E$*I=Hngp9dVC-}QJ6cuzHq&eBG3eKd&s*p zGV6$F`Wpz!_BJ+GCm%cldk0N8q}f><(E7)Zm*h_W4f>6@MrvO3g;Ni)=BeRBg%*f6 zoBOqIU>!r*dcC1GO_c4yn2At9uMP*|58D^ILYih=R*VeUNtra+qB0qtg{V(-w{e@LEtY6!)_+b<}FHft;xBn>QsIkwlYUTOuPyc(xt{E5jSwxwi zn5;m@K5A)+;3Yd=r$Br_ zVDz5>eKy1;%#QB78t$a&xW+|GqgWeSiV^k`*a5Hh<>mc-UJ=%VmxrO%W zLn2kW?s2*DvcJuEpTK2tK9-HP<%f_WC^7tT;|Wn|_XWit+Ue z@Y1D7C@IssSdH<+;h%~B3z-pOkuB$E}KAd#cSUK4ZRk}RV;WNyJ)c6iz zi^CjeuF};S~K7m!?%4eb5&WJaWX){rbg^t zW|S<~m6L{F`dfnufETbAby820F<6LswTH3C_vJVRc~icl%GVgYQS#2+OHU)28!>}X zySXX78Uk6hjcf$F3JA|13Zd|}n*B7F>?K2>j7t!fQjev?YQ$w&R6#)$^GfPxSL2;> ziFbzc{!;>pbEFyfIOXL7mkr~@G-Jh#1-h2knsI{NU*j+;W9-{ZhK9{06}s!^_mo=o zxSix=c%{ecQbw4y9exK~;H>RVX8xU@Us(qYSR#a4IU|vNCN4$p@NgL^B5ySlV`PuV zEZXFaaMA2i+FD`A^lGfsDPWUX{6ZUlsACqzs#Rxu)co!icj;htXVjIVlIc(xH`Gne zG>+OH`O5>xHyzVHE}1;X%S66B1+u)P`I%`7bf?V9Ge1Zro&wIXr4ssT_mNC$RugVA zFaA6QNaSP+*c*=Tmjy_K;k9g>6z$OXT# z>W7}Zc9^%)?t1P)8Qj=~b3L()$q8`EsjW233r9*n(bwWYc66DZ(-$)wfWHOW?LeM#Hw~2$$>3W@h zKRT-_q+JYtX`T!AzLtR)mu)Fd+5!PBmkxo3S3E?j>-}d+jUhd#BB7qohKNsEmo}Lv z$b<&CRv5`DtYyM>`gr%}U*{&nAE9`wcyl6bB zQ^7kiGH2p7v42>d#mTgdy8jgYl)u3i!-eRTKF=!)DM8|@$LXAin5@sHB&wEC$9v3! zdq)XmrjZ593-7-`&-JZMr7EQU^!{;l(sLm5^4G?Cqt!3&wB5dqo%+sGV7BqA4U~OZ zuY<>TVmo}8yl@F3TD=pYEmvJC%UU7H;`yrM=Vi#`_2*r3QgK-65=di3++t;z2xTW& z^kuV&7ygrt_E5fY>oCfI{Nmg@J=I&r^AA{QLY}z}0_gb9?(Z>Nt+g!6PTp)v0211= zIa#+8RqMl>nHku&8#`78_#u>ZM8Pw*evncp?b%K*OQK0;!xYD5=2eyt?Io9)-}Kh5 z(ktBGnXvTWej^g_gEp`&b4>mvh|wfH5NbtbM`dsWiNU2aUE(vp*?@CTKUw53@2-)- z@6mrg1-_^HMtOcb5q8ws?&CUfIj7|fH^_I1icd}a2;=tVcUGeYtvJF*m5%hM^%eVR zWi)Aex}Kf_Xw};%N#=f?#ZbI23;stm{xufj2H2=u>-Uclk0p2LHmQXUIaR_aEn3ac zgrsMM@lh+L4kO9vA)!o@NJXRd#BsB|@v?Nehk>AO4OVyTjtqI&t!q^5b})(*-Ci%w z>`o4ESKs_-|8Py8r-r!4u<*OE$mUVqX%s^Aydigbyu3(ua-W$ssy?Ov6p*leW2g77PyEia8U^iKUC=qt4OCCL zgm1|tHTBlk%6Y3VG2`j+`-j)!-QJ?t%0J1^k=kan(}20DNRR8gcpA4Be7Os>i%)P& z?2Q2;r$EE-Q%?ytq{>1{8XO#4=BGQS;&2b2XiPI-fV!E2Zcy+BTY9d-B)>`v>%f(c z;hwW!7wtpgTlaC;jB)5g5KL8OtkZE)T?_%SuDN4_Udqrg+A-^+X=P`lr*$4Lj&o%g zpgGs33RrbEmHptY&s@LI2n(C)T7Mp&WwOV=YS{00sd&MyOFh1tLdw4~@~b0`h54ky zHbh}1vdxTfvJQ&{=fYuOb#_wwba$U`@*xVy$Qs1PUf?Y)m@u_*d@1-KF%s))te#N67vDiwE&bQ~6X& zVrFIfO)T7+avM8J-yqcRkgp`(0gNLa20t4^q+DRtI(#U{G>;@O^+cC`H5FChB~crn zC4tYo#1E-NO4W4l1_fn(Z<_7i&DxYY$np=ZIXY?b9}wKBV9OMjr3=gGa_U&#(IcT1 z|16o#g@r4cQ@0+eyOLGytwccu9#rUD@~j`#;j{=FgkwjpCf)F#uSdS%?BEC>lcLoz z(sb$VN8`ZNc~aMrbu>0V#|T)#if zwR3&d*TnI+P@9tpoB0kXXGQP_uh&HF%EL z^&IFj1ncHzo-{*)%QFyj&t;@BebCnW#J$(`jE>FoB#nloW3MIhTV*vnYP_h^tG9)Gnw;zQRJtx>i*c6~Yy1i$-^BkcVrZ}-OTs1BF(1z>0H471M_QzBk_ z9YQ(0@gw{26g}|mO>LG33-#K+5clL8Ka1gQUJV%8DZvw2qH`^4N4WIN$~&n6RJ(I5 z+E;HbFdu&>so6|~nr}>0Ru%8*9hm29>7!VogbNsm?_VP=--})vG^gU&)?Y7{x496u zHX5|L@K!`x8oe|qlmumMHL$LIfja3@ZTU39D>3f#WOkP0_19lmrf%d|fwWI&P%enh z)S&1_k^s&uIm+#rhcxoi=m>lsPvN0A0gbCm8!>Iwg6rz*5>uK741u}kh-F3PZy&Sb z)sD<(iWhSfkCy(ihrLbg3S2WX4u!O4Oq1E% zeFe*)v5#xyd``&h=>49i^f3Y(u5H*Jp)%Yve~oWi(CdS$h}FReYNTI*1GmdbeAUT? zCqJj|SLd&ddH))V#n-8FF}@lp(SejJExxHH8H(Zy`m2npU@1*p@*rwk?j*oJ*$i9h!$259hw$S;`4{0BuF?G!HQvJTRr)#%gFYo6o!Ffj<(Tgoe-tSQ zX=RC>`YhXxcg~WP%uf85s*SwF8Ln18{l0fp$3AjpdWUwtx){bvCKb?K`bMU2i8;RZ zgY|GK&yyBR(=0_-4yjXalb}HF3JQuty6Cf-CGoMUw?xdwEQfV6Wp}S zgnRYc%pJi#aus3ROLjzq~VkqR+wwRU{bm7Ul&>E7J`avgu)FbctOhLJ z`qD(-pZrma{;Q&?b{7(er+#!%;2~eCAN2M2mK4lF3GJZm(k?8hyVQ6-OlfZEMs=F7 z(CYGsp)W=u%$wS^Xi+CpCd94yHp>DEJ?}}gp0a4{chX0~t=*sk`w4R!I1J{?Jh literal 0 HcmV?d00001