import jax.numpy as jnp
import jax
# import matplotlib.pyplot as plt
from jax import jit, lax, vmap
from jax.lax import scan
from functools import partial
from .surface import dof_to_rz_op, SurfaceRZFourierJAX
from .math_utils import safe_linear_solve
@partial(jit, static_argnames=['nfp', 'stellsym', 'mpol', 'ntor', 'lam_tikhonov',])
def fit_surfacerzfourier(
phi_grid, theta_grid,
r_fit, z_fit,
nfp:int, stellsym:bool,
mpol:int=5, ntor:int=5,
lam_tikhonov=0.,
custom_weight=None,):
# Fits r and z with a surface
A_lstsq, m_2_n_2 = dof_to_rz_op(
theta_grid=theta_grid,
phi_grid=phi_grid,
nfp=nfp,
stellsym=stellsym,
mpol=mpol,
ntor=ntor
)
b_lstsq = jnp.concatenate([r_fit[:, :, None], z_fit[:, :, None]], axis=2)
# A and b of the lstsq problem.
# A_lstsq is a function of phi_grid and theta_grid
# b_lstsq is differentiable.
# A_lstsq has shape: [nphi, ntheta, 2(rz), ndof]
# b_lstsq has shape: [nphi, ntheta, 2(rz)]
if custom_weight is not None:
if custom_weight.shape != A_lstsq.shape[:2]:
raise ValueError(
'custom_weight must have the shape '
+ str(A_lstsq.shape[:2])
+ ', but it has shape '
+ str(custom_weight.shape)
)
A_lstsq = A_lstsq * custom_weight[:, :, None, None]
b_lstsq = b_lstsq * custom_weight[:, :, None]
A_lstsq = A_lstsq.reshape(-1, A_lstsq.shape[-1])
b_lstsq = b_lstsq.flatten()
# tikhonov regularization for higher harmonics
lam = lam_tikhonov * jnp.diag(m_2_n_2)
solution = safe_linear_solve(
A=A_lstsq.T.dot(A_lstsq) + lam,
b=A_lstsq.T.dot(b_lstsq),
)
return solution
# An approximation for unit normal.
# and include the endpoints
gen_rot_matrix = lambda theta: jnp.array([
[jnp.cos(theta), -jnp.sin(theta), 0],
[jnp.sin(theta), jnp.cos(theta), 0],
[0, 0, 1]
])
# @partial(jit, static_argnames=[
# 'nfp', 'stellsym',
# 'mpol', 'ntor',
# ])
[docs]
def gen_winding_surface_offset(
plasma_gamma, d_expand,
nfp, stellsym,
unitnormal=None,
mpol=10, ntor=10,
):
# A simple winding surface generator with less intermediate quantities.
# only works for large offset distances, where center (from the unweighted
# avg of the quadrature points' rz coordinate) of the offset surface's rz cross sections
# lay within the cross sections.
theta = 2 * jnp.pi / nfp
rotation_matrix = gen_rot_matrix(theta)
# Approximately calculating the normal vector. Alternatively, the normal
# can be provided, but this will make the Jacobian matrix larger and lead to longer compile time.
if unitnormal is None:
xyz_rotated = plasma_gamma[0, :, :] @ rotation_matrix.T
plasma_gamma_phi_rolled = jnp.append(plasma_gamma[1:, :, :], xyz_rotated[None, :, :], axis=0)
delta_phi = plasma_gamma_phi_rolled - plasma_gamma
delta_theta = jnp.roll(plasma_gamma, 1, axis=1) - plasma_gamma
normal_approx = jnp.cross(delta_theta, delta_phi)
unitnormal = normal_approx / jnp.linalg.norm(normal_approx, axis=-1)[:,:,None]
# Copy the next field period
if stellsym:
# If stellsym, then only use half of the field period for surface fitting
len_phi = plasma_gamma.shape[0]//2
plasma_gamma_expand = (
plasma_gamma[:len_phi]
+ unitnormal[:len_phi] * d_expand)
else:
plasma_gamma_expand = plasma_gamma + unitnormal * d_expand
# The original uniform offset. Has self-intersections.
# Tested to be differentiable.
r_expand = jnp.sqrt(plasma_gamma_expand[:, :, 1]**2 + plasma_gamma_expand[:, :, 0]**2)
phi_expand = jnp.arctan2(plasma_gamma_expand[:, :, 1], plasma_gamma_expand[:, :, 0]) / jnp.pi / 2
theta_expand = jnp.linspace(0, 1, plasma_gamma.shape[1], endpoint=False)[None, :] + jnp.ones_like(phi_expand)
z_expand = plasma_gamma_expand[:, :, 2]
# gamma_and_scalar_field_to_vtk(weight_remove_invalid[:, :, None] * plasma_gamma_expand, theta_atan, 'ws_new_to_fit.vts')
dofs_expand = fit_surfacerzfourier(
mpol=mpol,
ntor=ntor,
theta_grid=theta_expand, # theta_interp
phi_grid=phi_expand,
r_fit=r_expand,
z_fit=z_expand,
nfp=nfp, stellsym=stellsym,
lam_tikhonov=0.,
)
return(dofs_expand)
def _get_line_intersection(p0, p1, p2, p3):
# Detects if two line segments given by
# p0 (x, y), p1 (x, y);
# p1 (x, y), p2 (x, y)
# intersects.
s1 = p1 - p0
s2 = p3 - p2
denom = -s2[0] * s1[1] + s1[0] * s2[1]
# Preventing division by zero
inv_denom = jnp.where(denom!=0, 1/denom, 0)
s = (-s1[1] * (p0[0] - p2[0]) + s1[0] * (p0[1] - p2[1])) * inv_denom
t = ( s2[0] * (p0[1] - p2[1]) - s2[1] * (p0[0] - p2[0])) * inv_denom
return (s >= 0) & (s <= 1) & (t >= 0) & (t <= 1) & (denom!=0)
# @jit
def _polygon_self_intersection(r_pol, z_pol):
len_theta = len(r_pol)
# Takes a planar polygon and removes self-intersecting regions.
# Returns a weight array that is 1 for every point where the
# adjacent edges contain self-intersection.
# Assumes that the first point in the input is on the polygon to keep.
# shape: len_phi, 2 (r, z)
p0_in = jnp.concatenate([r_pol[:, None], z_pol[:, None]], axis=-1)
p1_in = jnp.roll(p0_in, -1, axis=0)
# shape: len_phi, 4 (r0, z0, r1, z1)
p0p1 = jnp.concatenate([p0_in, p1_in], axis=1)
# Outer scan
def outer_loop(carry_outer, x_outer):
index_a, weight = carry_outer
def inner_loop(carry, x):
# carry is (index of p0p1, ([r0, z0, r1, z1]), index of p2p3)
# x is ([r2, z2, r3, z3])
index_a, r0z0r1z1, index_b = carry
# Is the index of the second line segment
# one greater or lower than that of the current line segment?
# If so, _get_line_intersection will throw a False positive
# and has to be disregarded.
is_overlapping = (
(index_a == index_b)
| (index_a == (index_b+1)%len_theta)
| ((index_a+1)%len_theta == index_b)
)
p0_i = r0z0r1z1[:2]
p1_i = r0z0r1z1[2:]
p2_i = x[:2]
p3_i = x[2:]
# True when intersection is present
is_intersect = _get_line_intersection(p0_i, p1_i, p2_i, p3_i)
return (index_a, r0z0r1z1, index_b+1), is_intersect & jnp.logical_not(is_overlapping)
_, is_intersect = scan(inner_loop, (index_a, x_outer, 0), p0p1)
has_self_intersection = jnp.any(is_intersect)
# flip the sign of weight if self intersection is detected.
# We assume that the outboard side is not self-intersecting (weight=1)
# This will allow us to mark all self-intersecting regions with
# weight = -1.
weight = jnp.where(has_self_intersection, -weight, weight)
# if has_self_intersection:
# plt.plot(p0p1[:, 0], p0p1[:, 1])
# plt.scatter(p0p1[:, 0], p0p1[:, 1], alpha = is_intersect)
# plt.show()
return (index_a + 1, weight), weight
_, weight = scan(outer_loop, (0, 1), p0p1)
# Convert weight = +-1 to 0 and 1
weight = (weight+1)/2
# Thus far we've marked all vertices where the edge
# that follows contains self-intersection. We now also
# change the weight of the vertices that is preceded
# by a self-intersecting edge.
weight = jnp.where(jnp.roll(weight, 1)==0, 0, 1)
return(weight)
def _graham_scan(r_expand, z_expand):
N = r_expand.shape[0]
# Step 1: Find P0 (lowest z, then leftmost r)
min_idx = jnp.lexsort((r_expand, z_expand))[0]
P0 = jnp.array([r_expand[min_idx], z_expand[min_idx]])
# Step 2: Compute polar angles and distances
delta_r = r_expand - P0[0]
delta_z = z_expand - P0[1]
angles = jnp.arctan2(delta_z, delta_r)
dists = delta_r**2 + delta_z**2
# Step 3: Sort indices by angle, break ties with farthest distance
sort_idx = jnp.lexsort((-dists, angles))
angles_sorted = angles[sort_idx]
# Step 4: Keep only the farthest point per unique angle using fixed-size buffer
def keep_unique_angles():
init_kept = jnp.zeros(N, dtype=jnp.int32).at[0].set(sort_idx[0])
init_angle = angles[sort_idx[0]]
init_count = jnp.array(1, dtype=jnp.int32)
def body(i, carry):
kept_indices, last_angle, count = carry
idx = sort_idx[i]
angle = angles[idx]
is_new = angle != last_angle
kept_indices = lax.cond(
is_new,
lambda k: k.at[count].set(idx),
lambda k: k,
kept_indices
)
last_angle = lax.cond(is_new, lambda _: angle, lambda a: a, last_angle)
count = count + is_new.astype(jnp.int32)
return (kept_indices, last_angle, count)
kept_indices, _, count = lax.fori_loop(1, N, body, (init_kept, init_angle, init_count))
kept_indices = lax.dynamic_slice(kept_indices, (0,), (count,))
return kept_indices, count
kept_idx, M = keep_unique_angles()
# Step 5: Sort r, z arrays by remaining indices
r_sorted = r_expand[kept_idx]
z_sorted = z_expand[kept_idx]
# Step 6: Graham scan using lax.while_loop
stack = jnp.zeros(M, dtype=jnp.int32).at[:2].set(jnp.array([0, 1]))
top = jnp.array(2, dtype=jnp.int32)
def ccw(i, j, k):
xi, yi = r_sorted[i], z_sorted[i]
xj, yj = r_sorted[j], z_sorted[j]
xk, yk = r_sorted[k], z_sorted[k]
return (xj - xi) * (yk - yi) - (xk - xi) * (yj - yi)
def cond(state):
i, top, stack = state
return i < M
def body(state):
i, top, stack = state
def inner_cond(inner_state):
top, stack = inner_state
return jnp.logical_and(top > 1, ccw(stack[top - 2], stack[top - 1], i) <= 0)
def inner_body(inner_state):
top, stack = inner_state
return (top - 1, stack)
top_new, stack_new = lax.while_loop(inner_cond, inner_body, (top, stack))
stack_new = stack_new.at[top_new].set(i)
return (i + 1, top_new + 1, stack_new)
_, final_top, final_stack = lax.while_loop(cond, body, (2, top, stack))
# Step 7: Map final hull indices back to original array
hull_idx = kept_idx[final_stack[:final_top]]
is_on_hull = jnp.zeros(N, dtype=bool).at[hull_idx].set(True)
return is_on_hull
@partial(jit, static_argnames=[
'nfp',
'stellsym',
'mpol',
'ntor',
'pol_interp',
'tor_interp',
# 'lam_tikhonov'
])
def gen_winding_surface_atan(
plasma_gamma, d_expand,
nfp, stellsym,
unitnormal=None,
mpol=5, ntor=5,
pol_interp=2,
tor_interp=2,
lam_tikhonov=1e-5,
):
''' Create uniform offset '''
uniform_offset_dofs = gen_winding_surface_offset(
plasma_gamma, d_expand,
nfp, stellsym,
unitnormal=unitnormal,
mpol=mpol, ntor=ntor,
)
''' Interpolate to generate smooth poloidal cross sections '''
phi_expand = jnp.linspace(0, 1/nfp, plasma_gamma.shape[0] * tor_interp)
uniform_offset_surface_jax = SurfaceRZFourierJAX(
nfp=nfp, stellsym=stellsym,
mpol=mpol, ntor=ntor,
quadpoints_phi=phi_expand,
quadpoints_theta=jnp.linspace(0, 1, plasma_gamma.shape[1] * pol_interp, endpoint=False),
dofs=uniform_offset_dofs
)
gamma_uniform = uniform_offset_surface_jax.gamma()
''' Trimming based on stellarator symmetry '''
# Fit only half a field period when stellsym.
if stellsym:
# If stellsym, then only use half of the field period for surface fitting
len_phi = len(phi_expand)//2
gamma_uniform = gamma_uniform[:len_phi]
phi_expand = phi_expand[:len_phi]
# finding center to generate poloidal parameterization
r_plasma = jnp.sqrt(plasma_gamma[:len_phi, :, 1]**2 + plasma_gamma[:len_phi, :, 0]**2)
z_plasma = plasma_gamma[:len_phi, :, 2]
else:
gamma_uniform = gamma_uniform
# Copy the gamma from the next and last fp.
# finding center to generate poloidal parameterization
r_plasma = jnp.sqrt(plasma_gamma[:, :, 1]**2 + plasma_gamma[:, :, 0]**2)
z_plasma = plasma_gamma[:, :, 2]
r_center = jnp.average(r_plasma, axis=-1)
z_center = jnp.average(z_plasma, axis=-1)
# The original uniform offset. Has self-intersections.
# Tested to be differentiable.
r_expand = jnp.sqrt(gamma_uniform[:, :, 1]**2 + gamma_uniform[:, :, 0]**2)
z_expand = gamma_uniform[:, :, 2]
''' Removing self-intersection '''
weight_remove_invalid = vmap(_polygon_self_intersection, in_axes=0)(r_expand, z_expand)
''' Fitting surface'''
theta_atan = jnp.arctan2(z_expand-z_center[:, None], r_expand-r_center[:, None])/jnp.pi/2
theta_atan = jnp.where(theta_atan>0, theta_atan, theta_atan+1)
phi_expand, theta_atan = jnp.broadcast_arrays(phi_expand[:, None], theta_atan)
dofs_expand = fit_surfacerzfourier(
mpol=mpol,
ntor=ntor,
theta_grid=theta_atan, # theta_interp
phi_grid=phi_expand,
r_fit=r_expand,
z_fit=z_expand,
nfp=nfp, stellsym=stellsym,
lam_tikhonov=lam_tikhonov,
custom_weight=weight_remove_invalid,
)
return(dofs_expand)
@partial(jit, static_argnames=[
'nfp',
'stellsym',
'mpol',
'ntor',
'pol_interp',
'tor_interp',
'rule',
])
def gen_winding_surface_arc(
plasma_gamma, d_expand,
nfp, stellsym,
unitnormal=None,
mpol=5, ntor=5,
pol_interp=2,
tor_interp=2,
lam_tikhonov=1e-5,
rule='self-intersection',
):
# ----- Create uniform offset -----
uniform_offset_dofs = gen_winding_surface_offset(
plasma_gamma, d_expand,
nfp, stellsym,
unitnormal=unitnormal,
mpol=mpol, ntor=ntor,
)
# ----- Interpolate to generate smooth poloidal cross sections -----
phi_expand = jnp.linspace(0, 1/nfp, plasma_gamma.shape[0] * tor_interp)
uniform_offset_surface_jax = SurfaceRZFourierJAX(
nfp=nfp, stellsym=stellsym,
mpol=mpol, ntor=ntor,
quadpoints_phi=phi_expand,
quadpoints_theta=jnp.linspace(0, 1, plasma_gamma.shape[1] * pol_interp, endpoint=False),
dofs=uniform_offset_dofs
)
gamma_uniform = uniform_offset_surface_jax.gamma()
# ----- Trimming based on stellarator symmetry -----
# Fit only half a field period when stellsym.
if stellsym:
# If stellsym, then only use half of the field period for surface fitting
len_phi = len(phi_expand)//2
gamma_uniform = gamma_uniform[:len_phi]
phi_expand = phi_expand[:len_phi]
# finding center to generate poloidal parameterization
r_plasma = jnp.sqrt(plasma_gamma[:len_phi, :, 1]**2 + plasma_gamma[:len_phi, :, 0]**2)
z_plasma = plasma_gamma[:len_phi, :, 2]
else:
gamma_uniform = gamma_uniform
# Copy the gamma from the next and last fp.
# finding center to generate poloidal parameterization
r_plasma = jnp.sqrt(plasma_gamma[:, :, 1]**2 + plasma_gamma[:, :, 0]**2)
z_plasma = plasma_gamma[:, :, 2]
r_center = jnp.average(r_plasma, axis=-1)
z_center = jnp.average(z_plasma, axis=-1)
# The original uniform offset. Has self-intersections.
# Tested to be differentiable.
r_expand = jnp.sqrt(gamma_uniform[:, :, 1]**2 + gamma_uniform[:, :, 0]**2)
z_expand = gamma_uniform[:, :, 2]
''' Removing self-intersection '''
if rule == 'self-intersection':
rule_f = _polygon_self_intersection
elif rule == 'hull':
rule_f = _graham_scan
else:
raise ValueError('rule must to be \'intersection\' '
'or \'hull\'. The current value is: '+ rule)
weight_remove_invalid = vmap(rule_f, in_axes=0)(r_expand, z_expand)
# ----- Calculating parameterization -----
r_wrapped = jnp.pad(r_expand, pad_width=((0, 0), (0, 1)), mode='wrap')
z_wrapped = jnp.pad(z_expand, pad_width=((0, 0), (0, 1)), mode='wrap')
# Compute the differences along axis=1 (between successive points)
dr = jnp.diff(r_wrapped, axis=1)
dz = jnp.diff(z_wrapped, axis=1)
# Compute the Euclidean distance for each segment
segment_lengths = jnp.sqrt(dr**2 + dz**2)
# Sum the segment lengths to get the total arclength for each curve
arclengths = jnp.cumsum(segment_lengths, axis=1)
theta_arc = (arclengths - arclengths[:, 0][:, None]) / arclengths[:, -1][:, None]
phi_expand, theta_arc = jnp.broadcast_arrays(phi_expand[:, None], theta_arc)
# ----- Fitting surface -----
dofs_expand = fit_surfacerzfourier(
mpol=mpol,
ntor=ntor,
theta_grid=theta_arc, # theta_interp
phi_grid=phi_expand,
r_fit=r_expand,
z_fit=z_expand,
nfp=nfp, stellsym=stellsym,
lam_tikhonov=lam_tikhonov,
custom_weight=weight_remove_invalid,
)
return(dofs_expand)