Skip to content

hess #

Utilities for approximating Hessian matrices.

Notes
  • This module is heavily based off of the hessian module of geomeTRIC. See the LICENSE-3RD-PARTY for license information.
References
  1. Schlegel, Theor. Chim. Acta, 66, 333 (1984)

Functions:

  • guess_hess_q

    Build a guess Hessian that roughly follows Schlegel's guidelines.

  • update_hess_q

    Approximately update the Hessian matrix using the history of optimization steps.

guess_hess_q #

guess_hess_q(
    coords: Tensor, ic_idxs: ICDict, atomic_nums: Tensor
) -> Tensor

Build a guess Hessian that roughly follows Schlegel's guidelines.

Parameters:

  • coords (Tensor) –

    The cartesian coordinates [a0].

  • ic_idxs (ICDict) –

    The indices of atoms that define the internal coordinates.

  • atomic_nums (Tensor) –

    The atomic numbers of the atoms in the molecule.

Source code in tico/hess.py
def guess_hess_q(
    coords: torch.Tensor, ic_idxs: tico.ic.ICDict, atomic_nums: torch.Tensor
) -> torch.Tensor:
    """Build a guess Hessian that roughly follows Schlegel's guidelines.

    Args:
        coords: The cartesian coordinates [a0].
        ic_idxs: The indices of atoms that define the internal coordinates.
        atomic_nums: The atomic numbers of the atoms in the molecule.
    """

    hess_diag = []

    bonds = {
        tuple(sorted([int(idx_a), int(idx_b)]))
        for idx_a, idx_b in ic_idxs[tico.ic.ICType.DISTANCE]
    }

    def is_bound(idx_a, idx_b):
        return tuple(sorted([int(idx_a), int(idx_b)])) in bonds

    for ic_type, idxs in ic_idxs.items():
        if ic_type == tico.ic.ICType.DISTANCE:
            hess_diag.extend(_guess_hess_distance(coords, atomic_nums, idxs))
        elif ic_type in {tico.ic.ICType.ANGLE, tico.ic.ICType.LINEAR}:
            hess_diag.extend(_guess_hess_angle(atomic_nums, idxs, is_bound))
        elif ic_type == tico.ic.ICType.DIHEDRAL:
            hess_diag.extend([0.023] * len(idxs))
        elif ic_type == tico.ic.ICType.OUT_OF_PLANE:
            hess_diag.extend(_guess_hess_out_of_plane(idxs, is_bound))
        else:
            raise NotImplementedError()

    return torch.diag(torch.tensor(hess_diag, dtype=coords.dtype, device=coords.device))

update_hess_q #

update_hess_q(
    hess_q: Tensor,
    history: list[Step],
    ic: IC,
    max_updates: int = 1,
) -> Tensor

Approximately update the Hessian matrix using the history of optimization steps.

Parameters:

  • hess_q (Tensor) –

    The current Hessian matrix.

  • history (list[Step]) –

    The history of optimization steps.

  • ic (IC) –

    The internal coordinate system.

  • max_updates (int, default: 1 ) –

    The maximum number of previous steps to use for the update.

Returns:

  • Tensor

    The updated Hessian matrix.

Source code in tico/hess.py
def update_hess_q(
    hess_q: torch.Tensor,
    history: list["tico.opt.Step"],
    ic: tico.ic.IC,
    max_updates: int = 1,
) -> torch.Tensor:
    """Approximately update the Hessian matrix using the history of optimization steps.

    Args:
        hess_q: The current Hessian matrix.
        history: The history of optimization steps.
        ic: The internal coordinate system.
        max_updates: The maximum number of previous steps to use for the update.

    Returns:
        The updated Hessian matrix.
    """
    if len(history) < 2:
        return hess_q

    n_steps = 0

    for _ in range(2, len(history) + 1):
        if n_steps == max_updates:
            break
        n_steps += 1

    for i in range(n_steps):
        this_step = -n_steps + i
        prev_step = -n_steps + i - 1

        dq = ic.compute_dq(
            history[this_step].coords_x, history[prev_step].coords_x
        ).unsqueeze(-1)
        dg = (history[prev_step].grad_q - history[this_step].grad_q).unsqueeze(-1)

        if torch.linalg.norm(dq) < 1e-6:
            continue
        if torch.linalg.norm(dg) < 1e-6:
            continue

        mat_1 = (dg @ dg.T) / (dg.T @ dq)
        mat_2 = ((hess_q @ dq) @ (hess_q @ dq).T) / torch.linalg.multi_dot(
            [dq.T, hess_q, dq]
        )
        hess_q = hess_q + mat_1 - mat_2

    return hess_q