optim
#
Custom parameter optimizers.
Classes:
-
LevenbergMarquardtConfig
–Configuration for the Levenberg-Marquardt optimizer.
Functions:
-
levenberg_marquardt
–Optimize a given set of parameters using the Levenberg-Marquardt algorithm.
LevenbergMarquardtConfig
pydantic-model
#
Bases: BaseModel
Configuration for the Levenberg-Marquardt optimizer.
levenberg_marquardt
#
levenberg_marquardt(
x: Tensor,
config: LevenbergMarquardtConfig,
closure_fn: ClosureFn,
correct_fn: CorrectFn | None = None,
report_fn: ReportFn | None = None,
) -> Tensor
Optimize a given set of parameters using the Levenberg-Marquardt algorithm.
Notes
- This optimizer assumes a least-square loss function.
- This is a reimplementation of the Levenberg-Marquardt optimizer from the ForceBalance package, and so may differ from a standard implementation.
Parameters:
-
x
(Tensor
) –The initial guess of the parameters with
shape=(n,)
. -
config
(LevenbergMarquardtConfig
) –The optimizer config.
-
closure_fn
(ClosureFn
) –A function that computes the loss (
shape=()
), its gradient (shape=(n,)
), and hessian (shape=(n, n)
). It should accept as arguments the current parameter tensor, and two booleans indicating whether the gradient and hessian are required. -
correct_fn
(CorrectFn | None
, default:None
) –A function that can be used to correct the parameters after each step is taken and before the new loss is computed. This may include, for example, ensuring that vdW parameters are all positive. It should accept as arguments the current parameter tensor and return the corrected parameter tensor.
-
report_fn
(ReportFn | None
, default:None
) –An optional function that should be called at the end of every step. This can be used to report the current state of the optimization. It should accept as arguments the step number, the current parameter tensor the loss, gradient and hessian, the step 'quality', and a bool indicating whether the step was accepted or rejected.
Returns:
-
Tensor
–The parameters that minimize the loss.
Source code in descent/optim/_lm.py
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 |
|