Skip to content

Commit

Permalink
adding gradient check for target curve difference L2 norm objective
Browse files Browse the repository at this point in the history
  • Loading branch information
ralberd committed Jan 3, 2024
1 parent 1bb5080 commit 056a392
Showing 1 changed file with 113 additions and 1 deletion.
114 changes: 113 additions & 1 deletion optimism/inverse/test/test_J2Plastic_gradient_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,106 @@ def energy_function_all_dofs(U, ivs, coords):

return gradient.ravel()

def test_gradient_with_adjoint_solve(self):
def compute_L2_norm_difference(self, uSteps, ivsSteps, bcsSteps, coordinates, nodal_forces):
index = (self.mesh.nodeSets['left'], 1) # arbitrarily choosing left side nodeset for reaction force

numerator = 0.0
denominator= 0.0
for i in range(0, len(self.targetSteps)):
step = self.targetSteps[i]
Uu = uSteps[step]
bc_data = bcsSteps[step]
ivs = ivsSteps[step]

U = self.dofManager.create_field(Uu, bc_data)
force = np.sum(np.array(nodal_forces(U, ivs, coordinates).at[index].get()))

diff = force - self.targetForces[i]
numerator += diff*diff
denominator += self.targetForces[i]*self.targetForces[i]

return np.sqrt(numerator/denominator)

def target_curve_objective(self, storedState, parameters):
parameters = parameters.reshape(self.mesh.coords.shape)

def energy_function_all_dofs(U, ivs, coords):
adjoint_func_space = AdjointFunctionSpace.construct_function_space_for_adjoint(coords, self.mesh, self.quadRule)
mech_funcs = Mechanics.create_mechanics_functions(adjoint_func_space, mode2D='plane strain',materialModel=self.materialModel)
return mech_funcs.compute_strain_energy(U, ivs)

nodal_forces = jax.jit(jax.grad(energy_function_all_dofs, argnums=0))

uSteps = np.stack([storedState[i][0] for i in range(0, self.steps+1)], axis=0)
ivsSteps = np.stack([storedState[i][1].state_data for i in range(0, self.steps+1)], axis=0)
bcsSteps = np.stack([storedState[i][1].bc_data for i in range(0, self.steps+1)], axis=0)

return self.compute_L2_norm_difference(uSteps, ivsSteps, bcsSteps, parameters, nodal_forces)

def target_curve_gradient(self, storedState, parameters):

def energy_function_coords(Uu, p, ivs_prev, coords):
adjoint_func_space = AdjointFunctionSpace.construct_function_space_for_adjoint(coords, self.mesh, self.quadRule)
mech_funcs = Mechanics.create_mechanics_functions(adjoint_func_space, mode2D='plane strain', materialModel=self.materialModel)
U = self.dofManager.create_field(Uu, p.bc_data)
return mech_funcs.compute_strain_energy(U, ivs_prev)

def energy_function_all_dofs(U, ivs, coords):
adjoint_func_space = AdjointFunctionSpace.construct_function_space_for_adjoint(coords, self.mesh, self.quadRule)
mech_funcs = Mechanics.create_mechanics_functions(adjoint_func_space, mode2D='plane strain',materialModel=self.materialModel)
return mech_funcs.compute_strain_energy(U, ivs)

nodal_forces = jax.jit(jax.grad(energy_function_all_dofs, argnums=0))

functionSpace = FunctionSpace.construct_function_space(self.mesh, self.quadRule)
ivsUpdateInverseFuncs = MechanicsInverse.create_ivs_update_inverse_functions(functionSpace,
"plane strain",
self.materialModel)

residualInverseFuncs = MechanicsInverse.create_path_dependent_residual_inverse_functions(energy_function_coords)

parameters = parameters.reshape(self.mesh.coords.shape)

# derivatives of F
uSteps = np.stack([storedState[i][0] for i in range(0, self.steps+1)], axis=0)
ivsSteps = np.stack([storedState[i][1].state_data for i in range(0, self.steps+1)], axis=0)
bcsSteps = np.stack([storedState[i][1].bc_data for i in range(0, self.steps+1)], axis=0)
df_du, df_dc, gradient = jax.grad(self.compute_L2_norm_difference, (0, 1, 3))(uSteps, ivsSteps, bcsSteps, parameters, nodal_forces)

mu = np.zeros(ivsSteps[0].shape)
adjointLoad = np.zeros(uSteps[0].shape)

for step in reversed(range(1, self.steps+1)):
Uu = uSteps[step]
p = storedState[step][1]
p_prev = storedState[step-1][1]
ivs_prev = ivsSteps[step-1]

dc_dcn = ivsUpdateInverseFuncs.ivs_update_jac_ivs_prev(self.dofManager.create_field(Uu, p.bc_data), ivs_prev)

mu += df_dc[step]
adjointLoad -= df_du[step]
adjointLoad -= self.dofManager.get_unknown_values(ivsUpdateInverseFuncs.ivs_update_jac_disp_vjp(self.dofManager.create_field(Uu, p.bc_data), ivs_prev, mu))

n = self.dofManager.get_unknown_size()
p_objective = Objective.Params(bc_data=p.bc_data, state_data=p_prev.state_data) # remember R is a function of ivs_prev
self.objective.p = p_objective
self.objective.update_precond(Uu) # update preconditioner for use in cg (will converge in 1 iteration as long as the preconditioner is not approximate)
dRdu = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.hessian_vec(Uu, V)))
dRdu_decomp = linalg.LinearOperator((n, n), lambda V: onp.asarray(self.objective.apply_precond(V)))
adjointVector = linalg.cg(dRdu, onp.array(adjointLoad, copy=False), tol=1e-10, atol=0.0, M=dRdu_decomp)[0]

gradient += residualInverseFuncs.residual_jac_coords_vjp(Uu, p, ivs_prev, parameters, adjointVector)
gradient += ivsUpdateInverseFuncs.ivs_update_jac_coords_vjp(self.dofManager.create_field(Uu, p.bc_data), ivs_prev, parameters, mu)

mu = np.einsum('ijk,ijkn->ijn', mu, dc_dcn)
mu += residualInverseFuncs.residual_jac_ivs_prev_vjp(Uu, p, ivs_prev, parameters, adjointVector)

adjointLoad = np.zeros(storedState[0][0].shape)

return gradient.ravel()

def test_total_work_gradient_with_adjoint_solve(self):
self.compute_objective_function = self.total_work_objective
self.compute_gradient = self.total_work_gradient

Expand All @@ -194,6 +293,19 @@ def test_gradient_with_adjoint_solve(self):
errors = self.compute_finite_difference_errors(initialStepSize, numSteps, self.initialMesh.coords.ravel())
self.assertFiniteDifferenceCheckHasVShape(errors)

def test_target_curve_gradient_with_adjoint_solve(self):
self.compute_objective_function = self.target_curve_objective
self.compute_gradient = self.target_curve_gradient

self.targetSteps = [1, 2]
self.targetForces = [4.5, 5.5] # [4.542013626078756, 5.7673988583067555] actual forces

initialStepSize = 1e-6
numSteps = 4

errors = self.compute_finite_difference_errors(initialStepSize, numSteps, self.initialMesh.coords.ravel())
self.assertFiniteDifferenceCheckHasVShape(errors)



if __name__ == '__main__':
Expand Down

0 comments on commit 056a392

Please sign in to comment.