Source code for quadcoil.io.jax

from jax import custom_jvp
from functools import partial
from quadcoil import quadcoil
import jax.numpy as jnp
import warnings
import jax
# A list of differentiable arguments of QUADCOIL.
# Will be ignored from the kwargs of gen_quadcoil_for_diff
QUADCOIL_DIFF_ARGS = [
    'plasma_dofs',
    'net_poloidal_current_amperes',
    'net_toroidal_current_amperes',
    'Bnormal_plasma',
    'plasma_coil_distance',
    'winding_dofs',
    'objective_weight',
    'constraint_value',
]

[docs] def gen_quadcoil_for_diff(**kwargs): # Generate a quadcoil call taking only: # plasma_dofs, # net_poloidal_current_amperes:float, # net_toroidal_current_amperes:float, # Bnormal_plasma, # plasma_coil_distance:float, # winding_dofs, # objective_weight, # constraint_value, # And its custom_jvp using partal. # Copy the kwargs, remove the variables that are differentiable partial_kwargs = {} for key in kwargs.keys(): if key in QUADCOIL_DIFF_ARGS: warnings.warn(key + " found in kwargs. This '\ 'is a differentiable argument and will be overridden.") else: partial_kwargs[key] = kwargs[key] # A partial of quadcoil taking only the # differentiable arguments, and does not # output derivatives quadcoil_values = partial( quadcoil, value_only=True, **partial_kwargs ) # A partial of quadcoil taking only the # differentiable arguments, and preserves # the full output. quadcoil_full = partial( quadcoil, value_only=False, **partial_kwargs ) # A partial of quadcoil taking only the # differentiable arguments, and outputs # value in a simpler dict. @partial(custom_jvp) def quadcoil_for_diff( plasma_dofs, net_poloidal_current_amperes, net_toroidal_current_amperes, Bnormal_plasma, plasma_coil_distance, winding_dofs, objective_weight, constraint_value, ): out_dict, qp, cp_mn, solve_results = quadcoil_values( plasma_dofs=plasma_dofs, net_poloidal_current_amperes=net_poloidal_current_amperes, net_toroidal_current_amperes=net_toroidal_current_amperes, Bnormal_plasma=Bnormal_plasma, plasma_coil_distance=plasma_coil_distance, winding_dofs=winding_dofs, objective_weight=objective_weight, constraint_value=constraint_value, ) out_dict_simple = {} for key_i in out_dict.keys(): out_dict_simple[key_i] = out_dict[key_i]['value'] # The rest of the outputs are not differentiable # These outputs are currently commented out, but that may change # after DESC ``_Objective`` can retain memory. return out_dict_simple @quadcoil_for_diff.defjvp def quadcoil_for_diff_jvp(primals, tangents): ''' The ``custom_jvp`` of ``quadcoil_for_diff``. ''' ( plasma_dofs, net_poloidal_current_amperes, net_toroidal_current_amperes, Bnormal_plasma, plasma_coil_distance, winding_dofs, objective_weight, constraint_value, ) = primals ( plasma_dofs_dot, net_poloidal_current_amperes_dot, net_toroidal_current_amperes_dot, Bnormal_plasma_dot, plasma_coil_distance_dot, winding_dofs_dot, objective_weight_dot, constraint_value_dot, ) = tangents out_dict_full, _, _, _ = quadcoil_full( plasma_dofs=plasma_dofs, net_poloidal_current_amperes=net_poloidal_current_amperes, net_toroidal_current_amperes=net_toroidal_current_amperes, Bnormal_plasma=Bnormal_plasma, plasma_coil_distance=plasma_coil_distance, winding_dofs=winding_dofs, objective_weight=objective_weight, constraint_value=constraint_value, ) # Recreate primal output: out_dict_primal = {} # Initialize tangent outputs out_dict_dot = {} for key_i in out_dict_full.keys(): out_dict_primal[key_i] = out_dict_full[key_i]['value'] # The shape of the second layer differs depending on the inputs. # we handle them individually. jvp_i = 0 if plasma_dofs_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dplasma_dofs'] * plasma_dofs_dot) if net_poloidal_current_amperes_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dnet_poloidal_current_amperes'] * net_poloidal_current_amperes_dot) if net_toroidal_current_amperes_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dnet_toroidal_current_amperes'] * net_toroidal_current_amperes_dot) if Bnormal_plasma_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dBnormal_plasma'] * Bnormal_plasma_dot) if plasma_coil_distance_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dplasma_coil_distance'] * plasma_coil_distance_dot) if winding_dofs_dot is not None: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dwinding_dofs'] * winding_dofs_dot) if objective_weight_dot is not None: # Converting an empty list/tuple to an array will produce a NaN. if not jnp.isscalar(objective_weight_dot): jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dobjective_weight'] * jnp.array(objective_weight_dot)) if constraint_value_dot is not None: if len(constraint_value_dot) > 0: jvp_i += jnp.sum(out_dict_full[key_i]['grad']['df_dconstraint_value'] * jnp.array(constraint_value_dot)) out_dict_dot[key_i] = jvp_i # The rest of the outputs are not differentiable # These outputs are currently commented out, but that may change # after DESC ``_Objective`` can retain memory. # qp_dot = None # cp_mn_dot = None # solve_results_dot = None # return (out_dict_primal, qp, cp_mn, solve_results), (out_dict_dot, qp, cp_mn, solve_results) return (out_dict_primal), (out_dict_dot) return quadcoil_values, quadcoil_for_diff