mindsponge.metrics.between_residue_clash
- mindsponge.metrics.between_residue_clash(atom14_pred_positions, atom14_atom_exists, atom14_atom_radius, residue_index, c_one_hot, n_one_hot, overlap_tolerance_soft, overlap_tolerance_hard, cys_sg_idx)[source]
This is a loss penalizing any steric clashes due to non bonded atoms in different peptides coming too close.
- Parameters
atom14_pred_positions (Tensor) – predicted positions of atoms in global prediction frame. shape is \((N_{res}, 14, 3)\) .
atom14_atom_exists (Tensor) – mask denoting whether atom at positions exists for given amino acid type. shape is \((N_{res}, 14)\) .
atom14_atom_radius (Tensor) – Van der Waals radius for each atom. shape is \((N_{res}, 14)\) .
residue_index (Tensor) – Residue index for given amino acid. shape is \((N_{res}, )\) , range from 1 to \(N_{res}\) .
c_one_hot (Tensor) – one hot encoding for C atoms (using atom14 representation). shape is (14, ) .
n_one_hot (Tensor) – one hot encoding for N atoms (using atom14 representation). shape is (14, ) .
overlap_tolerance_soft (float) – soft tolerance factor. in default:
12.0
.overlap_tolerance_hard (float) – hard tolerance factor. in default:
1.5
.cys_sg_idx (Tensor) – CYS amino acid index. Default:
5
. see more at mindsponge.common.residue_constants. Shape: () .
- Returns
Tensor, mean_loss, average clash loss. Shape is () .
Tensor, per_atom_loss_sum, sum of all clash losses per atom, shape is \((N_{res}, 14)\) .
Tensor, per_atom_clash_mask, mask whether atom clashes with any other atom, shape is \((N_{res}, 14)\) .
- Symbol:
\(N_{res}\), number of amino acids.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> from mindspore import Tensor >>> import numpy as np >>> from mindsponge.metrics import between_residue_clash >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) >>> atom14_atom_exists = Tensor(np.random.randint(2, size=(50, 14))) >>> atom14_atom_radius = Tensor(np.random.random(size=(50, 14)), ms.float32) >>> residue_index = Tensor(np.array(range(50)), ms.int32) >>> c_one_hot = Tensor(np.array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) >>> n_one_hot = Tensor(np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), ms.int32) >>> overlap_tolerance_soft = 12.0 >>> overlap_tolerance_hard = 1.5 >>> cys_sg_idx = Tensor(5, ms.int32) >>> mean_loss, per_atom_loss_sum, per_atom_clash_mask = between_residue_clash(atom14_pred_positions, ... atom14_atom_exists, ... atom14_atom_radius, ... residue_index, ... c_one_hot, ... n_one_hot, ... overlap_tolerance_soft, ... overlap_tolerance_hard, ... cys_sg_idx) >>> print(mean_loss.shape, per_atom_loss_sum.shape, per_atom_clash_mask.shape) () (50,14) (50,14)