Skip to content

utils #

General utility functions

Functions:

  • find_exclusions

    Find all excluded interaction pairs and their associated scaling factor.

  • ones_like

    Create a tensor of ones with the same device and type as another tensor.

  • zeros_like

    Create a tensor of zeros with the same device and type as another tensor.

  • tensor_like

    Create a tensor with the same device and type as another tensor.

  • arange_like

    Arange a tensor with the same device and type as another tensor.

  • logsumexp

    Compute the log of the sum of the exponential of the input elements, optionally

  • to_upper_tri_idx

    Converts pairs of 2D indices to 1D indices in an upper triangular matrix that

  • geometric_mean

    Computes the geometric mean of two values 'safely'.

Attributes:

  • EPSILON

    A small epsilon value used to prevent divide by zero errors.

EPSILON module-attribute #

EPSILON = 1e-06

A small epsilon value used to prevent divide by zero errors.

find_exclusions #

find_exclusions(
    topology: Topology, v_sites: Optional[VSiteMap] = None
) -> dict[tuple[int, int], ExclusionType]

Find all excluded interaction pairs and their associated scaling factor.

Parameters:

  • topology (Topology) –

    The topology to find the interaction pairs of.

  • v_sites (Optional[VSiteMap], default: None ) –

    Virtual sites that will be added to the topology.

Returns:

  • A dictionary of the form ``{(atom_idx_1, atom_idx_2

    scale}``.

Source code in smee/utils.py
def find_exclusions(
    topology: openff.toolkit.Topology,
    v_sites: typing.Optional["smee.VSiteMap"] = None,
) -> dict[tuple[int, int], ExclusionType]:
    """Find all excluded interaction pairs and their associated scaling factor.

    Args:
        topology: The topology to find the interaction pairs of.
        v_sites: Virtual sites that will be added to the topology.

    Returns:
        A dictionary of the form ``{(atom_idx_1, atom_idx_2): scale}``.
    """

    graph = networkx.from_edgelist(
        tuple(
            sorted((topology.atom_index(bond.atom1), topology.atom_index(bond.atom2)))
        )
        for bond in topology.bonds
    )

    if v_sites is not None:
        for v_site_key in v_sites.keys:
            v_site_idx = v_sites.key_to_idx[v_site_key]
            parent_idx = v_site_key.orientation_atom_indices[0]

            for neighbour_idx in graph.neighbors(parent_idx):
                graph.add_edge(v_site_idx, neighbour_idx)

            graph.add_edge(v_site_idx, parent_idx)

    distances = dict(networkx.all_pairs_shortest_path_length(graph, cutoff=5))
    distance_to_scale = {1: "scale_12", 2: "scale_13", 3: "scale_14", 4: "scale_15"}

    exclusions = {}

    for idx_a in distances:
        for idx_b, distance in distances[idx_a].items():
            pair = tuple(sorted((idx_a, idx_b)))
            scale = distance_to_scale.get(distance)

            if scale is None:
                continue

            assert pair not in exclusions or exclusions[pair] == scale
            exclusions[pair] = scale

    return exclusions

ones_like #

ones_like(size: _size, other: Tensor) -> Tensor

Create a tensor of ones with the same device and type as another tensor.

Source code in smee/utils.py
def ones_like(size: _size, other: torch.Tensor) -> torch.Tensor:
    """Create a tensor of ones with the same device and type as another tensor."""
    return torch.ones(size, dtype=other.dtype, device=other.device)

zeros_like #

zeros_like(size: _size, other: Tensor) -> Tensor

Create a tensor of zeros with the same device and type as another tensor.

Source code in smee/utils.py
def zeros_like(size: _size, other: torch.Tensor) -> torch.Tensor:
    """Create a tensor of zeros with the same device and type as another tensor."""
    return torch.zeros(size, dtype=other.dtype, device=other.device)

tensor_like #

tensor_like(data: Any, other: Tensor) -> Tensor

Create a tensor with the same device and type as another tensor.

Source code in smee/utils.py
def tensor_like(data: typing.Any, other: torch.Tensor) -> torch.Tensor:
    """Create a tensor with the same device and type as another tensor."""

    if isinstance(data, torch.Tensor):
        return data.clone().detach().to(other.device, other.dtype)

    return torch.tensor(data, dtype=other.dtype, device=other.device)

arange_like #

arange_like(end: int, other: Tensor) -> Tensor

Arange a tensor with the same device and type as another tensor.

Source code in smee/utils.py
def arange_like(end: int, other: torch.Tensor) -> torch.Tensor:
    """Arange a tensor with the same device and type as another tensor."""
    return torch.arange(end, dtype=other.dtype, device=other.device)

logsumexp #

logsumexp(
    a: Tensor,
    dim: int,
    keepdim: bool = False,
    b: Tensor | None = None,
    return_sign: bool = False,
) -> Tensor | tuple[Tensor, Tensor]

Compute the log of the sum of the exponential of the input elements, optionally with each element multiplied by a scaling factor.

Notes

This should be removed if torch.logsumexp is updated to support scaling factors.

Parameters:

  • a (Tensor) –

    The elements that should be summed over.

  • dim (int) –

    The dimension to sum over.

  • keepdim (bool, default: False ) –

    Whether to keep the summed dimension.

  • b (Tensor | None, default: None ) –

    The scaling factor to multiply each element by.

Returns:

  • Tensor | tuple[Tensor, Tensor]

    The log of the sum of exponential of the a elements.

Source code in smee/utils.py
def logsumexp(
    a: torch.Tensor,
    dim: int,
    keepdim: bool = False,
    b: torch.Tensor | None = None,
    return_sign: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """Compute the log of the sum of the exponential of the input elements, optionally
    with each element multiplied by a scaling factor.

    Notes:
        This should be removed if torch.logsumexp is updated to support scaling factors.

    Args:
        a: The elements that should be summed over.
        dim: The dimension to sum over.
        keepdim: Whether to keep the summed dimension.
        b: The scaling factor to multiply each element by.

    Returns:
        The log of the sum of exponential of the a elements.
    """
    a_type = a.dtype

    if b is None:
        assert return_sign is False
        return torch.logsumexp(a, dim, keepdim)

    a = a.double()
    b = b if b is not None else b.double()

    a, b = torch.broadcast_tensors(a, b)

    if torch.any(b == 0):
        a[b == 0] = -torch.inf

    a_max = torch.amax(a, dim=dim, keepdim=True)

    if a_max.ndim > 0:
        a_max[~torch.isfinite(a_max)] = 0
    elif not torch.isfinite(a_max):
        a_max = 0

    exp_sum = torch.sum(b * torch.exp(a - a_max), dim=dim, keepdim=keepdim)
    sign = None

    if return_sign:
        sign = torch.sign(exp_sum)
        exp_sum = exp_sum * sign

    ln_exp_sum = torch.log(exp_sum)

    if not keepdim:
        a_max = torch.squeeze(a_max, dim=dim)

    ln_exp_sum += a_max
    ln_exp_sum = ln_exp_sum.to(a_type)

    if return_sign:
        return ln_exp_sum, sign.to(a_type)
    else:
        return ln_exp_sum

to_upper_tri_idx #

to_upper_tri_idx(
    i: Tensor, j: Tensor, n: int, include_diag: bool = False
) -> Tensor

Converts pairs of 2D indices to 1D indices in an upper triangular matrix that excludes the diagonal.

Parameters:

  • i (Tensor) –

    A tensor of the indices along the first axis with shape=(n_pairs,).

  • j (Tensor) –

    A tensor of the indices along the second axis with shape=(n_pairs,).

  • n (int) –

    The size of the matrix.

  • include_diag (bool, default: False ) –

    Whether the diagonal is included in the upper triangular matrix.

Returns:

  • Tensor

    A tensor of the indices in the upper triangular matrix with shape=(n_pairs * (n_pairs - 1) // 2,).

Source code in smee/utils.py
def to_upper_tri_idx(
    i: torch.Tensor, j: torch.Tensor, n: int, include_diag: bool = False
) -> torch.Tensor:
    """Converts pairs of 2D indices to 1D indices in an upper triangular matrix that
    excludes the diagonal.

    Args:
        i: A tensor of the indices along the first axis with ``shape=(n_pairs,)``.
        j: A tensor of the indices along the second axis with ``shape=(n_pairs,)``.
        n: The size of the matrix.
        include_diag: Whether the diagonal is included in the upper triangular matrix.

    Returns:
        A tensor of the indices in the upper triangular matrix with
        ``shape=(n_pairs * (n_pairs - 1) // 2,)``.
    """

    if not include_diag:
        assert (i < j).all(), "i must be less than j"
        return (i * (2 * n - i - 1)) // 2 + j - i - 1

    assert (i <= j).all(), "i must be less than or equal to j"
    return (i * (2 * n - i + 1)) // 2 + j - i

geometric_mean #

geometric_mean(eps_a: Tensor, eps_b: Tensor) -> Tensor

Computes the geometric mean of two values 'safely'.

A small epsilon (smee.utils.EPSILON) is added when computing the gradient in cases where the mean is zero to prevent divide by zero errors.

Parameters:

  • eps_a (Tensor) –

    The first value.

  • eps_b (Tensor) –

    The second value.

Returns:

  • Tensor

    The geometric mean of the two values.

Source code in smee/utils.py
def geometric_mean(eps_a: torch.Tensor, eps_b: torch.Tensor) -> torch.Tensor:
    """Computes the geometric mean of two values 'safely'.

    A small epsilon (``smee.utils.EPSILON``) is added when computing the gradient in
    cases where the mean is zero to prevent divide by zero errors.

    Args:
        eps_a: The first value.
        eps_b: The second value.

    Returns:
        The geometric mean of the two values.
    """

    return _SafeGeometricMean.apply(eps_a, eps_b)