Skip to content

optim #

Custom parameter optimizers.

Classes:

Functions:

  • levenberg_marquardt

    Optimize a given set of parameters using the Levenberg-Marquardt algorithm.

LevenbergMarquardtConfig pydantic-model #

Bases: BaseModel

Configuration for the Levenberg-Marquardt optimizer.

levenberg_marquardt #

levenberg_marquardt(
    x: Tensor,
    config: LevenbergMarquardtConfig,
    closure_fn: ClosureFn,
    correct_fn: CorrectFn | None = None,
    report_fn: ReportFn | None = None,
) -> Tensor

Optimize a given set of parameters using the Levenberg-Marquardt algorithm.

Notes
  • This optimizer assumes a least-square loss function.
  • This is a reimplementation of the Levenberg-Marquardt optimizer from the ForceBalance package, and so may differ from a standard implementation.

Parameters:

  • x (Tensor) –

    The initial guess of the parameters with shape=(n,).

  • config (LevenbergMarquardtConfig) –

    The optimizer config.

  • closure_fn (ClosureFn) –

    A function that computes the loss (shape=()), its gradient (shape=(n,)), and hessian (shape=(n, n)). It should accept as arguments the current parameter tensor, and two booleans indicating whether the gradient and hessian are required.

  • correct_fn (CorrectFn | None, default: None ) –

    A function that can be used to correct the parameters after each step is taken and before the new loss is computed. This may include, for example, ensuring that vdW parameters are all positive. It should accept as arguments the current parameter tensor and return the corrected parameter tensor.

  • report_fn (ReportFn | None, default: None ) –

    An optional function that should be called at the end of every step. This can be used to report the current state of the optimization. It should accept as arguments the step number, the current parameter tensor the loss, gradient and hessian, the step 'quality', and a bool indicating whether the step was accepted or rejected.

Returns:

  • Tensor

    The parameters that minimize the loss.

Source code in descent/optim/_lm.py
@torch.no_grad()
def levenberg_marquardt(
    x: torch.Tensor,
    config: LevenbergMarquardtConfig,
    closure_fn: ClosureFn,
    correct_fn: CorrectFn | None = None,
    report_fn: ReportFn | None = None,
) -> torch.Tensor:
    """Optimize a given set of parameters using the Levenberg-Marquardt algorithm.

    Notes:
        * This optimizer assumes a least-square loss function.
        * This is a reimplementation of the Levenberg-Marquardt optimizer from the
          ForceBalance package, and so may differ from a standard implementation.

    Args:
        x: The initial guess of the parameters with ``shape=(n,)``.
        config: The optimizer config.
        closure_fn: A function that computes the loss (``shape=()``), its
            gradient (``shape=(n,)``), and hessian (``shape=(n, n)``). It should
            accept as arguments the current parameter tensor, and two booleans
            indicating whether the gradient and hessian are required.
        correct_fn: A function that can be used to correct the parameters after
            each step is taken and before the new loss is computed. This may
            include, for example, ensuring that vdW parameters are all positive.
            It should accept as arguments the current parameter tensor and return
            the corrected parameter tensor.
        report_fn: An optional function that should be called at the end of every
            step. This can be used to report the current state of the optimization.
            It should accept as arguments the step number, the current parameter tensor
            the loss, gradient and hessian, the step 'quality', and a bool indicating
            whether the step was accepted or rejected.

    Returns:
        The parameters that minimize the loss.
    """

    x = x.clone().detach().requires_grad_(x.requires_grad)

    correct_fn = correct_fn if correct_fn is not None else lambda y: y
    closure_fn = torch.enable_grad()(closure_fn)

    report_fn = report_fn if report_fn is not None else lambda *_, **__: None

    closure_prev = closure_fn(x, True, True)
    trust_radius = torch.tensor(config.trust_radius).to(x.device)

    loss_history = []
    has_converged = False

    best_x, best_loss = x.clone(), closure_prev[0]

    for step in range(config.max_steps):
        loss_prev, gradient_prev, hessian_prev = closure_prev

        dx, expected_improvement, damping_adjusted, damping_factor = _step(
            gradient_prev, hessian_prev, trust_radius, config
        )

        if config.mode.lower() == _HESSIAN_SEARCH:
            dx, expected_improvement = _hessian_diagonal_search(
                x,
                closure_prev,
                closure_fn,
                correct_fn,
                damping_factor,
                trust_radius,
                config,
            )

        dx_norm = torch.linalg.norm(dx)
        _LOGGER.info(f"{config.mode} step found (length {dx_norm:.4e})")

        x_next = correct_fn(x + dx).requires_grad_(x.requires_grad)

        loss, gradient, hessian = closure_fn(x_next, True, True)
        loss_delta = loss - loss_prev

        step_quality = loss_delta / expected_improvement
        accept_step = True

        if loss > (loss_prev + config.error_tolerance):
            # reject the 'bad' step and try again from where we were
            loss, gradient, hessian = (loss_prev, gradient_prev, hessian_prev)
            trust_radius = _reduce_trust_radius(dx_norm, config)

            accept_step = False
        elif config.mode.lower() == _ADAPTIVE:
            # this was a 'good' step - we can maybe increase the trust radius
            trust_radius = _update_trust_radius(
                dx_norm, step_quality, trust_radius, damping_adjusted, config
            )

        if accept_step:
            x.data.copy_(x_next.data)
            loss_history.append(loss.detach().cpu().clone())

        if loss < best_loss:
            best_x, best_loss = x.clone(), loss.detach().clone()

        closure_prev = (loss, gradient, hessian)

        report_fn(step, x, loss, gradient, hessian, step_quality, accept_step)
        _LOGGER.info(f"step={step} loss={loss.detach().cpu().item():.4e}")

        if _has_converged(dx, loss_history, gradient, step_quality, config):
            _LOGGER.info(f"optimization has converged after {step + 1} steps.")
            has_converged = True

            break

    if not has_converged:
        _LOGGER.info(f"optimization has not converged after {config.max_steps} steps.")

    return best_x