Skip to content

Commit

Permalink
Merge pull request #69 from sandialabs/ralberd/path_dependent_adjoint
Browse files Browse the repository at this point in the history
ralberd/path dependent adjoint
  • Loading branch information
ralberd authored Jan 17, 2024
2 parents ea24f41 + ebb5fb8 commit a2f8efe
Show file tree
Hide file tree
Showing 7 changed files with 1,000 additions and 1 deletion.
2 changes: 1 addition & 1 deletion optimism/Interpolants.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_lobatto_nodes_1d(degree):
p = onp.polynomial.Legendre.basis(degree, domain=[0.0, 1.0])
dp = p.deriv()
xInterior = dp.roots()
xn = np.hstack(([0.0], xInterior, [1.0]))
xn = np.hstack((np.array([0.0]), xInterior, np.array([1.0])))
return xn


Expand Down
30 changes: 30 additions & 0 deletions optimism/inverse/AdjointFunctionSpace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from optimism import FunctionSpace
from optimism import Interpolants
from optimism import Mesh
from optimism.FunctionSpace import compute_element_volumes
from optimism.FunctionSpace import compute_element_volumes_axisymmetric
from optimism.FunctionSpace import map_element_shape_grads
from jax import vmap

def construct_function_space_for_adjoint(coords, mesh, quadratureRule, mode2D='cartesian'):

shapeOnRef = Interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss)

shapes = vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values)

shapeGrads = vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients)

if mode2D == 'cartesian':
el_vols = compute_element_volumes
isAxisymmetric = False
elif mode2D == 'axisymmetric':
el_vols = compute_element_volumes_axisymmetric
isAxisymmetric = True
vols = vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss)

# unpack mesh and remake a mesh to make sure we get all the AD
mesh = Mesh.Mesh(coords=coords, conns=mesh.conns, simplexNodesOrdinals=mesh.simplexNodesOrdinals,
parentElement=mesh.parentElement, parentElement1d=mesh.parentElement1d, blocks=mesh.blocks,
nodeSets=mesh.nodeSets, sideSets=mesh.sideSets)

return FunctionSpace.FunctionSpace(shapes, vols, shapeGrads, mesh, quadratureRule, isAxisymmetric)
99 changes: 99 additions & 0 deletions optimism/inverse/MechanicsInverse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from collections import namedtuple

from optimism.JaxConfig import *
from optimism import FunctionSpace
from optimism import Interpolants
from optimism.TensorMath import tensor_2D_to_3D

IvsUpdateInverseFunctions = namedtuple('IvsUpdateInverseFunctions',
['ivs_update_jac_ivs_prev',
'ivs_update_jac_disp_vjp',
'ivs_update_jac_coords_vjp'])

PathDependentResidualInverseFunctions = namedtuple('PathDependentResidualInverseFunctions',
['residual_jac_ivs_prev_vjp',
'residual_jac_coords_vjp'])

ResidualInverseFunctions = namedtuple('ResidualInverseFunctions',
['residual_jac_coords_vjp'])

def _compute_element_field_gradient(U, elemShapeGrads, elemConnectivity, modify_element_gradient):
elemNodalDisps = U[elemConnectivity]
elemGrads = vmap(FunctionSpace.compute_quadrature_point_field_gradient, (None, 0))(elemNodalDisps, elemShapeGrads)
elemGrads = modify_element_gradient(elemGrads)
return elemGrads

def _compute_field_gradient(shapeGrads, conns, nodalField, modify_element_gradient):
return vmap(_compute_element_field_gradient, (None,0,0,None))(nodalField, shapeGrads, conns, modify_element_gradient)

def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_state_new, output_shape):
dgQuadPointRavel = dispGrads.reshape(dispGrads.shape[0]*dispGrads.shape[1],*dispGrads.shape[2:])
stQuadPointRavel = states.reshape(states.shape[0]*states.shape[1],*states.shape[2:])
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt)
return statesNew.reshape(output_shape)


def create_ivs_update_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None):
fs = functionSpace
shapeOnRef = Interpolants.compute_shapes(fs.mesh.parentElement, fs.quadratureRule.xigauss)

if mode2D == 'plane strain':
grad_2D_to_3D = vmap(tensor_2D_to_3D)
elif mode2D == 'axisymmetric':
raise NotImplementedError

modify_element_gradient = grad_2D_to_3D
if pressureProjectionDegree is not None:
raise NotImplementedError

def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=0.0):
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_gradient = jacfwd(materialModel.compute_state_new, argnums=1)
grad_shape = stateVariables.shape + (stateVariables.shape[2],)
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_gradient, grad_shape)

def compute_ivs_update_parameterized(U, stateVariables, coords, dt=0.0):
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients)
dispGrads = _compute_field_gradient(shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_func = materialModel.compute_state_new
output_shape = stateVariables.shape
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_func, output_shape)

compute_partial_ivs_update_partial_coords = jit(lambda u, ivs, x, av, dt=0.0:
vjp(lambda z: compute_ivs_update_parameterized(u, ivs, z, dt), x)[1](av)[0])

def compute_ivs_update(U, stateVariables, dt=0.0):
dispGrads = _compute_field_gradient(fs.shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_func = materialModel.compute_state_new
output_shape = stateVariables.shape
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_func, output_shape)

compute_partial_ivs_update_partial_disp = jit(lambda x, ivs, av, dt=0.0:
vjp(lambda z: compute_ivs_update(z, ivs, dt), x)[1](av)[0])

return IvsUpdateInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev),
compute_partial_ivs_update_partial_disp,
compute_partial_ivs_update_partial_coords
)

def create_path_dependent_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_ivs_prev = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, z, x), iv)[1](vx)[0])

compute_partial_residual_partial_coords = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, iv, z), x)[1](vx)[0])

return PathDependentResidualInverseFunctions(compute_partial_residual_partial_ivs_prev,
compute_partial_residual_partial_coords
)

def create_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_coords = jit(lambda u, q, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, z), x)[1](vx)[0])

return ResidualInverseFunctions(compute_partial_residual_partial_coords)
67 changes: 67 additions & 0 deletions optimism/inverse/test/FiniteDifferenceFixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from optimism.test.MeshFixture import MeshFixture
from collections import namedtuple
import numpy as onp

class FiniteDifferenceFixture(MeshFixture):
def assertFiniteDifferenceCheckHasVShape(self, errors, tolerance=1e-6):
minError = min(errors)
self.assertLess(minError, tolerance, "Smallest finite difference error not less than tolerance.")
self.assertLess(minError, errors[0], "Finite difference error does not decrease from initial step size.")
self.assertLess(minError, errors[-1], "Finite difference error does not increase after reaching minimum. Try more finite difference steps.")

def build_direction_vector(self, numDesignVars, seed=123):

onp.random.seed(seed)
directionVector = onp.random.uniform(-1.0, 1.0, numDesignVars)
normVector = directionVector / onp.linalg.norm(directionVector)

return onp.array(normVector)

def compute_finite_difference_error(self, stepSize, initialParameters):
storedState = self.forward_solve(initialParameters)
originalObjective = self.compute_objective_function(storedState, initialParameters)
gradient = self.compute_gradient(storedState, initialParameters)

directionVector = self.build_direction_vector(initialParameters.shape[0])
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1)

perturbedParameters = initialParameters + stepSize * directionVector
storedState = self.forward_solve(perturbedParameters)
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters)

fd_value = (perturbedObjective - originalObjective) / stepSize
error = abs(directionalDerivative - fd_value)

return error

def compute_finite_difference_errors(self, stepSize, steps, initialParameters, printOutput=True):
storedState = self.forward_solve(initialParameters)
originalObjective = self.compute_objective_function(storedState, initialParameters)
gradient = self.compute_gradient(storedState, initialParameters)

directionVector = self.build_direction_vector(initialParameters.shape[0])
directionalDerivative = onp.tensordot(directionVector, gradient, axes=1)

fd_values = []
errors = []
for i in range(0, steps):
perturbedParameters = initialParameters + stepSize * directionVector
storedState = self.forward_solve(perturbedParameters)
perturbedObjective = self.compute_objective_function(storedState, perturbedParameters)

fd_value = (perturbedObjective - originalObjective) / stepSize
fd_values.append(fd_value)

error = abs(directionalDerivative - fd_value)
errors.append(error)

stepSize *= 1e-1

if printOutput:
print("\n grad'*dir | FD approx | abs error")
print("--------------------------------------------------------------------------------")
for i in range(0, steps):
print(f" {directionalDerivative} | {fd_values[i]} | {errors[i]}")

return errors

Loading

0 comments on commit a2f8efe

Please sign in to comment.