Skip to content

train #

Helpers for training parameters.

Classes:

  • AttributeConfig

    Configuration for how a potential's attributes should be trained.

  • ParameterConfig

    Configuration for how a potential's parameters should be trained.

  • Trainable

    A convenient wrapper around a tensor force field that gives greater control

AttributeConfig pydantic-model #

Bases: BaseModel

Configuration for how a potential's attributes should be trained.

Fields:

  • cols (list[str])
  • scales (dict[str, float])
  • limits (dict[str, tuple[float | None, float | None]])

cols pydantic-field #

cols: list[str]

The parameters to train, e.g. 'k', 'length', 'epsilon'.

scales pydantic-field #

scales: dict[str, float] = {}

The scales to apply to each parameter, e.g. 'k': 1.0, 'length': 1.0, 'epsilon': 1.0.

limits pydantic-field #

limits: dict[str, tuple[float | None, float | None]] = {}

The min and max values to clamp each parameter within, e.g. 'k': (0.0, None), 'angle': (0.0, pi), 'epsilon': (0.0, None), where none indicates no constraint.

ParameterConfig pydantic-model #

Bases: AttributeConfig

Configuration for how a potential's parameters should be trained.

Fields:

cols pydantic-field #

cols: list[str]

The parameters to train, e.g. 'k', 'length', 'epsilon'.

scales pydantic-field #

scales: dict[str, float] = {}

The scales to apply to each parameter, e.g. 'k': 1.0, 'length': 1.0, 'epsilon': 1.0.

limits pydantic-field #

limits: dict[str, tuple[float | None, float | None]] = {}

The min and max values to clamp each parameter within, e.g. 'k': (0.0, None), 'angle': (0.0, pi), 'epsilon': (0.0, None), where none indicates no constraint.

include pydantic-field #

include: PotentialKeyList | None = None

The keys (see smee.TensorPotential.parameter_keys for details) corresponding to specific parameters to be trained. If None, all parameters will be trained.

exclude pydantic-field #

exclude: PotentialKeyList | None = None

The keys (see smee.TensorPotential.parameter_keys for details) corresponding to specific parameters to be excluded from training. If None, no parameters will be excluded.

Trainable #

Trainable(
    force_field: TensorForceField,
    parameters: dict[str, ParameterConfig],
    attributes: dict[str, AttributeConfig],
)

A convenient wrapper around a tensor force field that gives greater control over how parameters should be trained.

This includes imposing limits on the values of parameters, scaling the values so parameters passed to the optimizer have similar magnitudes, and freezing parameters so they are not updated during training.

parameters: Configure which parameters to train.
attributes: Configure which attributes to train.

Methods:

  • to_values

    Returns unfrozen parameter and attribute values as a flat tensor.

  • to_force_field

    Returns a force field with the parameters and attributes set to the given

  • clamp

    Clamps the given values to the configured min and max values.

Source code in descent/train.py
def __init__(
    self,
    force_field: smee.TensorForceField,
    parameters: dict[str, ParameterConfig],
    attributes: dict[str, AttributeConfig],
):
    """

    Args:
        force_field: The force field to wrap.
        parameters: Configure which parameters to train.
        attributes: Configure which attributes to train.
    """
    self._force_field = force_field

    (
        self._param_types,
        param_values,
        self._param_shapes,
        param_unfrozen_idxs,
        param_scales,
        param_clamp_lower,
        param_clamp_upper,
    ) = self._prepare(force_field, parameters, "parameters")
    (
        self._attr_types,
        attr_values,
        self._attr_shapes,
        attr_unfrozen_idxs,
        attr_scales,
        attr_clamp_lower,
        attr_clamp_upper,
    ) = self._prepare(force_field, attributes, "attributes")

    self._values = torch.cat([param_values, attr_values])

    self._unfrozen_idxs = torch.cat(
        [param_unfrozen_idxs, attr_unfrozen_idxs + len(param_scales)]
    ).long()

    self._scales = torch.cat([param_scales, attr_scales])[self._unfrozen_idxs]

    self._clamp_lower = torch.cat([param_clamp_lower, attr_clamp_lower])[
        self._unfrozen_idxs
    ]
    self._clamp_upper = torch.cat([param_clamp_upper, attr_clamp_upper])[
        self._unfrozen_idxs
    ]

to_values #

to_values() -> Tensor

Returns unfrozen parameter and attribute values as a flat tensor.

Source code in descent/train.py
@torch.no_grad()
def to_values(self) -> torch.Tensor:
    """Returns unfrozen parameter and attribute values as a flat tensor."""
    values_flat = self.clamp(self._values[self._unfrozen_idxs] * self._scales)
    return values_flat.detach().clone().requires_grad_()

to_force_field #

to_force_field(values_flat: Tensor) -> TensorForceField

Returns a force field with the parameters and attributes set to the given values.

Parameters:

  • values_flat (Tensor) –

    A flat tensor of parameter and attribute values. See to_values for the expected shape and ordering.

Source code in descent/train.py
def to_force_field(self, values_flat: torch.Tensor) -> smee.TensorForceField:
    """Returns a force field with the parameters and attributes set to the given
    values.

    Args:
        values_flat: A flat tensor of parameter and attribute values. See
            ``to_values`` for the expected shape and ordering.
    """
    potentials = self._force_field.potentials_by_type

    values = self._values.detach().clone()
    values[self._unfrozen_idxs] = (values_flat / self._scales).clamp(
        min=self._clamp_lower, max=self._clamp_upper
    )
    values = _unflatten_tensors(values, self._param_shapes + self._attr_shapes)

    params = values[: len(self._param_shapes)]

    for potential_type, param in zip(self._param_types, params, strict=True):
        potentials[potential_type].parameters = param

    attrs = values[len(self._param_shapes) :]

    for potential_type, attr in zip(self._attr_types, attrs, strict=True):
        potentials[potential_type].attributes = attr

    return self._force_field

clamp #

clamp(values_flat: Tensor) -> Tensor

Clamps the given values to the configured min and max values.

Source code in descent/train.py
@torch.no_grad()
def clamp(self, values_flat: torch.Tensor) -> torch.Tensor:
    """Clamps the given values to the configured min and max values."""
    return (values_flat / self._scales).clamp(
        min=self._clamp_lower, max=self._clamp_upper
    ) * self._scales