import jax.numpy as jnp
import numpy as np # Don't panic, it's for type checking
from jax import jit, custom_vjp
from jax.tree_util import tree_reduce
import jax.nn as jnn
import lineax as lx
[docs]
def tree_len(pytree):
return tree_reduce(
lambda acc, leaf: acc + jnp.atleast_1d(leaf).size, pytree, initializer=0
)
[docs]
def is_ndarray(arr, n=1):
return isinstance(arr, (np.ndarray, jnp.ndarray)) and arr.ndim == 1
[docs]
def sin_or_cos(x, mode):
r'''
Scans a pair of arrays, ``x`` and ``mode``. Where ``mode==1``, return ``jnp.sin(x)``.
Otherwise return ``jnp.cos(x)``. Used in inverse Fourier Transforms.
Parameters
----------
x : ndarray
The data.
mode : ndarray
The choice of trigonometry functions.
Returns
-------
ndarray
'''
return jnp.where(mode==1, jnp.sin(x), jnp.cos(x))
@jit
def norm_helper(vec):
r'''
Calculates :math:`|v|` and :math:`1/|v|` for a vector field
on a 2d surface.
Parameters
----------
vec : ndarray, shape (Nx, Ny, ..., 3)
The vector field
Returns
-------
normN_prime_2d : ndarray, shape (Nx, Ny, ...)
The vector field's length, :math:`|v|`
inv_normN_prime_2d: ndarray, shape (Nx, Ny, ...)
1/the vector field's length, :math:`1/|v|`
'''
# Length of the non-unit WS normal vector |N|,
# its inverse (1/|N|) and its inverse's derivatives
# w.r.t. phi(phi) and theta
# Not to be confused with the normN (plasma surface Jacobian)
# in Regcoil.
norm = jnp.linalg.norm(vec, axis=-1) # |N|
inv_norm = 1/norm # 1/|N|
return norm, inv_norm
@jit
def project_arr_coord(
operator,
unit1, unit2, unit3):
r'''
Project an array of vector fields on a 2d surface
in a given basis, ``unit1, unit2, unit3``.
Parameters
----------
operator : ndarray, shape (n_phi, n_theta, 3, ...)
An array of (n_phi, n_theta, 3) vector fields.
``operator.shape[:3]`` must be ``(n_phi, n_theta, 3)``.
Otehrwise the shape is flexible.
unit1 : ndarray, shape (n_phi, n_theta, 3)
Basis vector 1 where the vector field is sampled.
unit2 : ndarray, shape (n_phi, n_theta, 3)
Basis vector 2 where the vector field is sampled.
unit3 : ndarray, shape (n_phi, n_theta, 3)
Basis vector 3 where the vector field is sampled.
Returns
-------
Outputs: ndarray, shape (n_phi, n_theta, 3, ...)
'''
# Memorizing shape of the last dimensions of the array
len_phi = operator.shape[0]
len_theta = operator.shape[1]
operator_shape_rest = list(operator.shape[3:])
operator_reshaped = operator.reshape((len_phi, len_theta, 3, -1))
# Calculating components
# shape of operator is
# (n_grid_phi, n_grid_theta, 3, n_dof, n_dof)
# We take the dot product between K and unit vectors.
operator_1 = jnp.sum(unit1[:,:,:,None]*operator_reshaped, axis=2)
operator_2 = jnp.sum(unit2[:,:,:,None]*operator_reshaped, axis=2)
operator_3 = jnp.sum(unit3[:,:,:,None]*operator_reshaped, axis=2)
operator_1_nfp_recovered = operator_1.reshape([len_phi, len_theta] + operator_shape_rest)
operator_2_nfp_recovered = operator_2.reshape([len_phi, len_theta] + operator_shape_rest)
operator_3_nfp_recovered = operator_3.reshape([len_phi, len_theta] + operator_shape_rest)
operator_comp_arr = jnp.stack([
operator_1_nfp_recovered,
operator_2_nfp_recovered,
operator_3_nfp_recovered
], axis=2)
return(operator_comp_arr)
@jit
def project_arr_cylindrical(
gamma,
operator,
):
r'''
Project a stack of vector fields onto a cylindrical
coordinate for a given set of coordinate points.
Parameters
----------
gamma : ndarray, shape (n_phi, n_theta, 3)
The location of the coordinate points
where the field is sampled in x, y, z.
operator : ndarray, shape (n_phi, n_theta, 3, ...)
A stack of (n_phi, n_theta, 3) vector fields.
``operator.shape[:3]`` must be ``(n_phi, n_theta, 3)``.
Otherwise the shape is flexible.
Returns
-------
Outputs: ndarray, shape (n_phi, n_theta, 3, ...)
'''
# Keeping only the x, y components
r_unit = jnp.zeros_like(gamma)
r_unit = r_unit.at[:, :, -1].set(0)
# Calculating the norm and dividing the x, y components by it
r_unit = r_unit.at[:, :, :-1].set(gamma[:, :, :-1] / jnp.linalg.norm(gamma, axis=2)[:, :, None])
# Setting Z unit to 1
z_unit = jnp.zeros_like(gamma)
z_unit = z_unit.at[:,:,-1].set(1)
phi_unit = jnp.cross(z_unit, r_unit)
return(
project_arr_coord(
operator,
unit1=r_unit,
unit2=phi_unit,
unit3=z_unit,
)
)
[docs]
def max_lse(x, epsilon, **kwargs):
approx = epsilon * jnn.logsumexp(a=x/epsilon, **kwargs)
return approx
[docs]
def abs_lse(x, epsilon, **kwargs):
x_stacked = jnp.stack((x, -x), x.ndim)
return max_lse(
x_stacked, epsilon, axis=-1, **kwargs
)
[docs]
def linf_lse(x, epsilon, **kwargs):
abs = abs_lse(x, epsilon, **kwargs)
return max_lse(abs, epsilon, **kwargs)
# Custom lineax routine that removes nans.
# Used to make autodiff more robust to floating point error.
@custom_vjp
def safe_linear_solve(A, b):
operator = lx.MatrixLinearOperator(A)
solver = lx.AutoLinearSolver(well_posed=False)
solution = lx.linear_solve(operator, b, solver)
return solution.value
[docs]
def safe_linear_solve_fwd(A, b):
x = safe_linear_solve(A, b)
return x, (A, x)
[docs]
def safe_linear_solve_bwd(res, g):
A, x = res
# Clean the gradient before using it
g = jnp.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0)
# Solve A^T v = g for the VJP
operator = lx.MatrixLinearOperator(A.T)
solver = lx.AutoLinearSolver(well_posed=False)
v = lx.linear_solve(operator, g, solver).value
v = jnp.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
# dL/dA = -v @ x^T, dL/db = v
dA = -jnp.outer(v, x)
db = v
return (dA, db)
safe_linear_solve.defvjp(safe_linear_solve_fwd, safe_linear_solve_bwd)