Source code for quadcoil.solver

import warnings
import jax.numpy as jnp
import optax
import optax.tree_utils as otu
from functools import partial
from jax import jit, vmap, grad, jacrev
import jax
from jax.lax import while_loop
from jax import config as config_jax
config_jax.update('jax_enable_x64', True)

lstsq_vmap = vmap(jnp.linalg.lstsq)

# def wl_debug(cond_fun, body_fun, init_val):
#     val = init_val
#     iter_num_wl = 1
#     while cond_fun(val):
#         val = body_fun(val)
#     return val
[docs] def delta_normalized(x1, x2): diff = jnp.abs(x1-x2) max = jnp.maximum(jnp.abs(x1), jnp.abs(x2)) return jnp.where(max>0, diff/max, max)
[docs] def run_opt_lbfgs(init_params, fun, maxiter, fstop, xstop, gtol, max_linesearch_steps, verbose): r''' A wrapper for performing unconstrained optimization using ``optax.lbfgs``. Parameters ---------- init_params : ndarray, shape (N,) The initial condition. fun : Callable The objective function. maxiter : int The maximum iteration number. fstop : float The objective function convergence rate tolerance. Terminates when any one of the tolerances is satisfied. xstop : float The unknown convergence rate tolerance. Terminates when any one of the tolerances is satisfied. gtol : float The gradient tolerance. Terminates when any one of the tolerances is satisfied. max_linesearch_steps : int The maximum steps count in the LBFGS line search. verbose : int Output levels of detail. Returns ------- x : ndarray, shape (N,) The optimum. f : float The objective at the optimum. grad : ndarray, shape (N,) The gradient at the optimum. count : int The iteration number. final_dx : float The rate of change of x at the optimum. final_du : float The rate of change of updates at the optimum. final_df : float The rate of change of f at the optimum. ''' return run_opt_optax( init_params, fun, maxiter, fstop, xstop, gtol, opt=optax.lbfgs( linesearch=optax.scale_by_zoom_linesearch( max_linesearch_steps=max_linesearch_steps, initial_guess_strategy='one' ) ), verbose=verbose )
[docs] def run_opt_optax(init_params, fun, maxiter, fstop, xstop, gtol, opt, verbose, n_history=10): r''' A wrapper for performing unconstrained optimization using ``optax.base.GradientTransformationExtraArgs``. Parameters ---------- init_params : ndarray, shape (N,) The initial condition. fun : Callable The objective function. maxiter : int The maximum iteration number. fstop : float The objective function convergence rate tolerance. Terminates when any one of the tolerances is satisfied. xstop : float The unknown convergence rate tolerance. Terminates when any one of the tolerances is satisfied. gtol : float The gradient tolerance. Terminates when any one of the tolerances is satisfied. opt : optax.base.GradientTransformationExtraArgs The optimizer of choice. Returns ------- x : ndarray, shape (N,) The optimum. f : float The objective at the optimum. grad : ndarray, shape (N,) The gradient at the optimum. count : int The iteration number. final_dx : float The rate of change of x at the optimum. final_du : float The rate of change of updates at the optimum. final_df : float The rate of change of f at the optimum. ''' init_val = fun(init_params) init_carry = ( init_params, # params jnp.zeros_like(init_params), # update init_val, # value jnp.linspace(1000*fstop, 0, n_history) + init_val * 2, # val_rec jnp.zeros_like(init_params), # dx jnp.zeros_like(init_params), # du 0, # df # 0., 0., 0., 0., opt.init(init_params) # state1 ) g0 = grad(fun)(init_params) g0_norm = jnp.linalg.norm(g0) g0_max = jnp.max(jnp.abs(g0)) value_and_grad_fun = optax.value_and_grad_from_state(fun) if verbose>1: jax.debug.print('INNER: starting gradient L2 norm: {a}', a=g0_norm) # Carry is params, update, value, val_rec, dx, du, df, state1 def step(carry): params1, updates1, value1, val_rec, _, _, _, state1 = carry value2, grad2 = value_and_grad_fun(params1, state=state1) updates2, state2 = opt.update( grad2, state1, params1, value=value2, grad=grad2, value_fn=fun ) params2 = optax.apply_updates(params1, updates2) return( params2, updates2, value2, jnp.append(val_rec[1:], value2), jnp.abs(params2 - params1), # jnp.linalg.norm(params2 - params1), jnp.abs(updates2 - updates1), # jnp.linalg.norm(updates2 - updates1), jnp.abs(value2 - value1), # jnp.linalg.norm(delta_normalized(params2, params1)), # jnp.linalg.norm(delta_normalized(updates2, updates1)), # delta_normalized(value2, value1), state2 ) def continuing_criterion(carry): params, _, value, val_rec, dx, du, df, state = carry iter_num = otu.tree_get(state, 'count') grad = otu.tree_get(state, 'grad') err = otu.tree_norm(grad) # DEBUG param2 = dx + params dx1 = param2 - params dx_norm = jnp.linalg.norm(dx) du_norm = jnp.linalg.norm(du) params_norm = jnp.linalg.norm(params) avg_improvement = jnp.average(val_rec[:-1] - val_rec[1:]) if verbose>2: jax.debug.print( 'INNER: L: {l}, dx: {dx}, du: {du}, df: {df}, \n'\ ' grad:{g}, grad/g0:{gnorm}, Average improvement: {adf}\n'\ ' Value record: {val_rec}\n'\ ' Stopping criteria: \n' '(iter_num < maxiter): {a}\n' '& (err > gtol) : {b}\n' '& (avg_improvement > fstop): {ff}' '& ((dx_norm > xstop) | (du_norm > xstop) | (df > fstop)): {c}, {d}, {e}\n' '(dx_norm > xstop): {dx_norm} > {xstop})\n' '(du_norm > xstop): {du_norm} > {xstop}\n' '(df > fstop): {df} > {fstop}\n' '', adf=avg_improvement, val_rec=val_rec, a=(iter_num < maxiter), b=(err > gtol), c=(dx_norm > xstop), d=(du_norm > xstop), e=(df > fstop), ff=(avg_improvement > fstop), l=value, dx=jnp.max(dx), du=jnp.max(du), g=err, gnorm=err/g0_norm, dx_norm=dx_norm, du_norm=du_norm, xstop=xstop, df=df, fstop=fstop, ) return (iter_num == 0) | ( (iter_num < maxiter) & (err > gtol) & (avg_improvement > fstop) # Added May 27 & ((dx_norm > xstop) | (du_norm > xstop) | (df > fstop)) # The last one is added on May 19 # & ((dx_norm > xstop * params_norm) | (du_norm > xstop * params_norm)) # & (df > fstop * value) ) final_params, final_updates, final_value, val_rec, final_dx, final_du, final_df, final_state = while_loop( continuing_criterion, step, init_carry ) return( final_params, final_value, otu.tree_get(final_state, 'grad'), otu.tree_get(final_state, 'count'), jnp.linalg.norm(final_dx),# final_dx, # Changes in x jnp.linalg.norm(final_du),# final_du, # Changes in u final_df, # Changes in f )
# Thresholding function for g+. # The original one is gplus_hard. # Introducing soft thresholding may improve differentiation behavior. gplus_hard = lambda x, mu, c, g_ineq: jnp.maximum(g_ineq(x), -mu/c)
[docs] def gplus_elu(x, mu, c, g_ineq, scale=1): gval_shifted = g_ineq(x) + mu/c return jnp.where( gval_shifted<0, (jnp.exp(scale * gval_shifted) - 1)/scale - mu/c, gval_shifted - mu/c )
[docs] def gplus_softplus(x, mu, c, g_ineq, scale=1): gval_shifted = g_ineq(x) + mu/c return jnp.log(1 + jnp.exp(scale * gval_shifted))/scale - mu/c
[docs] def solve_constrained( x_init, # x_unit_init, f_obj, # run_opt : Callable, optional, default=run_opt_lbfgs # The optimizer choice. Must be a wrapper with the # same signature as ``run_opt_lbfgs``. # run_opt=run_opt_lbfgs, # No constraints by default c_init=1., c_growth_rate=1.1, lam_init=jnp.zeros(0), h_eq=lambda x:jnp.zeros(0), mu_init=jnp.zeros(0), g_ineq=lambda x:jnp.zeros(0), xstop_outer=1e-7, # convergence rate tolerance # gtol_outer=1e-7, # gradient tolerance ctol_outer=1e-7, # constraint tolerance, used in multiplier update fstop_inner=1e-7, xstop_inner=1e-7, gtol_inner=1e-7, fstop_inner_last=1e-7, xstop_inner_last=1e-7, gtol_inner_last=1e-7, maxiter_tot=10000, maxiter_inner=500, # # Uses jax.lax.scan instead of while_loop. # # Enables history and forward diff but disables # # convergence test. max_linesearch_steps=20, verbose=0, c_k_safe=1e15, gplus_mask=gplus_hard, ): r''' Solves the constrained optimization problem: .. math:: \min_x f(x) \\ \text{subject to } \\ h(x) = 0, \\ g(x) \leq 0 \\ Using the augmented Lagrangian method in *Constrained Optimization and Lagrange Multiplier Methods* Chapter 3. Please refer to the chapter for notation. Parameters ---------- init_params : ndarray, shape (N,) fun : Callable maxiter : int The maximum iteration number. fstop : float The objective function convergence rate tolerance. Terminates when any one of the tolerances is satisfied. xstop : float The unknown convergence rate tolerance. Terminates when any one of the tolerances is satisfied. gtol : float The gradient tolerance. Terminates when any one of the tolerances is satisfied. x_init : ndarray, shape (Nx,) The initial condition. x_unit_init : ndarray, shape (Nx,) The initial x scale. This scaling factor ensures that x~1. Will be updated after every outer iteration. f_obj : Callable The objective function. c_init : float, optional, default=1. The initial :math:`c` factor. Please see *Constrained Optimization and Lagrange Multiplier Methods* Chapter 3. c_growth_rate : float, optional, default=1.1, The growth rate of the :math:`c` factor. lam_init : ndarray, shape (Nh), optional, default=jnp.zeros(1), The initial :math:`\lambda` multiplier for equality constraints. No constraints by default. h_eq : Callable, optional, default=lambda x:jnp.zeros(1), The equality constraint function. Must map ``x`` to an ``ndarray`` with shape ``(Nh)``. No constraints by default. mu_init : ndarray, shape (Ng), optional, default=jnp.zeros(1), The initial :math:`\mu` multiplier for inequality constraints. No constraints by default. g_ineq : Callable, optional, default=lambda x:jnp.zeros(1), The equality constraint function. Must map ``x`` to an ``ndarray`` with shape ``(Ng)``. No constraints by default. xstop_outer : float, optional, default=1e-7 (Traced) ``x`` convergence rate of the outer augmented Lagrangian loop. Terminates when ``dx`` falls below this. gtol_outer : float, optional, default=1e-7 (Traced) Tolerance of the :math:`\nabla L` KKT condition in the outer augmented Lagrangian loop. ctol_outer : float, optional, default=1e-7 (Traced) Tolerance of the constraint KKT conditions in the outer Lagrangian loop. fstop_inner : float, optional, default=1e-7 (Traced) ``f`` convergence rate of the inner LBFGS Lagrangian loop. Terminates when ``df`` falls below this. xstop_inner : float, optional, default=0 (Traced) ``x`` convergence rate of the outer augmented Lagrangian loop. Terminates when ``dx`` falls below this. gtol_inner : float, optional, default=1e-7 (Traced) Gradient tolerance of the inner LBFGS iteration. Terminates when is satisfied. maxiter_tot : int, optional, default=10000 (Static) The maximum of the outer iteration. maxiter_inner : int, optional, default=500 (Static) The maximum of the inner iteration. max_linesearch_steps : int (Static) The maximum steps count in the LBFGS line search. verbose : int, optional, default=0 (Static) The verbosity. When >1, outputs outer iteration convergence info. c_k_safe : float, optional, default=1e15, gplus_mask=gplus_hard, Returns ------- status : dict The end state of the iteration. Contains the following entries: .. code-block:: python init_dict = { 'tot_niter' : int, # The outer iteration number 'outer_dx' : float, # The L2 norm of the change in x between the last 2 outer iterations 'inner_fin_f' : float, # The value of f at the optimum 'inner_fin_g' : ndarray, # The value of g at the optimum 'inner_fin_h' : ndarray, # The value of h at the optimum 'inner_fin_x' : ndarray, # The optimum 'inner_fin_l_aug' : float, # The value of the augmented Lagrangian objective l_k at the optimum 'grad_l_k' : ndarray, # The gradient of the augmented Lagrangian objective l_k at the optimum 'inner_fin_c' : float, # The final value of c 'inner_fin_lam' : ndarray, # The final value of lambda 'inner_fin_mu' : ndarray, # The final value of mu 'inner_fin_niter' : int, # The number of L-BFGS iterations in the last step 'inner_fin_dx_scaled' : float, # The L2 norm of the change in x between the last 2 inner L-BFGS iteration 'inner_fin_du' : float, # The L2 norm of the change in update between the last 2 inner L-BFGS iteration 'inner_fin_dl' : float, # The L2 norm of the change in f between the last 2 inner L-BFGS iteration } ''' # Has shape n_cons_ineq # gplus = lambda x, mu, c: jnp.max(jnp.array([g_ineq(x), -mu/c]), axis=0) gplus = partial(gplus_mask, g_ineq=g_ineq) grad_f = grad(f_obj) grad_g = jacrev(g_ineq) grad_h = jacrev(h_eq) if verbose>0: jax.debug.print( 'SOLVER INITIALIZED. \nginit = {g} \nviolating elements: {c}', g = g_ineq(x_init), c = jnp.sum(jnp.where(g_ineq(x_init) > 0, 1., 0)) ) # True when non-convergent. # @jit def outer_convergence_criterion(dict_in): x_k = dict_in['inner_fin_x'] # x_norm = jnp.linalg.norm(x_k) # lam_k = dict_in['inner_fin_lam'] # mu_k = dict_in['inner_fin_mu'] # grad_l = dict_in['outer_grad_l'] outer_dx = dict_in['outer_dx'] tot_niter = dict_in['tot_niter'] g_k = dict_in['inner_fin_g'] h_k = dict_in['inner_fin_h'] c_k = dict_in['inner_fin_c'] # outer_dgrad_l = dict_in['outer_dgrad_l'] # outer_dg = dict_in['outer_dg'] # outer_dh = dict_in['outer_dh'] # f_k = dict_in['inner_fin_f'] # This is the convergence condition (True when not converged yet) if verbose>1: jax.debug.print( 'OUTER CONVERGENCE CRITERIA\n'\ ' (tot_niter == 0): {x1}\n'\ ' (tot_niter < maxiter_tot): {x2}\n'\ ' (outer_dx >= xstop_outer): {x3}\n'\ ' (jnp.any(g_k >= ctol_outer) | jnp.any(jnp.abs(h_k) >= ctol_outer)): {x4}\n'\ ' (c_k <= c_k_safe): {x5}\n', x1 = (tot_niter == 0), x2 = (tot_niter < maxiter_tot), x3 = (outer_dx >= xstop_outer), x4 = (jnp.any(g_k >= ctol_outer) | jnp.any(jnp.abs(h_k) >= ctol_outer)), x5 = (c_k <= c_k_safe), ) return( (tot_niter == 0) | ( (tot_niter < maxiter_tot) # & (outer_dx >= xstop_outer * x_norm) & ( # Continue iteration when dx is significant (outer_dx >= xstop_outer) # Or when constraint violation is sufficiently strong, # because sometimes the iteration terminates before # c becomes large enough. However, when c_k exceeds # our safe limit, to prevent endless outer iteration, # disble the constraint checking. | ( (jnp.any(g_k >= ctol_outer) | jnp.any(jnp.abs(h_k) >= ctol_outer)) & (c_k <= c_k_safe) ) ) ) ) # Recursion # @jit def body_fun_augmented_lagrangian( dict_in, gtol_inner=gtol_inner, fstop_inner=fstop_inner, xstop_inner=xstop_inner ): x_km1 = dict_in['inner_fin_x'] c_k = dict_in['inner_fin_c'] lam_k = dict_in['inner_fin_lam'] mu_k = dict_in['inner_fin_mu'] f_km1 = dict_in['inner_fin_f'] g_km1 = dict_in['inner_fin_g'] h_km1 = dict_in['inner_fin_h'] x_unit = dict_in['x_unit'] # normalizing x with the sln from the previous step is not great either # abs_x_km1 = jnp.abs(x_km1) # mode_scaling = jnp.where(abs_x_km1>1e-5, abs_x_km1, 1e-5) # x_unit = x_unit_in * mode_scaling # grad_l_val_km1 = dict_in['outer_grad_l'] # Eq (10) on p160 of Constrained Optimization and Multiplier Method l_k = lambda x, x_unit=x_unit, mu_k=mu_k, c_k=c_k: ( f_obj(x*x_unit) + lam_k@h_eq(x*x_unit) + mu_k@gplus(x*x_unit, mu_k, c_k) + c_k/2 * ( jnp.sum(h_eq(x*x_unit)**2) + jnp.sum(gplus(x*x_unit, mu_k, c_k)**2) ) ) # Solving a stage of the problem x_k_raw, val_l_k, grad_l_k, niter_inner_k, dx_k, du_k, dL_k = run_opt_lbfgs( x_km1/x_unit, l_k, maxiter_inner, fstop_inner, xstop_inner, gtol_inner, max_linesearch_steps=max_linesearch_steps, verbose=verbose ) x_k = x_k_raw*x_unit x_norm = jnp.linalg.norm(x_k) x_unit_new = jnp.where(x_norm!=0, x_norm, 1.) f_k = f_obj(x_k) g_k = g_ineq(x_k) h_k = h_eq(x_k) gp_k = gplus(x_k, mu_k, c_k) # ----- Upsdating c and the multipliers # If constraints are sufficiently # satisfied, or c is too large, # or if the inner hasn't converged, # update the multiplier only. # otherwise, update c only. update_multiplier = ( ( # if all constraints are satisfied, jnp.all(g_k < ctol_outer) & jnp.all(jnp.abs(h_k) < ctol_outer) ) # | (c_k >= c_k_safe) | (niter_inner_k >= maxiter_inner) ) c_k_new = jnp.where(update_multiplier, c_k, c_k * c_growth_rate) lam_k = lam_k + c_k * h_k mu_k = mu_k + c_k * gp_k df = jnp.linalg.norm(f_km1 - f_k) dg = jnp.linalg.norm(g_km1 - g_k) dh = jnp.linalg.norm(h_km1 - h_k) if verbose>1: jax.debug.print( 'OUTER: \n'\ ' Iteration: {tot_niter}/{maxiter_tot}\n'\ ' f : {f}\n'\ ' g : {gmin}, {gmax}\n'\ ' g+ : {gpmin}, {gpmax}\n'\ ' h : {hmin}, {hmax}\n'\ ' |grad f|: {xx}\n'\ ' |grad g|: {xg}\n'\ ' |grad h|: {xh}\n'\ ' mu : {mu1}, {mu2}\n'\ ' dmu : {dmu1}, {dmu2}\n'\ ' lam : {lam1}, {lam2}\n'\ ' dlam : {dlam1}, {dlam2}\n'\ ' Outer stopping criteria (False = satisfied)\n'\ ' |x_k - x_km1| >= xstop_outer: {b}\n'\ ' outer_dx = {outer_dx}\n'\ ' outer_df = {outer_df}\n'\ ' outer_dg = {outer_dg}\n'\ ' outer_dh = {outer_dh}\n'\ ' xstop_outer = {xstop_outer}\n'\ # ' grad_l_val: {x}, d_grad_l_val: {dx}\n'\ ' inner iter #: {z}\n'\ ' c_k: {c_k}', f=f_k, gmin=_print_min_blank(g_k), gmax=_print_max_blank(g_k), gpmin=_print_min_blank(gp_k), gpmax=_print_max_blank(gp_k), hmin=_print_min_blank(h_k), hmax=_print_max_blank(h_k), c_k=c_k, mu1=_print_min_blank(mu_k), mu2=_print_max_blank(mu_k), lam1=_print_min_blank(lam_k), lam2=_print_max_blank(lam_k), dmu1=_print_min_blank(c_k * gp_k), dmu2=_print_max_blank(c_k * gp_k), dlam1=_print_min_blank(c_k * h_k), dlam2=_print_max_blank(c_k * h_k), xx=jnp.linalg.norm(grad_f(x_k)), xg=jnp.linalg.norm(grad_g(x_k)), xh=jnp.linalg.norm(grad_h(x_k)), z=niter_inner_k, tot_niter=dict_in['tot_niter']+niter_inner_k, maxiter_tot=maxiter_tot, outer_dx=jnp.linalg.norm(x_k - x_km1), outer_df=df, outer_dg=dg, outer_dh=dh, xstop_outer=xstop_outer * jnp.linalg.norm(x_k), b=(jnp.linalg.norm(x_k - x_km1) >= xstop_outer), ) # There is the possibility that the dict_out = { 'tot_niter': dict_in['tot_niter']+niter_inner_k, 'outer_dx': jnp.linalg.norm(x_k - x_km1), 'outer_df': df, 'outer_dg': dg, 'outer_dh': dh, 'inner_fin_f': f_k, 'inner_fin_g': g_k, 'inner_fin_h': h_k, 'inner_fin_x': x_k, 'inner_fin_l_aug': val_l_k, 'inner_fin_grad_l_aug': jnp.linalg.norm(grad_l_k), 'inner_fin_c': c_k_new, 'inner_fin_lam': lam_k, 'inner_fin_mu': mu_k, 'inner_fin_niter': niter_inner_k, 'inner_fin_dx_scaled': dx_k, 'inner_fin_du': du_k, 'inner_fin_dl': dL_k, # The scaling factor for the next iteration 'x_unit': x_unit_new, } return(dict_out) init_dict = { 'tot_niter': 0, # Changes in x between the kth and k-1th iteration 'outer_dx': 0., 'outer_df': 0., 'outer_dg': 0., 'outer_dh': 0., 'inner_fin_f': f_obj(x_init), # Value of f, g, h after the kth iteration 'inner_fin_g': g_ineq(x_init), 'inner_fin_h': h_eq(x_init), 'x_unit': 1., 'inner_fin_x': x_init, 'inner_fin_l_aug': 0., 'inner_fin_grad_l_aug': 0., 'inner_fin_c': c_init, 'inner_fin_lam': lam_init, 'inner_fin_mu': mu_init, 'inner_fin_niter': 0, 'inner_fin_dx_scaled': 0., 'inner_fin_du': 0., 'inner_fin_dl': 0., } # Apply a looser tolerance for most of the iteration result_dict = while_loop( cond_fun=outer_convergence_criterion, body_fun=body_fun_augmented_lagrangian, init_val=init_dict, ) # Apply tight tolerance in the last iteration result_dict = body_fun_augmented_lagrangian( result_dict, gtol_inner=gtol_inner_last, fstop_inner=fstop_inner_last, xstop_inner=xstop_inner_last ) return(result_dict)# Changes in f, g, h between the kth and k-1th iteration
def _print_min_blank(a): return jnp.min(a) if a.size > 0 else jnp.nan def _print_max_blank(a): return jnp.max(a) if a.size > 0 else jnp.nan