Source code for quadcoil.wrapper

import quadcoil.quantity
from quadcoil.quantity.quantity import _Quantity
import jax.numpy as jnp
from jax import jit
from functools import partial
from . import max_lse
import warnings


def _resolve_quadpoints(
    nfp,
    Bnormal_plasma,
    plasma_quadpoints_phi,
    plasma_quadpoints_theta,
    winding_quadpoints_phi,
    winding_quadpoints_theta,
    quadpoints_phi,
    quadpoints_theta,
    plasma_coil_distance,
    winding_dofs,
):
    """
    Resolves default quadrature points and validates winding surface inputs.

    Returns
    -------
    plasma_quadpoints_phi, plasma_quadpoints_theta,
    winding_quadpoints_phi, winding_quadpoints_theta,
    quadpoints_phi, quadpoints_theta
    """
    if Bnormal_plasma is not None:
        if jnp.ndim(Bnormal_plasma) != 2:
            raise TypeError(
                f'Bnormal_plasma must be a 2D array-like, got shape {jnp.shape(Bnormal_plasma)}.'
            )
        if plasma_quadpoints_phi is not None or plasma_quadpoints_theta is not None:
            warnings.warn(
                'Bnormal_plasma provided, inputs for plasma_quadpoints_phi '
                'and plasma_quadpoints_theta will be ignored.'
            )
        plasma_quadpoints_phi = jnp.linspace(0, 1/nfp, Bnormal_plasma.shape[0], endpoint=False)
        plasma_quadpoints_theta = jnp.linspace(0, 1, Bnormal_plasma.shape[1], endpoint=False)
    else:
        if plasma_quadpoints_phi is None:
            plasma_quadpoints_phi = jnp.linspace(0, 1/nfp, 32, endpoint=False)
        if plasma_quadpoints_theta is None:
            plasma_quadpoints_theta = jnp.linspace(0, 1, 34, endpoint=False)
    if winding_quadpoints_phi is None:
        winding_quadpoints_phi = jnp.linspace(0, 1, 32*nfp, endpoint=False)
    if winding_quadpoints_theta is None:
        winding_quadpoints_theta = jnp.linspace(0, 1, 34, endpoint=False)
    if quadpoints_phi is None:
        quadpoints_phi = winding_quadpoints_phi[:len(winding_quadpoints_phi)//nfp]
    if quadpoints_theta is None:
        quadpoints_theta = winding_quadpoints_theta
    if plasma_coil_distance is None and winding_dofs is None:
        raise ValueError('At least one of plasma_coil_distance and winding_dofs must be provided.')
    if plasma_coil_distance is not None and winding_dofs is not None:
        raise ValueError('Only one of plasma_coil_distance and winding_dofs can be provided.')
    return (
        plasma_quadpoints_phi, plasma_quadpoints_theta,
        winding_quadpoints_phi, winding_quadpoints_theta,
        quadpoints_phi, quadpoints_theta,
    )

[docs] def get_quantity(func_name: str): r''' Takes a string as input and returns the function with the same name in ``quadcoil.quantity``. throws an error if a function with the same name cannot be found. Used to parse ``str`` in ``quadcoil.quadcoil``. Parameters ---------- func_name : str Name of the function to find. Returns ------- callable A callable with the same name in ``quadcoil.quantity``. ''' if hasattr(quadcoil.quantity, func_name): func = getattr(quadcoil.quantity, func_name) if isinstance(func, _Quantity): return func else: raise ValueError( f'\'{func_name}\' exists in quadcoil.quantity but is '\ 'not properly implemented as an instance of _Quantity. '\ f'Instead, it\'s of type: {str(func)}') else: raise ValueError(f'\'{func_name}\' not found in quadcoil.quantity.')
[docs] def merge_callables( callables, merge_constraints=False, smoothing=None, smoothing_params=None ): r''' Merge a tuple of ``callable``s into one that takes 2 arguments (all functions in the ``quadcoil.objective`` do), by flattening and concatenating their outputs into an 1D ``array``. Used to construct constraints. Parameters ---------- callables : tuple of callables The callables to merge. Returns ------- callable A callable that returns a 1D ``array`` ''' @partial(jit, static_argnums=(2,)) def merged_fn( qp, dofs, callables=callables, merge_constraints=merge_constraints, smoothing=smoothing, smoothing_params=smoothing_params ): outputs = [] for fn in callables: if fn is not None: outputs.append(fn(qp, dofs)) # Convert scalars to 1D arrays outputs = [jnp.atleast_1d(out) for out in outputs] # Flatten any array outputs outputs = [out.ravel() for out in outputs] # Concatenate into a single 1D array if len(outputs) == 0: return jnp.zeros(0) # There is an option to merge callables if merge_constraints: # an option to merge multiple inequality constraints into one if smoothing=='approx': return jnp.array([ max_lse(jnp.array(outputs), smoothing_params['lse_epsilon']), ]) else: raise AttributeError( 'Merging constraints is only available ' 'when using smoothing==\'approx\'.' ) else: return jnp.concatenate(outputs, axis=0) return merged_fn
def _add_quantity(name, unit, use_case, smoothing, smoothing_params): ''' Finds a quantity from quadcoil.quantity, unpacks and scales it. Also checks compatibility. Parameters ---------- objective_name : str The name of the quantity to find unit : scalar or None The unit of the quantity use_case : str The current type of use case (``'f'``, ``'=='``, ``'<='`` or ``'>='``). Returns ------- scaled_f_impl : Callable The "under-the-hood" implementation of the quantity g_ineq, h_eq : List[Callable] The list of inequality and equality constraints. unit_callable : Callable(qp: QuadcoilParams) The unit as a Callable, in case the scaling factor of the quantity need to be used later, and the scaling mode is set to ``None``. The only place where this is currently used is constraint value scaling. scaled_slack_dofs : dict{str: None, Tuple or Callable} The accumulator after adding auxiliary variables. smoothing : str Smoothing mode. smoothing_params Smoothing parameters. ''' quantity = get_quantity(name) scaled_f_impl = quantity.scaled_f_impl(smoothing=smoothing, smoothing_params=smoothing_params) scaled_g_ineq_impl = quantity.scaled_g_ineq_impl(smoothing=smoothing) scaled_h_eq_impl = quantity.scaled_h_eq_impl(smoothing=smoothing) scaled_slack_dofs = quantity.scaled_slack_dofs_init(smoothing=smoothing) compatibility = quantity.compatibility # Checking compatibility if use_case not in compatibility: if use_case == 'f': raise ValueError(f'{name} cannot be used as an objective term.') elif use_case in ['<=', '==', '>=']: raise ValueError(f'{name} cannot be used in a {use_case} constraint.') else: raise ValueError(f'{use_case} is not a valid type of constraint.') # Perform scaling # When the unit of a quantity is left blank, # automatically scale that quantity by its value # with only net poloidal/toroidal currents. # To accommodate this with the shortest amount of code, # we make unit a callable regardless it's a scalar or None. if unit is None: c0_impl = quantity.__call__ unit_callable = lambda qp: c0_impl(qp, {'phi': jnp.zeros(qp.ndofs)}) elif jnp.isscalar(unit): unit_callable = lambda qp: jnp.abs(unit) else: raise TypeError( f'Unit for {name} has incorrect type. The supported '\ f'types are scalar and None. The provided value is a {type(unit)}.' ) # Apply units (now known) to a function with signature # Callable(qp: QuadcoilParams, dofs: dict, unit: float) def apply_unit(fun, unit_callable=unit_callable): return lambda qp, dofs, fun=fun, unit_callable=unit_callable:\ fun(qp, dofs, unit=unit_callable(qp)) # Scaling value val_scaled = apply_unit(scaled_f_impl) # Scaling auxiliary variables' initial values scaled_slack_dofs_out = {} if not (scaled_slack_dofs is None): # We loop over all auxiliary variables' init function, # and substitute in the value of unit with the presently known value. for key in scaled_slack_dofs: scaled_slack_dofs_new = apply_unit(scaled_slack_dofs[key]) scaled_slack_dofs_out[key] = scaled_slack_dofs_new # Scaling g and h if scaled_g_ineq_impl is not None: g_ineq_list_scaled = [apply_unit(scaled_g_ineq_impl)] else: g_ineq_list_scaled = [] if scaled_h_eq_impl is not None: h_eq_list_scaled = [apply_unit(scaled_h_eq_impl)] else: h_eq_list_scaled = [] return ( val_scaled, g_ineq_list_scaled, h_eq_list_scaled, unit_callable, scaled_slack_dofs_out ) def _parse_objectives( objective_name, objective_unit, objective_weight, smoothing, smoothing_params, ): r''' Parses a tuple of ``str`` quantities names (or a single ``str`` for one objective only), an array of weights, and a tuple of units into a ``callable`` that outputs the weighted sum of the corresponding functions in ``quadcoil.objectives``. Parameters ---------- objective_name : str or tuple of str The name of the quantities to combine into an objective funtion. The corresponding functions must all return scalars. objective_weight : float or array of float, optional, default=1 The weight(s) of each objective terms. objective_unit : float or tuple of float, optional, default=None The normalization factor of each objective term. If set to ``None`` or a `tuple` with ``None``, then the corresponding objective will be normalized with its value when the current is uniform. (or in other words, :math:`\Phi_{sv}=0`). smoothing : str Smoothing mode. smoothing_params Smoothing parameters. Returns ------- f_tot : callable(QuadcoilParams, ndarray) The weighted objective function that maps a QuadcoilParams and an array of :math:`\Phi_{sv}` Fourier coefficients into a scalar. g_list : List[Callable or None] h_list : List[Callable or None] scaled_slack_dofs : dict{str: None, Tuple or Callable} ''' if isinstance(objective_name, str): objective_name = (objective_name,) objective_unit = (objective_unit,) objective_weight = jnp.array([1.,]) if len(objective_name) != len(objective_weight): # or len(objective_name) != len(objective_unit): raise ValueError('objective, objective_weight and objective_unit must have the same length.') scaled_slack_dofs = {} f_list = [] g_list = [] h_list = [] for i in range(len(objective_name)): ( val_scaled, g_ineq_list_scaled, h_eq_list_scaled, _, scaled_slack_dofs_i, ) = _add_quantity( name=objective_name[i], unit=objective_unit[i], use_case='f', smoothing=smoothing, smoothing_params=smoothing_params, ) scaled_slack_dofs = scaled_slack_dofs | scaled_slack_dofs_i f_list.append(val_scaled) g_list = g_list + g_ineq_list_scaled h_list = h_list + h_eq_list_scaled def f_tot( qp, dofs, f_list=f_list, objective_weight=objective_weight ): out = 0 for i in range(len(f_list)): out = out + f_list[i](qp, dofs) * objective_weight[i] return out return jit(f_tot), g_list, h_list, scaled_slack_dofs def _parse_constraints( constraint_name, constraint_type, constraint_unit, constraint_value, smoothing, smoothing_params, ): r''' Parses a series of tuples and arrays specifying the quantities, types (``'>=', '<=', '=='``) Parameters ---------- constraint_name : tuple of str A tuple of quantity names in ``quadcoil.objective``. The corresponding quantity can be both scalars or a vector fields (``ndarray``). constraint_type : tuple of str A tuple of strings. Must consists of ``'>=', '<=', '=='`` only. constraint_unit : array/tuple of float, may contain None A tuple of float/ints giving the constraints' order of magnitude. If a corresponding element is None, will normalize by the value of the objective when the poloidal/toroidal current is uniform. constraint_value : array(float) An array of constraint thresholds. smoothing : str Smoothing mode. smoothing_params Smoothing parameters. Returns ------- g_list : List[Callable or None] h_list : List[Callable or None] A list of ``Callable`` for the inequality/equality constraints. Returns will be greater than 0 when the constraints are violated. scaled_slack_dofs : dict{str: None, Tuple or Callable} A dictionary containing the shapes of the auxiliary variables, or the \ ``Callables`` required to calculate them ''' # Outputs g_ineq and h_ineq for the augmented lagrangian solver: # min f(x) # subject to # h(x) = 0, g(x) <= 0 # First, we parse the constraints from strings into functions. n_cons_total = len(constraint_name) # Detecting input shape issues if ( n_cons_total != len(constraint_type) or n_cons_total != len(constraint_unit) or n_cons_total != len(constraint_value) ): raise ValueError('constraint_name, constraint_type, '\ 'and constraint_value must have the same length.') # Contains a list of callables # that maps (QuadcoilParams, cp_mn) # to arrays or scalars # that are =0 or <=0 when the constraint is satisfied. g_ineq_list = [] h_eq_list = [] scaled_slack_dofs = {} for i in range(n_cons_total): cons_type_i = constraint_type[i] cons_val_i = constraint_value[i] ( cons_func_i_scaled, aux_g_ineq_i_scaled, aux_h_eq_i_scaled, unit_callable_i, scaled_slack_dofs_i ) = _add_quantity( name=constraint_name[i], unit=constraint_unit[i], use_case=cons_type_i, smoothing=smoothing, smoothing_params=smoothing_params, ) scaled_slack_dofs = scaled_slack_dofs | scaled_slack_dofs_i g_ineq_list = g_ineq_list + aux_g_ineq_i_scaled h_eq_list = h_eq_list + aux_h_eq_i_scaled # Flipping the sign of >= constraints. if cons_type_i == '>=': sign_i = -1 else: sign_i = 1 # This is the proper way to generate a list of # callable without running into the lambda reference # issue. def cons_func_centered_i( qp, dofs, cons_func_i_scaled=cons_func_i_scaled, cons_val_i=cons_val_i, unit_callable_i=unit_callable_i, sign=sign_i, smoothing=smoothing, ): # Scaling and centering constraints # Now all constraints are <= constraints out_array = sign * (cons_func_i_scaled(qp, dofs) - cons_val_i/unit_callable_i(qp)) # If smoothing mode is approx, then process the resulting array with logsumexp # to replace pointwise constraints with a single constraint. if smoothing == 'approx' and (not jnp.isscalar(out_array)): out_array = max_lse(out_array, smoothing_params['lse_epsilon']) return out_array # Creating a list of function in h and g. if cons_type_i == '==': h_eq_list.append(cons_func_centered_i) elif cons_type_i in ['<=', '>=']: g_ineq_list.append(cons_func_centered_i) else: raise ValueError('Constraint type can only be \"<=\", \">=\", or \"==\"') # # Merging the list of function into one # # callable for both g and h. # # merge_callables already contains jit. # g_ineq = merge_callables(g_ineq_list) # h_eq = merge_callables(h_eq_list) return g_ineq_list, h_eq_list, scaled_slack_dofs