Skip to content

Commit

Permalink
moving functionality into MechanicsInverse and improving jitting
Browse files Browse the repository at this point in the history
  • Loading branch information
ralberd committed Nov 10, 2023
1 parent 6cfd43d commit f618c1b
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 693 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@
from optimism.FunctionSpace import compute_element_volumes
from optimism.FunctionSpace import compute_element_volumes_axisymmetric
from optimism.FunctionSpace import map_element_shape_grads
import jax

from jax import vmap

def construct_function_space_for_adjoint(coords, mesh, quadratureRule, mode2D='cartesian'):

shapeOnRef = Interpolants.compute_shapes(mesh.parentElement, quadratureRule.xigauss)

shapes = jax.vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values)
shapes = vmap(lambda elConns, elShape: elShape, (0, None))(mesh.conns, shapeOnRef.values)

shapeGrads = jax.vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients)
shapeGrads = vmap(map_element_shape_grads, (None, 0, None, None))(coords, mesh.conns, mesh.parentElement, shapeOnRef.gradients)

if mode2D == 'cartesian':
el_vols = compute_element_volumes
elif mode2D == 'axisymmetric':
el_vols = compute_element_volumes_axisymmetric
vols = jax.vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss)
vols = vmap(el_vols, (None, 0, None, 0, None))(coords, mesh.conns, mesh.parentElement, shapes, quadratureRule.wgauss)

# unpack mesh and remake a mesh to make sure we get all the AD
mesh = Mesh.Mesh(coords=coords, conns=mesh.conns, simplexNodesOrdinals=mesh.simplexNodesOrdinals,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from collections import namedtuple

from optimism.JaxConfig import *
from optimism import Mechanics
from optimism import FunctionSpace
from optimism import Interpolants
from optimism.TensorMath import tensor_2D_to_3D

MechanicsInverseFunctions = namedtuple('MechanicsInverseFunctions',
IvsUpdateInverseFunctions = namedtuple('IvsUpdateInverseFunctions',
['ivs_update_jac_ivs_prev',
'ivs_update_jac_disp_vjp',
'ivs_update_jac_coords_vjp',
'nodal_forces_parameterized',
'residual_jac_ivs_prev_vjp',
'residual_jac_coords_vjp'])
'ivs_update_jac_coords_vjp'])

PathDependentResidualInverseFunctions = namedtuple('PathDependentResidualInverseFunctions',
['residual_jac_ivs_prev_vjp',
'residual_jac_coords_vjp'])

ResidualInverseFunctions = namedtuple('ResidualInverseFunctions',
['residual_jac_coords_vjp'])

def _compute_quadrature_point_field_gradient(u, shapeGrad):
dg = np.tensordot(u, shapeGrad, axes=[0,0])
Expand All @@ -33,44 +36,8 @@ def _compute_updated_internal_variables_gradient(dispGrads, states, dt, compute_
statesNew = vmap(compute_state_new, (0, 0, None))(dgQuadPointRavel, stQuadPointRavel, dt)
return statesNew.reshape(output_shape)

def _compute_strain_energy(functionSpace, coords, shapeGrads, vols,
UField, stateField, dt,
compute_energy_density, modify_element_gradient):
L = Mechanics.strain_energy_density_to_lagrangian_density(compute_energy_density)
return _integrate_over_block(functionSpace, coords, shapeGrads, vols,
UField, stateField, dt,
L, slice(None), modify_element_gradient)

def _integrate_over_block(functionSpace, coords, shapeGrads, vols,
U, stateVars, dt,
func, block, modify_element_gradient):
vals = _evaluate_on_block(functionSpace, coords, shapeGrads,
U, stateVars, dt,
func, block, modify_element_gradient)
return np.dot(vals.ravel(), vols[block].ravel())

def _evaluate_on_block(functionSpace, coords, shapeGrads,
U, stateVars, dt,
func, block, modify_element_gradient):
compute_elem_values = vmap(_evaluate_on_element, (None, None, 0, None, 0, 0, 0, None, None))

blockValues = compute_elem_values(U, coords, stateVars[block], dt,
functionSpace.shapes[block], shapeGrads[block], functionSpace.mesh.conns[block],
func, modify_element_gradient)
return blockValues

def _evaluate_on_element(U, coords, elemStates, dt,
elemShapes, elemShapeGrads, elemConn,
kernelFunc, modify_element_gradient):
elemVals = FunctionSpace.interpolate_to_element_points(U, elemShapes, elemConn)
elemGrads = _compute_element_field_gradient(U, elemShapeGrads, elemConn, modify_element_gradient)
elemXs = FunctionSpace.interpolate_to_element_points(coords, elemShapes, elemConn)
vmapArgs = 0, 0, 0, 0, None
fVals = vmap(kernelFunc, vmapArgs)(elemVals, elemGrads, elemStates, elemXs, dt)
return fVals


def create_mechanics_inverse_functions(functionSpace, createField, mode2D, materialModel, pressureProjectionDegree=None, dt=0.0):

def create_ivs_update_inverse_functions(functionSpace, mode2D, materialModel, pressureProjectionDegree=None, dt=0.0):
fs = functionSpace
shapeOnRef = Interpolants.compute_shapes(fs.mesh.parentElement, fs.quadratureRule.xigauss)

Expand All @@ -90,8 +57,7 @@ def compute_partial_ivs_update_partial_ivs_prev(U, stateVariables, dt=dt):
return _compute_updated_internal_variables_gradient(dispGrads, stateVariables, dt,\
update_gradient, grad_shape)

def compute_ivs_update_parameterized(U, stateVariables, coordinates, dt=dt):
coords = coordinates.reshape(fs.mesh.coords.shape)
def compute_ivs_update_parameterized(U, stateVariables, coords, dt=dt):
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients)
dispGrads = _compute_field_gradient(shapeGrads, fs.mesh.conns, U, modify_element_gradient)
update_func = materialModel.compute_state_new
Expand All @@ -112,27 +78,26 @@ def compute_ivs_update(U, stateVariables, dt=dt):
compute_partial_ivs_update_partial_disp = jit(lambda x, ivs, av:
vjp(lambda z: compute_ivs_update(z, ivs), x)[1](av)[0])

def compute_strain_energy_parameterized(U, stateVariables, coordinates, dt=dt):
coords = coordinates.reshape(fs.mesh.coords.shape)
shapes = vmap(lambda elConns, elShape: elShape, (0, None))(fs.mesh.conns, shapeOnRef.values)
vols = vmap(FunctionSpace.compute_element_volumes, (None, 0, None, 0, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapes, fs.quadratureRule.wgauss)
shapeGrads = vmap(FunctionSpace.map_element_shape_grads, (None, 0, None, None))(coords, fs.mesh.conns, fs.mesh.parentElement, shapeOnRef.gradients)
return _compute_strain_energy(fs, coords, shapeGrads, vols, U, stateVariables, dt, materialModel.compute_energy_density, modify_element_gradient)
return IvsUpdateInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev),
compute_partial_ivs_update_partial_disp,
compute_partial_ivs_update_partial_coords
)

def compute_strain_energy_for_residual(Uu, p, stateVariables, coordinates, dt=dt):
U = createField(Uu, p)
return compute_strain_energy_parameterized(U, stateVariables, coordinates, dt)
def create_path_dependent_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_ivs_prev = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(compute_strain_energy_for_residual, 0)(u, q, z, x), iv)[1](vx)[0])
vjp(lambda z: grad(energyFunction, 0)(u, q, z, x), iv)[1](vx)[0])

compute_partial_residual_partial_coords = jit(lambda u, q, iv, x, vx:
vjp(lambda z: grad(compute_strain_energy_for_residual, 0)(u, q, iv, z), x)[1](vx)[0])
vjp(lambda z: grad(energyFunction, 0)(u, q, iv, z), x)[1](vx)[0])

return MechanicsInverseFunctions(jit(compute_partial_ivs_update_partial_ivs_prev),
compute_partial_ivs_update_partial_disp,
compute_partial_ivs_update_partial_coords,
jit(grad(compute_strain_energy_parameterized)),
compute_partial_residual_partial_ivs_prev,
return PathDependentResidualInverseFunctions(compute_partial_residual_partial_ivs_prev,
compute_partial_residual_partial_coords
)
)

def create_residual_inverse_functions(energyFunction):

compute_partial_residual_partial_coords = jit(lambda u, q, x, vx:
vjp(lambda z: grad(energyFunction, 0)(u, q, z), x)[1](vx)[0])

return ResidualInverseFunctions(compute_partial_residual_partial_coords)
Loading

0 comments on commit f618c1b

Please sign in to comment.