# Copyright (C) 2026 Jack S. Hale
#
# This file is part of DOLFINx (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later
"""High-level problem classes using native linear algebra objects.
Users with advanced requirements should use
:mod:`dolfinx.fem.petsc`.
"""
import typing
from collections.abc import Sequence
import numpy.typing as npt
import ufl
from dolfinx import default_scalar_type
from dolfinx.fem import (
DirichletBC,
Form,
Function,
apply_lifting,
assemble_matrix,
assemble_vector,
create_matrix,
create_vector,
form,
)
from dolfinx.la import InsertMode, MatrixCSR, Vector
from dolfinx.la.superlu_dist import superlu_dist_matrix, superlu_dist_solver
from dolfinx.mesh import EntityMap as EntityMap
__all__ = ["LinearProblem"]
[docs]
class LinearProblem:
r"""High-level class for solving linear problems using SuperLU_DIST.
Solves problems of the form :math:`a(u, v) = L(v) \; \forall v \in V`
using :class:`dolfinx.la.superlu_dist.SuperLUDistSolver` as the
linear solver.
Note:
DOLFINx must be built with SuperLU_DIST to use this class.
"""
def __init__(
self,
a: ufl.Form,
L: ufl.Form,
bcs: Sequence[DirichletBC] | None = None,
u: Function | None = None,
dtype: npt.DTypeLike = default_scalar_type,
superlu_dist_options: dict | None = None,
form_compiler_options: dict | None = None,
jit_options: dict | None = None,
entity_maps: Sequence[EntityMap] | None = None,
) -> None:
"""Initialize SuperLU_DIST solver for a linear variational problem.
Args:
a: Bilinear UFL form, the left-hand side of the variational
problem.
L: Linear UFL form, the right-hand side of the variational
problem.
bcs: Dirichlet boundary conditions to apply to the variational
problem.
u: Solution function. Created if not provided.
dtype: Scalar type for form compilation. Must match
``u.dtype`` if ``u`` is provided.
superlu_dist_options: Options passed to the SuperLU_DIST
solver. See the SuperLU_DIST User's Guide for
available options and values.
form_compiler_options: Options used in FFCx compilation of
all forms. Run ``ffcx --help`` at the command line to see
all available options.
jit_options: Options used in CFFI JIT compilation of C code
generated by FFCx. See :func:`dolfinx.jit.ffcx_jit` for
all available options. Takes priority over all other
option values.
entity_maps: If any trial functions, test functions, or
coefficients in the form are not defined over the same mesh
as the integration domain, a corresponding
:class:`EntityMap <dolfinx.mesh.EntityMap>` must be
provided.
Example::
problem = LinearProblem(a, L, bcs=bc,
superlu_dist_options={"SymmetricMode": "YES"})
u_h = problem.solve()
"""
if u is not None:
_dtype = u.dtype
if dtype is not u.dtype:
raise ValueError(f"dtype ({dtype}) does not match u.dtype ({u.dtype}).")
else:
_dtype = dtype
self._a = typing.cast(
Form,
form(
a,
_dtype,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
entity_maps=entity_maps,
),
)
self._L = typing.cast(
Form,
form(
L,
_dtype,
form_compiler_options=form_compiler_options,
jit_options=jit_options,
entity_maps=entity_maps,
),
)
self._A = create_matrix(self._a)
self._x = create_vector(L.arguments()[0].ufl_function_space(), dtype=_dtype)
self._b = create_vector(L.arguments()[0].ufl_function_space(), dtype=_dtype)
self._u: Function
if u is None:
self._u = Function(L.arguments()[0].ufl_function_space(), dtype=_dtype)
else:
self._u = u
self.bcs = [] if bcs is None else bcs
self._superlu_dist_options = superlu_dist_options
[docs]
def solve(self) -> Function:
"""Solve the problem.
This method updates the solution function ``u`` stored in the
problem instance.
Returns:
The solution function.
"""
# Assemble lhs
self.A.set_value(self.A.data.dtype.type(0.0))
assemble_matrix(self.A, self.a, bcs=self.bcs) # type: ignore[arg-type, misc]
self.A.scatter_reverse()
# SuperLU_DIST solves in-place, so a deep copy of A is required.
A_superlu_dist = superlu_dist_matrix(self.A)
solver = superlu_dist_solver(A_superlu_dist)
if self._superlu_dist_options is not None:
for option, value in self._superlu_dist_options.items():
solver.set_option(option, value)
# Assemble rhs
self.b.array[:] = 0.0
assemble_vector(self.b.array, self.L) # type: ignore[arg-type]
# Apply boundary conditions to the rhs
if self.bcs:
apply_lifting(self.b.array, [self.a], bcs=[self.bcs])
self.b.scatter_reverse(InsertMode.add)
for bc in self.bcs:
bc.set(self.b.array)
else:
self.b.scatter_reverse(InsertMode.add)
# Solve linear system and update ghost values in the solution
error = solver.solve(self.b, self.x)
if error > 0:
raise RuntimeError(f"SuperLU_DIST returned non-zero error code: {error}")
self.x.scatter_forward()
self.u.x.array[:] = self.x.array
return self.u
@property
def L(self) -> Form:
"""The compiled linear form representing the right-hand side."""
return self._L
@property
def a(self) -> Form:
"""The compiled bilinear form representing the left-hand side."""
return self._a
@property
def A(self) -> MatrixCSR:
"""Left-hand side matrix."""
return self._A
@property
def b(self) -> Vector:
"""Right-hand side vector."""
return self._b
@property
def x(self) -> Vector:
"""Solution vector.
Note:
The vector does not share memory with the solution
function ``u``.
"""
return self._x
@property
def u(self) -> Function:
"""Solution function.
Note:
The function does not share memory with the solution
vector ``x``.
"""
return self._u