mindsponge.metrics.within_residue_violations
- mindsponge.metrics.within_residue_violations(atom14_pred_positions, atom14_atom_exists, atom14_dists_lower_bound, atom14_dists_upper_bound, tighten_bounds_for_loss, dists_mask_i)[source]
Loss to penalize steric clashes within residues. This is a loss penalizing any steric violations or clashes of non-bonded atoms in a given peptide.
- Parameters
atom14_pred_positions (Tensor) – predicted positions of atoms in global prediction frame. shape \((N_{res}, 14, 3)\) .
atom14_atom_exists (Tensor) – mask denoting whether atom at positions exists for given amino acid type. shape \((N_{res}, 14)\) .
atom14_dists_lower_bound (Tensor) – lower bond on allowed distances. shape \((N_{res}, 14, 14)\) .
atom14_dists_upper_bound (Tensor) – upper bond on allowed distances. shape \((N_{res}, 14, 14)\) .
tighten_bounds_for_loss (float) – Extra factor to tighten loss. Default:
0.0
.dists_mask_i (Tensor) – initial distants mask, shape: \((14, 14)\) .
- Returns
per_atom_loss_sum (Tensor) - sum of all clash losses per atom, shape \((N_{res}, 14)\) .
per_atom_violations (Tensor) - violation per atom, shape \((N_{res}, 14)\) .
- Symbol:
\(N_{res}\), number of amino acids.
- Supported Platforms:
Ascend
GPU
Examples
>>> import mindspore as ms >>> from mindspore import Tensor >>> import numpy as np >>> from mindsponge.metrics import within_residue_violations >>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32) >>> atom14_atom_exists = Tensor(np.random.random(size=(50, 14)), ms.float32) >>> atom14_dists_lower_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) >>> atom14_dists_upper_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32) >>> tighten_bounds_for_loss = 0.0 >>> dists_mask_i = Tensor(np.eye(14, 14), ms.int32) >>> per_atom_loss_sum, per_atom_violations = within_residue_violations(atom14_pred_positions, ... atom14_atom_exists, ... atom14_dists_lower_bound, ... atom14_dists_upper_bound, ... tighten_bounds_for_loss, ... dists_mask_i) >>> print(per_atom_loss_sum.shape, per_atom_violations.shape) (50, 14) (50, 14)