Skip to content

dimers #

Train against dimer energies.

Classes:

  • Dimer

    Represents a single experimental data point.

Functions:

  • create_dataset

    Create a dataset from a list of existing dimers.

  • create_from_des

    Create a dataset from a DESXXX dimer set.

  • extract_smiles

    Return a list of unique SMILES strings in the dataset.

  • compute_dimer_energy

    Compute the energy of a dimer in a series of conformers.

  • predict

    Predict the energies of each dimer in the dataset.

  • default_closure

    Return a default closure function for training against dimer energies.

  • report

    Generate a report comparing the predicted and reference energies of each dimer.

Dimer #

Bases: TypedDict

Represents a single experimental data point.

create_dataset #

create_dataset(dimers: list[Dimer]) -> Dataset

Create a dataset from a list of existing dimers.

Parameters:

  • dimers (list[Dimer]) –

    The dimers to create the dataset from.

Returns:

  • Dataset

    The created dataset.

Source code in descent/targets/dimers.py
def create_dataset(dimers: list[Dimer]) -> datasets.Dataset:
    """Create a dataset from a list of existing dimers.

    Args:
        dimers: The dimers to create the dataset from.

    Returns:
        The created dataset.
    """

    table = pyarrow.Table.from_pylist(
        [
            {
                "smiles_a": dimer["smiles_a"],
                "smiles_b": dimer["smiles_b"],
                "coords": torch.tensor(dimer["coords"]).flatten().tolist(),
                "energy": torch.tensor(dimer["energy"]).flatten().tolist(),
                "source": dimer["source"],
            }
            for dimer in dimers
        ],
        schema=DATA_SCHEMA,
    )
    # TODO: validate rows
    dataset = datasets.Dataset(datasets.table.InMemoryTable(table))
    dataset.set_format("torch")

    return dataset

create_from_des #

create_from_des(
    data_dir: Path, energy_fn: EnergyFn
) -> Dataset

Create a dataset from a DESXXX dimer set.

Parameters:

  • data_dir (Path) –

    The path to the DESXXX directory.

  • energy_fn (EnergyFn) –

    A function which computes the reference energy of a dimer. This should take as input a pandas DataFrame containing the metadata for a given group, a tuple of geometry IDs, and a tensor of coordinates with shape=(n_dimers, n_atoms, 3). It should return a tensor of energies with shape=(n_dimers,) and units of [kcal/mol].

Returns:

  • Dataset

    The created dataset.

Source code in descent/targets/dimers.py
def create_from_des(
    data_dir: pathlib.Path,
    energy_fn: EnergyFn,
) -> datasets.Dataset:
    """Create a dataset from a DESXXX dimer set.

    Args:
        data_dir: The path to the DESXXX directory.
        energy_fn: A function which computes the reference energy of a dimer. This
            should take as input a pandas DataFrame containing the metadata for a
            given group, a tuple of geometry IDs, and a tensor of coordinates with
            ``shape=(n_dimers, n_atoms, 3)``. It should return a tensor of energies
            with ``shape=(n_dimers,)`` and units of [kcal/mol].

    Returns:
        The created dataset.
    """
    import pandas
    from rdkit import Chem, RDLogger

    RDLogger.DisableLog("rdApp.*")

    metadata = pandas.read_csv(data_dir / f"{data_dir.name}.csv", index_col=False)

    system_ids = metadata["system_id"].unique()
    dimers: list[Dimer] = []

    for system_id in tqdm.tqdm(system_ids, desc="loading dimers"):
        system_data = metadata[metadata["system_id"] == system_id]

        group_ids = metadata[metadata["system_id"] == system_id]["group_id"].unique()

        for group_id in group_ids:
            group_data = system_data[system_data["group_id"] == group_id]
            group_orig = group_data["group_orig"].unique()[0]

            geometry_ids = tuple(group_data["geom_id"].values)

            dimer_example = Chem.MolFromMolFile(
                f"{data_dir}/geometries/{system_id}/DES{group_orig}_{geometry_ids[0]}.mol",
                removeHs=False,
            )
            mol_a, mol_b = Chem.GetMolFrags(dimer_example, asMols=True)

            smiles_a = descent.utils.molecule.mol_to_smiles(mol_a, False)
            smiles_b = descent.utils.molecule.mol_to_smiles(mol_b, False)

            source = (
                f"{data_dir.name} system={system_id} orig={group_orig} group={group_id}"
            )

            coords_raw = [
                Chem.MolFromMolFile(
                    f"{data_dir}/geometries/{system_id}/DES{group_orig}_{geometry_id}.mol",
                    removeHs=False,
                )
                .GetConformer()
                .GetPositions()
                .tolist()
                for geometry_id in geometry_ids
            ]

            coords = torch.tensor(coords_raw)
            energy = energy_fn(group_data, geometry_ids, coords)

            dimer = {
                "smiles_a": smiles_a,
                "smiles_b": smiles_b,
                "coords": coords,
                "energy": energy,
                "source": source,
            }
            dimers.append(dimer)

    RDLogger.EnableLog("rdApp.*")

    return create_dataset(dimers)

extract_smiles #

extract_smiles(dataset: Dataset) -> list[str]

Return a list of unique SMILES strings in the dataset.

Parameters:

  • dataset (Dataset) –

    The dataset to extract the SMILES strings from.

Returns:

  • list[str]

    The list of unique SMILES strings.

Source code in descent/targets/dimers.py
def extract_smiles(dataset: datasets.Dataset) -> list[str]:
    """Return a list of unique SMILES strings in the dataset.

    Args:
        dataset: The dataset to extract the SMILES strings from.

    Returns:
        The list of unique SMILES strings.
    """

    smiles_a = dataset.unique("smiles_a")
    smiles_b = dataset.unique("smiles_b")

    return sorted({*smiles_a, *smiles_b})

compute_dimer_energy #

compute_dimer_energy(
    topology_a: TensorTopology,
    topology_b: TensorTopology,
    force_field: TensorForceField,
    coords: Tensor,
) -> Tensor

Compute the energy of a dimer in a series of conformers.

Parameters:

  • topology_a (TensorTopology) –

    The topology of the first monomer.

  • topology_b (TensorTopology) –

    The topology of the second monomer.

  • force_field (TensorForceField) –

    The force field to use.

  • coords (Tensor) –

    The coordinates of the dimer with shape=(n_dimers, n_atoms, 3).

Returns:

  • Tensor

    The energy [kcal/mol] of the dimer in each conformer.

Source code in descent/targets/dimers.py
def compute_dimer_energy(
    topology_a: smee.TensorTopology,
    topology_b: smee.TensorTopology,
    force_field: smee.TensorForceField,
    coords: torch.Tensor,
) -> torch.Tensor:
    """Compute the energy of a dimer in a series of conformers.

    Args:
        topology_a: The topology of the first monomer.
        topology_b: The topology of the second monomer.
        force_field: The force field to use.
        coords: The coordinates of the dimer with ``shape=(n_dimers, n_atoms, 3)``.

    Returns:
        The energy [kcal/mol] of the dimer in each conformer.
    """
    dimer = smee.TensorSystem([topology_a, topology_b], [1, 1], False)

    coords_a = coords[:, : topology_a.n_atoms, :]

    if topology_a.v_sites is not None:
        coords_a = smee.geometry.add_v_site_coords(
            topology_a.v_sites, coords_a, force_field
        )

    coords_b = coords[:, topology_a.n_atoms :, :]

    if topology_b.v_sites is not None:
        coords_b = smee.geometry.add_v_site_coords(
            topology_b.v_sites, coords_b, force_field
        )

    coords = torch.cat([coords_a, coords_b], dim=1)

    energy_dimer = smee.compute_energy(dimer, force_field, coords)

    energy_a = smee.compute_energy(topology_a, force_field, coords_a)
    energy_b = smee.compute_energy(topology_b, force_field, coords_b)

    return energy_dimer - energy_a - energy_b

predict #

predict(
    dataset: Dataset,
    force_field: TensorForceField,
    topologies: dict[str, TensorTopology],
) -> tuple[Tensor, Tensor]

Predict the energies of each dimer in the dataset.

Parameters:

  • dataset (Dataset) –

    The dataset to predict the energies of.

  • force_field (TensorForceField) –

    The force field to use.

  • topologies (dict[str, TensorTopology]) –

    The topologies of each monomer. Each key should be a fully mapped SMILES string.

Returns:

  • tuple[Tensor, Tensor]

    The reference and predicted energies [kcal/mol] of each dimer, each with shape=(n_dimers * n_conf_per_dimer,).

Source code in descent/targets/dimers.py
def predict(
    dataset: datasets.Dataset,
    force_field: smee.TensorForceField,
    topologies: dict[str, smee.TensorTopology],
) -> tuple[torch.Tensor, torch.Tensor]:
    """Predict the energies of each dimer in the dataset.

    Args:
        dataset: The dataset to predict the energies of.
        force_field: The force field to use.
        topologies: The topologies of each monomer. Each key should be a fully
            mapped SMILES string.

    Returns:
        The reference and predicted energies [kcal/mol] of each dimer, each with
        ``shape=(n_dimers * n_conf_per_dimer,)``.
    """

    reference, predicted = zip(
        *[
            _predict(dimer, force_field, topologies)
            for dimer in descent.utils.dataset.iter_dataset(dataset)
        ],
        strict=True,
    )
    return torch.cat(reference), torch.cat(predicted)

default_closure #

default_closure(
    trainable: Trainable,
    topologies: dict[str, TensorTopology],
    dataset: Dataset,
)

Return a default closure function for training against dimer energies.

Parameters:

  • trainable (Trainable) –

    The wrapper around trainable parameters.

  • topologies (dict[str, TensorTopology]) –

    The topologies of the molecules present in the dataset, with keys of mapped SMILES patterns.

  • dataset (Dataset) –

    The dataset to train against.

Returns:

  • The default closure function.

Source code in descent/targets/dimers.py
def default_closure(
    trainable: "descent.train.Trainable",
    topologies: dict[str, smee.TensorTopology],
    dataset: datasets.Dataset,
):
    """Return a default closure function for training against dimer energies.

    Args:
        trainable: The wrapper around trainable parameters.
        topologies: The topologies of the molecules present in the dataset, with keys
            of mapped SMILES patterns.
        dataset: The dataset to train against.

    Returns:
        The default closure function.
    """

    def loss_fn(_x: torch.Tensor) -> torch.Tensor:
        y_ref, y_pred = descent.targets.dimers.predict(
            dataset, trainable.to_force_field(_x), topologies
        )
        return ((y_pred - y_ref) ** 2).sum()

    return descent.utils.loss.to_closure(loss_fn)

report #

report(
    dataset: Dataset,
    force_fields: dict[str, TensorForceField],
    topologies: dict[str, dict[str, TensorTopology]],
    output_path: Path,
)

Generate a report comparing the predicted and reference energies of each dimer.

Parameters:

  • dataset (Dataset) –

    The dataset to generate the report for.

  • force_fields (dict[str, TensorForceField]) –

    The force fields to use to predict the energies.

  • topologies (dict[str, dict[str, TensorTopology]]) –

    The topologies of each monomer for the given force field. Each key should be a fully mapped SMILES string. The name of the force field must also be present in force_fields

  • output_path (Path) –

    The path to write the report to.

Source code in descent/targets/dimers.py
def report(
    dataset: datasets.Dataset,
    force_fields: dict[str, smee.TensorForceField],
    topologies: dict[str, dict[str, smee.TensorTopology]],
    output_path: pathlib.Path,
):
    """Generate a report comparing the predicted and reference energies of each dimer.

    Args:
        dataset: The dataset to generate the report for.
        force_fields: The force fields to use to predict the energies.
        topologies: The topologies of each monomer for the given force field. Each key
            should be a fully mapped SMILES string. The name of the force field must
            also be present in force_fields
        output_path: The path to write the report to.
    """
    import pandas

    rows = []

    delta_sqr_total = {
        force_field_name: torch.zeros(1) for force_field_name in force_fields
    }
    delta_sqr_count = 0

    for dimer in descent.utils.dataset.iter_dataset(dataset):
        energies = {"ref": dimer["energy"]}
        energies.update(
            (
                force_field_name,
                _predict(dimer, force_field, topologies[force_field_name])[1],
            )
            for force_field_name, force_field in force_fields.items()
        )

        mol_img = descent.utils.reporting.mols_to_img(
            dimer["smiles_a"], dimer["smiles_b"]
        )
        data_row = {"Dimer": mol_img, "Energy [kcal/mol]": _plot_energies(energies)}

        for force_field_name in force_fields:
            delta_sqr = ((energies["ref"] - energies[force_field_name]) ** 2).sum()
            delta_sqr_total[force_field_name] += delta_sqr

            rmse = torch.sqrt(delta_sqr / len(energies["ref"]))
            data_row[f"RMSE {force_field_name}"] = rmse.item()

        data_row["Source"] = dimer["source"]

        delta_sqr_count += len(energies["ref"])

        rows.append(data_row)

    rmse_total_rows = [
        {
            "Force Field": force_field_name,
            "RMSE [kcal/mol]": torch.sqrt(
                delta_sqr_total[force_field_name].sum() / delta_sqr_count
            ).item(),
        }
        for force_field_name in force_fields
    ]

    import bokeh.models.widgets.tables
    import panel

    data_full = pandas.DataFrame(rows)
    data_stats = pandas.DataFrame(rmse_total_rows)

    rmse_format = bokeh.models.widgets.tables.NumberFormatter(format="0.0000")

    formatters_stats = {
        col: rmse_format for col in data_stats.columns if col.startswith("RMSE")
    }
    formatters_full = {
        **{col: "html" for col in ["Dimer", "Energy [kcal/mol]"]},
        **{col: rmse_format for col in data_full.columns if col.startswith("RMSE")},
    }

    layout = panel.Column(
        "## Statistics",
        panel.widgets.Tabulator(
            pandas.DataFrame(rmse_total_rows),
            show_index=False,
            selectable=False,
            disabled=True,
            formatters=formatters_stats,
            configuration={"columnDefaults": {"headerSort": False}},
        ),
        "## Energies",
        panel.widgets.Tabulator(
            data_full,
            show_index=False,
            selectable=False,
            disabled=True,
            formatters=formatters_full,
            configuration={"rowHeight": 400},
        ),
        sizing_mode="stretch_width",
        scroll=True,
    )

    output_path.parent.mkdir(parents=True, exist_ok=True)
    layout.save(output_path, title="Dimers", embed=True)