energy #

Train against relative energies and forces.


  • Entry

  • create_dataset

    Create a dataset from a list of existing entries.

  • extract_smiles

    Return a list of unique SMILES strings in the dataset.

  • predict

    Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset.

Entry #

Bases: TypedDict

Represents a set of reference energies and forces.


smiles instance-attribute #

smiles: str

The indexed SMILES description of the molecule the energies and forces were computed for.

coords instance-attribute #

coords: Tensor

The coordinates [Å] the energies and forces were evaluated at with shape=(n_confs, n_particles, 3).

energy instance-attribute #

energy: Tensor

The reference energies [kcal/mol] with shape=(n_confs,).

forces instance-attribute #

forces: Tensor

The reference forces [kcal/mol/Å] with shape=(n_confs, n_particles, 3).

create_dataset #

create_dataset(entries: list[Entry]) -> Dataset

Create a dataset from a list of existing entries.


  • entries (list[Entry]) –

    The entries to create the dataset from.


  • Dataset

    The created dataset.

def create_dataset(entries: list[Entry]) -> datasets.Dataset:
    """Create a dataset from a list of existing entries.

        entries: The entries to create the dataset from.

        The created dataset.

    table = pyarrow.Table.from_pylist(
                "smiles": entry["smiles"],
                "coords": torch.tensor(entry["coords"]).flatten().tolist(),
                "energy": torch.tensor(entry["energy"]).flatten().tolist(),
                "forces": torch.tensor(entry["forces"]).flatten().tolist(),
            for entry in entries
    # TODO: validate rows
    dataset = datasets.Dataset(datasets.table.InMemoryTable(table))

    return dataset

extract_smiles #

extract_smiles(dataset: Dataset) -> list[str]

Return a list of unique SMILES strings in the dataset.


  • dataset (Dataset) –

    The dataset to extract the SMILES strings from.


  • list[str]

    The list of unique SMILES strings.

def extract_smiles(dataset: datasets.Dataset) -> list[str]:
    """Return a list of unique SMILES strings in the dataset.

        dataset: The dataset to extract the SMILES strings from.

        The list of unique SMILES strings.
    return sorted({*dataset.unique("smiles")})

predict #

    dataset: Dataset,
    force_field: TensorForceField,
    topologies: dict[str, TensorTopology],
    reference: Literal["mean", "min"] = "mean",
    normalize: bool = True,
) -> tuple[Tensor, Tensor, Tensor, Tensor]

Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset.


  • dataset (Dataset) –

    The dataset to predict the energies and forces of.

  • force_field (TensorForceField) –

    The force field to use to predict the energies and forces.

  • topologies (dict[str, TensorTopology]) –

    The topologies of the molecules in the dataset. Each key should be a fully indexed SMILES string.

  • reference (Literal['mean', 'min'], default: 'mean' ) –

    The reference energy to compute the relative energies with respect to. This should be either the "mean" energy of all conformers, or the energy of the conformer with the lowest reference energy ("min").

  • normalize (bool, default: True ) –

    Whether to scale the relative energies by 1/sqrt(n_confs_i) and the forces by 1/sqrt(n_confs_i * n_atoms_per_conf_i * 3) This is useful when wanting to compute the MSE per entry.


  • tuple[Tensor, Tensor, Tensor, Tensor]

    The predicted and reference relative energies [kcal/mol] with shape=(n_confs,), and predicted and reference forces [kcal/mol/Å] with shape=(n_confs * n_atoms_per_conf, 3).

def predict(
    dataset: datasets.Dataset,
    force_field: smee.TensorForceField,
    topologies: dict[str, smee.TensorTopology],
    reference: typing.Literal["mean", "min"] = "mean",
    normalize: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Predict the relative energies [kcal/mol] and forces [kcal/mol/Å] of a dataset.

        dataset: The dataset to predict the energies and forces of.
        force_field: The force field to use to predict the energies and forces.
        topologies: The topologies of the molecules in the dataset. Each key should be
            a fully indexed SMILES string.
        reference: The reference energy to compute the relative energies with respect
            to. This should be either the "mean" energy of all conformers, or the
            energy of the conformer with the lowest reference energy ("min").
        normalize: Whether to scale the relative energies by ``1/sqrt(n_confs_i)``
            and the forces by ``1/sqrt(n_confs_i * n_atoms_per_conf_i * 3)`` This
            is useful when wanting to compute the MSE per entry.

        The predicted and reference relative energies [kcal/mol] with
        ``shape=(n_confs,)``, and predicted and reference forces [kcal/mol/Å] with
        ``shape=(n_confs * n_atoms_per_conf, 3)``.
    energy_ref_all, energy_pred_all = [], []
    forces_ref_all, forces_pred_all = [], []

    for entry in dataset:
        smiles = entry["smiles"]

        energy_ref = entry["energy"]
        forces_ref = entry["forces"].reshape(len(energy_ref), -1, 3)

        coords = (
            .reshape(len(energy_ref), -1, 3)
        topology = topologies[smiles]

        energy_pred = smee.compute_energy(topology, force_field, coords)
        forces_pred = -torch.autograd.grad(

        if reference.lower() == "mean":
            energy_ref_0 = energy_ref.mean()
            energy_pred_0 = energy_pred.mean()
        elif reference.lower() == "min":
            min_idx = energy_ref.argmin()

            energy_ref_0 = energy_ref[min_idx]
            energy_pred_0 = energy_pred[min_idx]
            raise NotImplementedError(f"invalid reference energy {reference}")

        scale_energy, scale_forces = 1.0, 1.0

        if normalize:
            scale_energy = 1.0 / torch.sqrt(torch.tensor(energy_pred.numel()))
            scale_forces = 1.0 / torch.sqrt(torch.tensor(forces_pred.numel()))

        energy_ref_all.append(scale_energy * (energy_ref - energy_ref_0))
        forces_ref_all.append(scale_forces * forces_ref.reshape(-1, 3))

        energy_pred_all.append(scale_energy * (energy_pred - energy_pred_0))
        forces_pred_all.append(scale_forces * forces_pred.reshape(-1, 3))

    return (,,,,