Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use single LinearConstraintProjection in ProximalProjection #1409

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9b3f1bf
initial commit
YigitElma Nov 27, 2024
24e638f
add update_constraint_target method to LienarConstraintProjection, ma…
YigitElma Jan 30, 2025
bb72369
fix typo
YigitElma Jan 30, 2025
67e64b9
some fixes
YigitElma Jan 30, 2025
4a79f12
Merge branch 'master' into yge/less-factorize
YigitElma Jan 30, 2025
77be65d
fix some problems
YigitElma Jan 30, 2025
9bab19f
fix the objective jac compute issues in perturb
YigitElma Jan 30, 2025
9db8332
update changelog and fix perturb errors caused by argument order
YigitElma Jan 30, 2025
ffb62bf
fix missing factorize_linear_constraints
YigitElma Jan 30, 2025
1a361fc
Merge branch 'master' into yge/less-factorize
YigitElma Jan 31, 2025
bb6a003
revert ordering change
YigitElma Jan 31, 2025
b36e74d
remore changelog part
YigitElma Jan 31, 2025
2a70a77
update benchmerks to profile eq solve and perturb
YigitElma Jan 31, 2025
2ded4fd
update benchmarks
YigitElma Jan 31, 2025
93b2d4b
update benchmarks
YigitElma Jan 31, 2025
9253ba4
instead add a new benchmark
YigitElma Jan 31, 2025
3bae0d4
also update gpu benchmarks
YigitElma Jan 31, 2025
dc967bf
re-factor factorize_linear_constraints to prevent repeated code
YigitElma Jan 31, 2025
cdd68bd
update Ainv based on new D
YigitElma Jan 31, 2025
7ef2cb6
fix benchmarks round thing, don't update the constraints if the eq is…
YigitElma Feb 1, 2025
844bd2b
fix round thing for other test
YigitElma Feb 1, 2025
db37e2d
fix the truth value of an array issue
YigitElma Feb 1, 2025
5b60d2b
address Rory's comments and some clean-up
YigitElma Feb 2, 2025
00f033f
update the nullspace too and some clean up
YigitElma Feb 3, 2025
ca8bc02
Merge branch 'master' into yge/less-factorize
YigitElma Feb 3, 2025
901a776
update the nullspace too and some clean up
YigitElma Feb 3, 2025
f571a79
add x_scale to updater
YigitElma Feb 3, 2025
9634646
add x_scale to updater
YigitElma Feb 3, 2025
7381eba
Merge branch 'yge/less-factorize' of github.com:PlasmaControl/DESC in…
YigitElma Feb 3, 2025
523ab2e
dummy fix
YigitElma Feb 3, 2025
dbbe988
Merge branch 'master' into yge/less-factorize
YigitElma Feb 4, 2025
758689d
remove linear constraint projection from docs
YigitElma Feb 4, 2025
e12a3a6
add update constraint checker
YigitElma Feb 4, 2025
e634152
Merge branch 'yge/less-factorize' of github.com:PlasmaControl/DESC in…
YigitElma Feb 4, 2025
ccd6fda
use assigned name for printing linear constraint projection build, us…
YigitElma Feb 4, 2025
c1fe741
fixes for qr
YigitElma Feb 4, 2025
a47d1e0
revert qr change, print tolerances shown in message
YigitElma Feb 5, 2025
33ef2ae
Merge remote-tracking branch 'origin/master' into yge/less-factorize
YigitElma Feb 5, 2025
8c10022
revert changes to the test
YigitElma Feb 5, 2025
3d39b2c
remove x_scale and add comment
YigitElma Feb 5, 2025
1ad8029
Merge branch 'master' into yge/less-factorize
YigitElma Feb 5, 2025
ffab8b7
Merge branch 'master' into yge/less-factorize
YigitElma Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ New Features
- See GitHub pull requests [#1003](https://github.com/PlasmaControl/DESC/pull/1003), [#1042](https://github.com/PlasmaControl/DESC/pull/1042), [#1119](https://github.com/PlasmaControl/DESC/pull/1119), and [#1290](https://github.com/PlasmaControl/DESC/pull/1290) for more details.
- Many new compute quantities for partial derivatives in different coordinate systems.
- Adds a new profile class ``PowerProfile`` for raising profiles to a power.
- Add ``desc.objectives.LinkingCurrentConsistency`` for ensuring that coils in a stage 2 or single stage optimization provide the required linking current for a given equilibrium.
- Adds ``desc.objectives.LinkingCurrentConsistency`` for ensuring that coils in a stage 2 or single stage optimization provide the required linking current for a given equilibrium.
- Adds an option ``scaled_termination`` (defaults to True) to all the desc optimizers to measure the norms for ``xtol`` and ``gtol`` in the scaled norm provided by ``x_scale`` (which defaults to using an adaptive scaling based on the Jacobian or Hessian). This should make things more robust when optimizing parameters with widely different magnitudes. The old behavior can be recovered by passing ``options={"scaled_termination": False}``.
- ``desc.objectives.Omnigenity`` is now vectorized and able to optimize multiple surfaces at the same time. Previously it was required to use a different objective for each surface.
- Adds a new objective ``desc.objectives.MirrorRatio`` for targeting a particular mirror ratio on each flux surface, for either an ``Equilibrium`` or ``OmnigenousField``.
- Adds the output quantities ``wb`` and ``wp`` to ``VMECIO.save``.
- Change implementation of Dommaschk potentials to use recursive algorithm and symbolic integration.
- Changes implementation of Dommaschk potentials to use recursive algorithm and symbolic integration.
- Changes hessian computation to use chunked ``jacfwd`` and ``jacrev``, allowing ``jac_chunk_size`` to now reduce hessian memory usage as well.
- Adds an option to ``VMECIO.save`` to specify the grid resolution in real space.
- Adds a new objective ``desc.objectives.CoilIntegratedCurvature`` for targeting convex coils.
- `eq.solve` and `eq.perturb` now accept `LinearConstraintProjection` as objective. This option must be used without any constraints.
- Adds batching feature to singular integrals.
- ``desc.objectives.CoilSetMinDistance`` and ``desc.objectives.PlasmaCoilSetMinDistance`` now include the option to use a softmin which can give smoother gradients. They also both now have a ``dist_chunk_size`` option to break up the distance calculation into smaller pieces to save memory
- Adds a new function ``desc.coils.initialize_helical_coils`` for creating an initial guess for stage 2 helical coil optimization.
Expand All @@ -35,6 +36,10 @@ Bug Fixes
- Sets ``os.environ["JAX_PLATFORMS"] = "cpu"`` instead of ``os.environ["JAX_PLATFORM_NAME"] = "cpu"`` when doing ``set_device("cpu")``.


Performance Improvements

- `proximal-` optimizers use a single `LinearConstraintProjection` and this makes the optimization faster for high resolution cases where taking the SVD (for null-space and inverse) of constraint matrix takes significant time.

v0.13.0
-------

Expand Down
28 changes: 21 additions & 7 deletions desc/equilibrium/equilibrium.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
get_fixed_boundary_constraints,
)
from desc.optimizable import Optimizable, optimizable_parameter
from desc.optimize import Optimizer
from desc.optimize import LinearConstraintProjection, Optimizer
from desc.perturbations import perturb
from desc.profiles import HermiteSplineProfile, PowerSeriesProfile, SplineProfile
from desc.transform import Transform
Expand Down Expand Up @@ -2140,6 +2140,7 @@
----------
objective : {"force", "forces", "energy"}
Objective function to solve. Default = force balance on unified grid.
ObjectiveFunction can also be passed.
constraints : Tuple
set of constraints to enforce. Default = fixed boundary/profiles
optimizer : str or Optimizer (optional)
Expand Down Expand Up @@ -2177,13 +2178,19 @@
`OptimizeResult` for a description of other attributes.

"""
if constraints is None:
is_linear_proj = isinstance(objective, LinearConstraintProjection)
if is_linear_proj and constraints is not None:
raise ValueError(

Check warning on line 2183 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L2183

Added line #L2183 was not covered by tests
"If a LinearConstraintProjection is passed, "
"no constraints should be passed."
)
if constraints is None and not is_linear_proj:
constraints = get_fixed_boundary_constraints(eq=self)
if not isinstance(objective, ObjectiveFunction):
objective = get_equilibrium_objective(eq=self, mode=objective)
if not isinstance(optimizer, Optimizer):
optimizer = Optimizer(optimizer)
if not isinstance(constraints, (list, tuple)):
if not isinstance(constraints, (list, tuple)) and not is_linear_proj:
constraints = tuple([constraints])

warnif(
Expand Down Expand Up @@ -2351,19 +2358,26 @@
Perturbed equilibrium.

"""
is_linear_proj = isinstance(objective, LinearConstraintProjection)
if is_linear_proj and constraints is not None:
raise ValueError(

Check warning on line 2363 in desc/equilibrium/equilibrium.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/equilibrium.py#L2363

Added line #L2363 was not covered by tests
"If a LinearConstraintProjection is passed, "
"no constraints should be passed. Passed constraints:"
f"{constraints}."
)
if objective is None:
objective = get_equilibrium_objective(eq=self)
if constraints is None:
if constraints is None and not is_linear_proj:
if "Ra_n" in deltas or "Za_n" in deltas:
constraints = get_fixed_axis_constraints(eq=self)
else:
constraints = get_fixed_boundary_constraints(eq=self)

eq = perturb(
self,
objective,
constraints,
deltas,
objective=objective,
constraints=constraints,
deltas=deltas,
order=order,
tr_ratio=tr_ratio,
weight=weight,
Expand Down
146 changes: 100 additions & 46 deletions desc/objectives/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,63 +100,25 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
A_augmented = np.hstack([A, np.reshape(b, (A.shape[0], 1))])
YigitElma marked this conversation as resolved.
Show resolved Hide resolved

# Find unique rows of A_augmented
unique_rows, unique_indices = np.unique(A_augmented, axis=0, return_index=True)
_, unique_indices = np.unique(A_augmented, axis=0, return_index=True)

# Sort the indices to preserve the order of appearance
unique_indices = np.sort(unique_indices)
# Find the indices of the degenerate rows
degenerate_idx = np.setdiff1d(np.arange(A_augmented.shape[0]), unique_indices)

# Extract the unique rows
A_augmented = A_augmented[unique_indices]
A = A_augmented[:, :-1]
b = np.atleast_1d(A_augmented[:, -1].squeeze())

# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])

while len(np.where(np.count_nonzero(A, axis=1) == 1)[0]):
# fixed just means there is a single element in A, so A_ij*x_j = b_i
fixed_rows = np.where(np.count_nonzero(A, axis=1) == 1)[0]
# indices of x that are fixed = cols of A where rows have 1 nonzero val.
_, fixed_idx = np.where(A[fixed_rows])
unfixed_rows = np.setdiff1d(np.arange(A.shape[0]), fixed_rows)
unfixed_idx = np.setdiff1d(np.arange(A.shape[1]), fixed_idx)

# find the global index of the fixed variables of this iteration
global_fixed_idx = indices_idx[fixed_idx]
# find the global index of the unfixed variables by removing the fixed variables
# from the indices arrays.
indices_idx = np.delete(indices_idx, fixed_idx) # fixed indices are removed
indices_row = np.delete(indices_row, fixed_rows) # fixed rows are removed

if len(fixed_rows):
# something like 0.5 x1 = 2 is the same as x1 = 4
b = put(b, fixed_rows, b[fixed_rows] / np.sum(A[fixed_rows], axis=1))
A = put(
A,
Index[fixed_rows, :],
A[fixed_rows] / np.sum(A[fixed_rows], axis=1)[:, None],
)
xp = put(xp, global_fixed_idx, b[fixed_rows])
# Some values might be fixed, but they still show up in other constraints
# this is where the fixed cols have >1 nonzero val.
# For fixed variables, we delete that row and col of A, but that means
# we need to subtract the fixed value from b so that the equation is
# balanced.
# e.g., 2 x1 + 3 x2 + 1 x3 = 4 ; 4 x1 = 2
# combining gives 3 x2 + 1 x3 = 3, with x1 now removed
b = put(
b,
unfixed_rows,
b[unfixed_rows] - A[unfixed_rows][:, fixed_idx] @ b[fixed_rows],
)
A = A[unfixed_rows][:, unfixed_idx]
b = b[unfixed_rows]
A_nondegenerate = A.copy()

unfixed_idx = indices_idx
fixed_idx = np.delete(np.arange(xp.size), unfixed_idx)
# remove fixed parameters from A and b
A, b, xp, unfixed_idx, fixed_idx = remove_fixed_parameters(A, b, xp)

# compute x_scale if not provided
# Note: this x_scale is not the same as the x_scale as in solve_options["x_scale"]
if x_scale == "auto":
x_scale = objective.x(*objective.things)
errorif(
Expand Down Expand Up @@ -229,7 +191,19 @@ def factorize_linear_constraints(objective, constraint, x_scale="auto"): # noqa
"or be due to floating point error.",
)

return xp, A, b, Z, D, unfixed_idx, project, recover
return (
xp,
A,
b,
Z,
D,
unfixed_idx,
project,
recover,
A_inv,
A_nondegenerate,
degenerate_idx,
)


class _Project(IOAble):
Expand Down Expand Up @@ -419,3 +393,83 @@ def check_if_points_are_inside_perimeter(R, Z, Rcheck, Zcheck):
# assign negative distance
pt_sign = jnp.where(jnp.isclose(pt_sign, 0), 1, -1)
return pt_sign


def remove_fixed_parameters(A, b, xp):
"""Remove fixed parameters from the linear constraint matrix and RHS vector.

Given a linear constraint matrix A and RHS vector b, remove fixed parameters from A
and b. Fixed parameters are those that have only a single nonzero value in A, so
that the equation is already balanced. This function will remove the fixed
parameters from A and b, will also update the correcponding sections of the
particular solution xp.

Parameters
----------
A : ndarray
Constraint matrix.
b : ndarray
RHS vector.
xp : ndarray
Particular solution vector for the constraint Ax=b.

Returns
-------
A : ndarray
Constraint matrix with fixed parameters removed.
b : ndarray
RHS vector with fixed parameters removed.
xp : ndarray
Particular solution with fixed parameters updated.
unfixed_idx : ndarray
Indices of the unfixed parameters.
fixed_idx : ndarray
Indices of the fixed parameters
"""
# will store the global index of the unfixed rows, idx
indices_row = np.arange(A.shape[0])
indices_idx = np.arange(A.shape[1])

while len(np.where(np.count_nonzero(A, axis=1) == 1)[0]):
# fixed just means there is a single element in A, so A_ij*x_j = b_i
fixed_rows = np.where(np.count_nonzero(A, axis=1) == 1)[0]
# indices of x that are fixed = cols of A where rows have 1 nonzero val.
_, fixed_idx = np.where(A[fixed_rows])
unfixed_rows = np.setdiff1d(np.arange(A.shape[0]), fixed_rows)
unfixed_idx = np.setdiff1d(np.arange(A.shape[1]), fixed_idx)

# find the global index of the fixed variables of this iteration
global_fixed_idx = indices_idx[fixed_idx]
# find the global index of the unfixed variables by removing the fixed variables
# from the indices arrays.
indices_idx = np.delete(indices_idx, fixed_idx) # fixed indices are removed
indices_row = np.delete(indices_row, fixed_rows) # fixed rows are removed

if len(fixed_rows):
# something like 0.5 x1 = 2 is the same as x1 = 4
b = put(b, fixed_rows, b[fixed_rows] / np.sum(A[fixed_rows], axis=1))
A = put(
A,
Index[fixed_rows, :],
A[fixed_rows] / np.sum(A[fixed_rows], axis=1)[:, None],
)
xp = put(xp, global_fixed_idx, b[fixed_rows])
# Some values might be fixed, but they still show up in other constraints
# this is where the fixed cols have >1 nonzero val.
# For fixed variables, we delete that row and col of A, but that means
# we need to subtract the fixed value from b so that the equation is
# balanced.
# e.g., 2 x1 + 3 x2 + 1 x3 = 4 ; 4 x1 = 2
# combining gives 3 x2 + 1 x3 = 3, with x1 now removed
b = put(
b,
unfixed_rows,
b[unfixed_rows] - A[unfixed_rows][:, fixed_idx] @ b[fixed_rows],
)
A = A[unfixed_rows][:, unfixed_idx]
b = b[unfixed_rows]

unfixed_idx = indices_idx
fixed_idx = np.delete(np.arange(xp.size), unfixed_idx)

return A, b, xp, unfixed_idx, fixed_idx
Loading
Loading