mindsponge.metrics.between_residue_clash

View Source On Gitee
mindsponge.metrics.between_residue_clash(atom14_pred_positions, atom14_atom_exists, atom14_atom_radius, residue_index, c_one_hot, n_one_hot, overlap_tolerance_soft, overlap_tolerance_hard, cys_sg_idx)[source]

This is a loss penalizing any steric clashes due to non bonded atoms in different peptides coming too close.

Parameters
  • atom14_pred_positions (Tensor) – predicted positions of atoms in global prediction frame. shape is \((N_{res}, 14, 3)\) .

  • atom14_atom_exists (Tensor) – mask denoting whether atom at positions exists for given amino acid type. shape is \((N_{res}, 14)\) .

  • atom14_atom_radius (Tensor) – Van der Waals radius for each atom. shape is \((N_{res}, 14)\) .

  • residue_index (Tensor) – Residue index for given amino acid. shape is \((N_{res}, )\) , range from 1 to \(N_{res}\) .

  • c_one_hot (Tensor) – one hot encoding for C atoms (using atom14 representation). shape is (14, ) .

  • n_one_hot (Tensor) – one hot encoding for N atoms (using atom14 representation). shape is (14, ) .

  • overlap_tolerance_soft (float) – soft tolerance factor. in default: 12.0.

  • overlap_tolerance_hard (float) – hard tolerance factor. in default: 1.5.

  • cys_sg_idx (Tensor) – CYS amino acid index. Default: 5. see more at mindsponge.common.residue_constants. Shape: () .

Returns

  • Tensor, mean_loss, average clash loss. Shape is () .

  • Tensor, per_atom_loss_sum, sum of all clash losses per atom, shape is \((N_{res}, 14)\) .

  • Tensor, per_atom_clash_mask, mask whether atom clashes with any other atom, shape is \((N_{res}, 14)\) .

Symbol:

\(N_{res}\), number of amino acids.

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> import numpy as np
>>> from mindsponge.metrics import between_residue_clash
>>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32)
>>> atom14_atom_exists = Tensor(np.random.randint(2, size=(50, 14)))
>>> atom14_atom_radius = Tensor(np.random.random(size=(50, 14)), ms.float32)
>>> residue_index = Tensor(np.array(range(50)), ms.int32)
>>> c_one_hot = Tensor(np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32)
>>> n_one_hot = Tensor(np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32)
>>> overlap_tolerance_soft = 12.0
>>> overlap_tolerance_hard = 1.5
>>> cys_sg_idx = Tensor(5, ms.int32)
>>> mean_loss, per_atom_loss_sum, per_atom_clash_mask = between_residue_clash(atom14_pred_positions,
...                                                                           atom14_atom_exists,
...                                                                           atom14_atom_radius,
...                                                                           residue_index,
...                                                                           c_one_hot,
...                                                                           n_one_hot,
...                                                                           overlap_tolerance_soft,
...                                                                           overlap_tolerance_hard,
...                                                                           cys_sg_idx)
>>> print(mean_loss.shape, per_atom_loss_sum.shape, per_atom_clash_mask.shape)
() (50,14) (50,14)