mindsponge.metrics.within_residue_violations

View Source On Gitee
mindsponge.metrics.within_residue_violations(atom14_pred_positions, atom14_atom_exists, atom14_dists_lower_bound, atom14_dists_upper_bound, tighten_bounds_for_loss, dists_mask_i)[source]

Loss to penalize steric clashes within residues. This is a loss penalizing any steric violations or clashes of non-bonded atoms in a given peptide.

Parameters
  • atom14_pred_positions (Tensor) – predicted positions of atoms in global prediction frame. shape \((N_{res}, 14, 3)\) .

  • atom14_atom_exists (Tensor) – mask denoting whether atom at positions exists for given amino acid type. shape \((N_{res}, 14)\) .

  • atom14_dists_lower_bound (Tensor) – lower bond on allowed distances. shape \((N_{res}, 14, 14)\) .

  • atom14_dists_upper_bound (Tensor) – upper bond on allowed distances. shape \((N_{res}, 14, 14)\) .

  • tighten_bounds_for_loss (float) – Extra factor to tighten loss. Default: 0.0.

  • dists_mask_i (Tensor) – initial distants mask, shape: \((14, 14)\) .

Returns

  • per_atom_loss_sum (Tensor) - sum of all clash losses per atom, shape \((N_{res}, 14)\) .

  • per_atom_violations (Tensor) - violation per atom, shape \((N_{res}, 14)\) .

Symbol:

\(N_{res}\), number of amino acids.

Supported Platforms:

Ascend GPU

Examples

>>> import mindspore as ms
>>> from mindspore import Tensor
>>> import numpy as np
>>> from mindsponge.metrics import within_residue_violations
>>> atom14_pred_positions = Tensor(np.random.random(size=(50, 14, 3)), ms.float32)
>>> atom14_atom_exists = Tensor(np.random.random(size=(50, 14)), ms.float32)
>>> atom14_dists_lower_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32)
>>> atom14_dists_upper_bound = Tensor(np.random.random(size=(50, 14, 14)), ms.float32)
>>> tighten_bounds_for_loss = 0.0
>>> dists_mask_i = Tensor(np.eye(14, 14), ms.int32)
>>> per_atom_loss_sum, per_atom_violations = within_residue_violations(atom14_pred_positions,
...                                                                   atom14_atom_exists,
...                                                                   atom14_dists_lower_bound,
...                                                                   atom14_dists_upper_bound,
...                                                                   tighten_bounds_for_loss,
...                                                                   dists_mask_i)
>>> print(per_atom_loss_sum.shape, per_atom_violations.shape)
(50, 14) (50, 14)