Source code for mindsponge.common.utils

# Copyright 2021 The AIMM Group at Shenzhen Bay Laboratory & Peking University & Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""utils module"""

import numpy as np
from Bio import Align
from Bio.Align import substitution_matrices
import mindspore as ms
from mindspore import nn, ops, Tensor
import mindspore.numpy as mnp
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import lazy_inline, get_context
from . import geometry
from . import residue_constants, protein


def null_decorator(func):
    return func


if get_context("device_target") == "Ascend":
    cus_lazy_inline = lazy_inline
else:
    cus_lazy_inline = null_decorator

STACK_NAME = [
    'msa_act.msa_row_attention_with_pair_bias.query_norm_gammas',
    'msa_act.msa_row_attention_with_pair_bias.query_norm_betas',
    'msa_act.msa_row_attention_with_pair_bias.feat_2d_norm_gammas',
    'msa_act.msa_row_attention_with_pair_bias.feat_2d_norm_betas',
    'msa_act.msa_row_attention_with_pair_bias.feat_2d_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.linear_q_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.linear_k_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.linear_v_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.linear_output_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.o_biases',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.linear_gating_weights',
    'msa_act.msa_row_attention_with_pair_bias.attn_mod.gating_biases',
    'msa_act.msa_transition.input_layer_norm_gammas',
    'msa_act.msa_transition.input_layer_norm_betas',
    'msa_act.msa_transition.transition1_weights',
    'msa_act.msa_transition.transition1_biases',
    'msa_act.msa_transition.transition2_weights',
    'msa_act.msa_transition.transition2_biases',
    'msa_act.attn_mod.query_norm_gammas',
    'msa_act.attn_mod.query_norm_betas',
    'msa_act.attn_mod.attn_mod.linear_q_weights',
    'msa_act.attn_mod.attn_mod.linear_k_weights',
    'msa_act.attn_mod.attn_mod.linear_v_weights',
    'msa_act.attn_mod.attn_mod.linear_output_weights',
    'msa_act.attn_mod.attn_mod.o_biases',
    'msa_act.attn_mod.attn_mod.linear_gating_weights',
    'msa_act.attn_mod.attn_mod.gating_biases',
    'pair_act.outer_product_mean.layer_norm_input_gammas',
    'pair_act.outer_product_mean.layer_norm_input_betas',
    'pair_act.outer_product_mean.left_projection_weights',
    'pair_act.outer_product_mean.left_projection_biases',
    'pair_act.outer_product_mean.right_projection_weights',
    'pair_act.outer_product_mean.right_projection_biases',
    'pair_act.outer_product_mean.linear_output_weights',
    'pair_act.outer_product_mean.o_biases',
    'pair_act.triangle_attention_starting_node.query_norm_gammas',
    'pair_act.triangle_attention_starting_node.query_norm_betas',
    'pair_act.triangle_attention_starting_node.feat_2d_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.linear_q_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.linear_k_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.linear_v_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.linear_output_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.o_biases',
    'pair_act.triangle_attention_starting_node.attn_mod.linear_gating_weights',
    'pair_act.triangle_attention_starting_node.attn_mod.gating_biases',
    'pair_act.triangle_attention_ending_node.query_norm_gammas',
    'pair_act.triangle_attention_ending_node.query_norm_betas',
    'pair_act.triangle_attention_ending_node.feat_2d_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.linear_q_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.linear_k_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.linear_v_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.linear_output_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.o_biases',
    'pair_act.triangle_attention_ending_node.attn_mod.linear_gating_weights',
    'pair_act.triangle_attention_ending_node.attn_mod.gating_biases',
    'pair_act.pair_transition.input_layer_norm_gammas',
    'pair_act.pair_transition.input_layer_norm_betas',
    'pair_act.pair_transition.transition1_weights',
    'pair_act.pair_transition.transition1_biases',
    'pair_act.pair_transition.transition2_weights',
    'pair_act.pair_transition.transition2_biases',
    'pair_act.triangle_multiplication_outgoing.layer_norm_input_gammas',
    'pair_act.triangle_multiplication_outgoing.layer_norm_input_betas',
    'pair_act.triangle_multiplication_outgoing.left_projection_weights',
    'pair_act.triangle_multiplication_outgoing.left_projection_biases',
    'pair_act.triangle_multiplication_outgoing.right_projection_weights',
    'pair_act.triangle_multiplication_outgoing.right_projection_biases',
    'pair_act.triangle_multiplication_outgoing.left_gate_weights',
    'pair_act.triangle_multiplication_outgoing.left_gate_biases',
    'pair_act.triangle_multiplication_outgoing.right_gate_weights',
    'pair_act.triangle_multiplication_outgoing.right_gate_biases',
    'pair_act.triangle_multiplication_outgoing.center_layer_norm_gammas',
    'pair_act.triangle_multiplication_outgoing.center_layer_norm_betas',
    'pair_act.triangle_multiplication_outgoing.output_projection_weights',
    'pair_act.triangle_multiplication_outgoing.output_projection_biases',
    'pair_act.triangle_multiplication_outgoing.gating_linear_weights',
    'pair_act.triangle_multiplication_outgoing.gating_linear_biases',
    'pair_act.triangle_multiplication_incoming.layer_norm_input_gammas',
    'pair_act.triangle_multiplication_incoming.layer_norm_input_betas',
    'pair_act.triangle_multiplication_incoming.left_projection_weights',
    'pair_act.triangle_multiplication_incoming.left_projection_biases',
    'pair_act.triangle_multiplication_incoming.right_projection_weights',
    'pair_act.triangle_multiplication_incoming.right_projection_biases',
    'pair_act.triangle_multiplication_incoming.left_gate_weights',
    'pair_act.triangle_multiplication_incoming.left_gate_biases',
    'pair_act.triangle_multiplication_incoming.right_gate_weights',
    'pair_act.triangle_multiplication_incoming.right_gate_biases',
    'pair_act.triangle_multiplication_incoming.center_layer_norm_gammas',
    'pair_act.triangle_multiplication_incoming.center_layer_norm_betas',
    'pair_act.triangle_multiplication_incoming.output_projection_weights',
    'pair_act.triangle_multiplication_incoming.output_projection_biases',
    'pair_act.triangle_multiplication_incoming.gating_linear_weights',
    'pair_act.triangle_multiplication_incoming.gating_linear_biases']

def get_predict_checkpoint(train_ckpt, msa_layers, predict_ckpt):
    """convert megafold checkpoint: from training checkpoint to predict checkpoint.

    Args:
        train_ckpt(str): Path of the training checkpoint.
        msa_layers(int): Number of Msa stack layers.
        predict_ckpt(str): Save path of the predict checkpoint.
    """
    parameters = ms.load_checkpoint(train_ckpt)
    resave_parameters = []
    for key, value in parameters.items():
        if key.replace("msa_stack.", "") in STACK_NAME:
            continue
        resave_parameters.append({"name": key, "data": value})

    for name in STACK_NAME:
        predict_value = []
        for i in range(msa_layers):
            predict_name = f"msa_stack.{i}." + name
            predict_value_part = parameters[predict_name].asnumpy()
            predict_value.append(predict_value_part)
        predict_value = Tensor(np.array(predict_value))
        resave_parameters.append({"name": "msa_stack." + name, "data": predict_value})
    ms.save_checkpoint(resave_parameters, predict_ckpt)

def get_train_checkpoint(train_ckpt, msa_layers, predict_ckpt):
    """convert megafold checkpoint: from predict checkpoint to training checkpoint.

    Args:
        train_ckpt(str): Save path of the training checkpoint.
        msa_layers(int): Number of Msa stack layers.
        predict_ckpt(str): Path of the training checkpoint.
    """
    predict_name_list = []
    for name in STACK_NAME:
        for i in range(msa_layers):
            predict_name = f"msa_stack.{i}." + name
            predict_name_list.append(predict_name)
    parameters = ms.load_checkpoint(predict_ckpt)
    resave_parameters = []
    for key, value in parameters.items():
        if key in predict_name_list:
            continue
        resave_parameters.append({"name": key, "data": value})

    for name in STACK_NAME:
        save_value = parameters["msa_stack." + name]
        for i in range(msa_layers):
            train_name = f"msa_stack.{i}." + name
            train_value = save_value[i]
            resave_parameters.append({"name": train_name, "data": train_value})
    ms.save_checkpoint(resave_parameters, train_ckpt)

def _memory_reduce(body, batched_inputs, nonbatched_inputs, slice_num, dim=0):
    """memory reduce function"""
    if slice_num <= 1:
        inputs = batched_inputs + nonbatched_inputs
        return body(*inputs)
    inner_batched_inputs = []
    for val in batched_inputs:
        inner_val = P.Split(dim, slice_num)(val)
        inner_batched_inputs.append(inner_val)
    # for depend
    inner_split_batched_inputs = ()
    for j in range(len(inner_batched_inputs)):
        inner_split_batched_inputs = inner_split_batched_inputs + (inner_batched_inputs[j][0],)
    inner_split_inputs = inner_split_batched_inputs + nonbatched_inputs
    inner_split_res = body(*inner_split_inputs)
    res = (inner_split_res,)
    for i in range(1, slice_num):
        inner_split_batched_inputs = ()
        for j in range(len(inner_batched_inputs)):
            inner_split_batched_inputs = inner_split_batched_inputs + (inner_batched_inputs[j][i],)
        inner_split_inputs = inner_split_batched_inputs + nonbatched_inputs
        inner_split_inputs = F.depend(inner_split_inputs, res[-1])
        inner_split_res = body(*inner_split_inputs)
        res = res + (inner_split_res,)
    res = P.Concat()(res)
    return res


def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
    """Create pseudo beta features."""

    is_gly = mnp.equal(aatype, residue_constants.restype_order['G'])
    ca_idx = residue_constants.atom_order['CA']
    cb_idx = residue_constants.atom_order['CB']
    pseudo_beta = mnp.where(
        mnp.tile(is_gly[..., None], [1,] * len(is_gly.shape) + [3,]),
        all_atom_positions[..., ca_idx, :],
        all_atom_positions[..., cb_idx, :])
    if all_atom_masks is not None:
        pseudo_beta_mask = mnp.where(is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
        pseudo_beta_mask = pseudo_beta_mask.astype(mnp.float32)
        return pseudo_beta, pseudo_beta_mask
    return pseudo_beta


def dgram_from_positions(positions, num_bins, min_bin, max_bin, ret_type):
    """Compute distogram from amino acid positions.

    Arguments:
    positions: [N_res, 3] Position coordinates.
    num_bins: The number of bins in the distogram.
    min_bin: The left edge of the first bin.
    max_bin: The left edge of the final bin. The final bin catches
    everything larger than `max_bin`.

    Returns:
    Distogram with the specified number of bins.
    """

    def squared_difference(x, y):
        return mnp.square(x - y)

    lower_breaks = ops.linspace(min_bin, max_bin, num_bins)
    lower_breaks = mnp.square(lower_breaks)
    upper_breaks = mnp.concatenate([lower_breaks[1:], mnp.array([1e8], dtype=mnp.float32)], axis=-1)
    dist2 = mnp.sum(squared_difference(mnp.expand_dims(positions, axis=-2),
                                       mnp.expand_dims(positions, axis=-3)), axis=-1, keepdims=True)
    dgram = ((dist2 > lower_breaks).astype(ret_type) * (dist2 < upper_breaks).astype(ret_type))
    return dgram


def atom37_to_torsion_angles(
        aatype,  # (B, N)
        all_atom_pos,  # (B, N, 37, 3)
        all_atom_mask,  # (B, N, 37)
        chi_atom_indices,
        chi_angles_mask,
        mirror_psi_mask,
        chi_pi_periodic,
        indices0,
        indices1
):
    """Computes the 7 torsion angles (in sin, cos encoding) for each residue.

    The 7 torsion angles are in the order
    '[pre_omega, phi, psi, chi_1, chi_2, chi_3, chi_4]',
    here pre_omega denotes the omega torsion angle between the given amino acid
    and the previous amino acid.

    Args:
    aatype: Amino acid type, given as array with integers.
    all_atom_pos: atom37 representation of all atom coordinates.
    all_atom_mask: atom37 representation of mask on all atom coordinates.
    placeholder_for_undefined: flag denoting whether to set masked torsion
    angles to zero.
    Returns:
    Dict containing:
    * 'torsion_angles_sin_cos': Array with shape (B, N, 7, 2) where the final
    2 dimensions denote sin and cos respectively
    * 'alt_torsion_angles_sin_cos': same as 'torsion_angles_sin_cos', but
    with the angle shifted by pi for all chi angles affected by the naming
    ambiguities.
    * 'torsion_angles_mask': Mask for which chi angles are present.
    """

    # Map aatype > 20 to 'Unknown' (20).
    aatype = mnp.minimum(aatype, 20)

    # Compute the backbone angles.
    num_batch, num_res = aatype.shape

    pad = mnp.zeros([num_batch, 1, 37, 3], mnp.float32)
    prev_all_atom_pos = mnp.concatenate([pad, all_atom_pos[:, :-1, :, :]], axis=1)

    pad = mnp.zeros([num_batch, 1, 37], mnp.float32)
    prev_all_atom_mask = mnp.concatenate([pad, all_atom_mask[:, :-1, :]], axis=1)

    # For each torsion angle collect the 4 atom positions that define this angle.
    # shape (B, N, atoms=4, xyz=3)
    pre_omega_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 1:3, :], all_atom_pos[:, :, 0:2, :]], axis=-2)
    phi_atom_pos = mnp.concatenate([prev_all_atom_pos[:, :, 2:3, :], all_atom_pos[:, :, 0:3, :]], axis=-2)
    psi_atom_pos = mnp.concatenate([all_atom_pos[:, :, 0:3, :], all_atom_pos[:, :, 4:5, :]], axis=-2)
    # # Collect the masks from these atoms.
    # # Shape [batch, num_res]
    # ERROR NO PROD
    pre_omega_mask = (P.ReduceProd()(prev_all_atom_mask[:, :, 1:3], -1)  # prev CA, C
                      * P.ReduceProd()(all_atom_mask[:, :, 0:2], -1))  # this N, CA
    phi_mask = (prev_all_atom_mask[:, :, 2]  # prev C
                * P.ReduceProd()(all_atom_mask[:, :, 0:3], -1))  # this N, CA, C
    psi_mask = (P.ReduceProd()(all_atom_mask[:, :, 0:3], -1) *  # this N, CA, C
                all_atom_mask[:, :, 4])  # this O
    # Collect the atoms for the chi-angles.
    # Compute the table of chi angle indices. Shape: [restypes, chis=4, atoms=4].
    # Select atoms to compute chis. Shape: [batch, num_res, chis=4, atoms=4].
    atom_indices = mnp.take(chi_atom_indices, aatype, axis=0)

    # # Gather atom positions Batch Gather. Shape: [batch, num_res, chis=4, atoms=4, xyz=3].

    # 4 seq_length 4 4  batch, sequence length, chis, atoms
    seq_length = all_atom_pos.shape[1]
    atom_indices = atom_indices.reshape((4, seq_length, 4, 4, 1)).astype("int32")
    new_indices = P.Concat(4)((indices0, indices1, atom_indices))  # 4, seq_length, 4, 4, 3
    chis_atom_pos = P.GatherNd()(all_atom_pos, new_indices)
    chis_mask = mnp.take(chi_angles_mask, aatype, axis=0)
    chi_angle_atoms_mask = P.GatherNd()(all_atom_mask, new_indices)

    # Check if all 4 chi angle atoms were set. Shape: [batch, num_res, chis=4].
    chi_angle_atoms_mask = P.ReduceProd()(chi_angle_atoms_mask, -1)
    chis_mask = chis_mask * (chi_angle_atoms_mask).astype(mnp.float32)

    # Stack all torsion angle atom positions.
    # Shape (B, N, torsions=7, atoms=4, xyz=3)ls
    torsions_atom_pos = mnp.concatenate([pre_omega_atom_pos[:, :, None, :, :],
                                         phi_atom_pos[:, :, None, :, :],
                                         psi_atom_pos[:, :, None, :, :],
                                         chis_atom_pos], axis=2)
    # Stack up masks for all torsion angles.
    # shape (B, N, torsions=7)
    torsion_angles_mask = mnp.concatenate([pre_omega_mask[:, :, None],
                                           phi_mask[:, :, None],
                                           psi_mask[:, :, None],
                                           chis_mask], axis=2)

    torsion_rigid = geometry.rigids_from_3_points(
        geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 1, :]),
        geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 2, :]),
        geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 0, :]))
    inv_torsion_rigid = geometry.invert_rigids(torsion_rigid)
    forth_atom_rel_pos = geometry.rigids_mul_vecs(inv_torsion_rigid,
                                                  geometry.vecs_from_tensor(torsions_atom_pos[:, :, :, 3, :]))
    # Compute the position of the forth atom in this frame (y and z coordinate
    torsion_angles_sin_cos = mnp.stack([forth_atom_rel_pos[2], forth_atom_rel_pos[1]], axis=-1)
    torsion_angles_sin_cos /= mnp.sqrt(mnp.sum(mnp.square(torsion_angles_sin_cos), axis=-1, keepdims=True) + 1e-8)
    # Mirror psi, because we computed it from the Oxygen-atom.
    torsion_angles_sin_cos *= mirror_psi_mask
    chi_is_ambiguous = mnp.take(chi_pi_periodic, aatype, axis=0)
    mirror_torsion_angles = mnp.concatenate([mnp.ones([num_batch, num_res, 3]), 1.0 - 2.0 * chi_is_ambiguous], axis=-1)
    alt_torsion_angles_sin_cos = (torsion_angles_sin_cos * mirror_torsion_angles[:, :, :, None])
    return torsion_angles_sin_cos, alt_torsion_angles_sin_cos, torsion_angles_mask


def rigids_from_tensor4x4(m):
    """Construct Rigids object from an 4x4 array.

    Here the 4x4 is representing the transformation in homogeneous coordinates.

    Args:
    m: Array representing transformations in homogeneous coordinates.
    Returns:
    Rigids object corresponding to transformations m
    """
    rotation = (m[..., 0, 0], m[..., 0, 1], m[..., 0, 2],
                m[..., 1, 0], m[..., 1, 1], m[..., 1, 2],
                m[..., 2, 0], m[..., 2, 1], m[..., 2, 2])
    trans = (m[..., 0, 3], m[..., 1, 3], m[..., 2, 3])
    rigid = (rotation, trans)
    return rigid


def frames_and_literature_positions_to_atom14_pos(aatype, all_frames_to_global, restype_atom14_to_rigid_group,
                                                  restype_atom14_rigid_group_positions, restype_atom14_mask):  # (N, 14)
    """Put atom literature positions (atom14 encoding) in each rigid group.

    Jumper et al. (2021) Suppl. Alg. 24 "computeAllAtomCoordinates" line 11

    Args:
    aatype: aatype for each residue.
    all_frames_to_global: All per residue coordinate frames.
    Returns:
    Positions of all atom coordinates in global frame.
    """

    # Pick the appropriate transform for every atom.
    residx_to_group_idx = P.Gather()(restype_atom14_to_rigid_group, aatype, 0)
    group_mask = nn.OneHot(depth=8, axis=-1)(residx_to_group_idx)

    # Rigids with shape (N, 14)
    map_atoms_to_global = map_atoms_to_global_func(all_frames_to_global, group_mask)

    # Gather the literature atom positions for each residue.
    # Vecs with shape (N, 14)
    lit_positions = geometry.vecs_from_tensor(P.Gather()(restype_atom14_rigid_group_positions, aatype, 0))

    # Transform each atom from its local frame to the global frame.
    # Vecs with shape (N, 14)
    pred_positions = geometry.rigids_mul_vecs(map_atoms_to_global, lit_positions)

    # Mask out non-existing atoms.
    mask = P.Gather()(restype_atom14_mask, aatype, 0)

    pred_positions = geometry.vecs_scale(pred_positions, mask)

    return pred_positions


def rigids_concate_all(xall, x5, x6, x7):
    """rigids concate all."""
    x5 = (geometry.rots_expand_dims(x5[0], -1), geometry.vecs_expand_dims(x5[1], -1))
    x6 = (geometry.rots_expand_dims(x6[0], -1), geometry.vecs_expand_dims(x6[1], -1))
    x7 = (geometry.rots_expand_dims(x7[0], -1), geometry.vecs_expand_dims(x7[1], -1))
    xall_rot = xall[0]
    xall_rot_slice = []
    for val in xall_rot:
        xall_rot_slice.append(val[:, 0:5])
    xall_trans = xall[1]
    xall_trans_slice = []
    for val in xall_trans:
        xall_trans_slice.append(val[:, 0:5])
    xall = (xall_rot_slice, xall_trans_slice)
    res_rot = []
    for i in range(9):
        res_rot.append(mnp.concatenate((xall[0][i], x5[0][i], x6[0][i], x7[0][i]), axis=-1))
    res_trans = []
    for i in range(3):
        res_trans.append(mnp.concatenate((xall[1][i], x5[1][i], x6[1][i], x7[1][i]), axis=-1))
    return (res_rot, res_trans)


def torsion_angles_to_frames(aatype, backb_to_global, torsion_angles_sin_cos, restype_rigid_group_default_frame):
    """Compute rigid group frames from torsion angles."""

    # Gather the default frames for all rigid groups.
    m = P.Gather()(restype_rigid_group_default_frame, aatype, 0)

    default_frames = rigids_from_tensor4x4(m)

    # Create the rotation matrices according to the given angles (each frame is
    # defined such that its rotation is around the x-axis).
    sin_angles = torsion_angles_sin_cos[..., 0]
    cos_angles = torsion_angles_sin_cos[..., 1]

    # insert zero rotation for backbone group.
    num_residues, = aatype.shape
    sin_angles = mnp.concatenate([mnp.zeros([num_residues, 1]), sin_angles], axis=-1)
    cos_angles = mnp.concatenate([mnp.ones([num_residues, 1]), cos_angles], axis=-1)
    zeros = mnp.zeros_like(sin_angles)
    ones = mnp.ones_like(sin_angles)

    all_rots = (ones, zeros, zeros,
                zeros, cos_angles, -sin_angles,
                zeros, sin_angles, cos_angles)

    # Apply rotations to the frames.
    all_frames = geometry.rigids_mul_rots(default_frames, all_rots)
    # chi2, chi3, and chi4 frames do not transform to the backbone frame but to
    # the previous frame. So chain them up accordingly.
    chi2_frame_to_frame = ((all_frames[0][0][:, 5], all_frames[0][1][:, 5], all_frames[0][2][:, 5],
                            all_frames[0][3][:, 5], all_frames[0][4][:, 5], all_frames[0][5][:, 5],
                            all_frames[0][6][:, 5], all_frames[0][7][:, 5], all_frames[0][8][:, 5]),
                           (all_frames[1][0][:, 5], all_frames[1][1][:, 5], all_frames[1][2][:, 5]))
    chi3_frame_to_frame = ((all_frames[0][0][:, 6], all_frames[0][1][:, 6], all_frames[0][2][:, 6],
                            all_frames[0][3][:, 6], all_frames[0][4][:, 6], all_frames[0][5][:, 6],
                            all_frames[0][6][:, 6], all_frames[0][7][:, 6], all_frames[0][8][:, 6]),
                           (all_frames[1][0][:, 6], all_frames[1][1][:, 6], all_frames[1][2][:, 6]))

    chi4_frame_to_frame = ((all_frames[0][0][:, 7], all_frames[0][1][:, 7], all_frames[0][2][:, 7],
                            all_frames[0][3][:, 7], all_frames[0][4][:, 7], all_frames[0][5][:, 7],
                            all_frames[0][6][:, 7], all_frames[0][7][:, 7], all_frames[0][8][:, 7]),
                           (all_frames[1][0][:, 7], all_frames[1][1][:, 7], all_frames[1][2][:, 7]))

    chi1_frame_to_backb = ((all_frames[0][0][:, 4], all_frames[0][1][:, 4], all_frames[0][2][:, 4],
                            all_frames[0][3][:, 4], all_frames[0][4][:, 4], all_frames[0][5][:, 4],
                            all_frames[0][6][:, 4], all_frames[0][7][:, 4], all_frames[0][8][:, 4]),
                           (all_frames[1][0][:, 4], all_frames[1][1][:, 4], all_frames[1][2][:, 4]))

    chi2_frame_to_backb = geometry.rigids_mul_rigids(chi1_frame_to_backb, chi2_frame_to_frame)
    chi3_frame_to_backb = geometry.rigids_mul_rigids(chi2_frame_to_backb, chi3_frame_to_frame)
    chi4_frame_to_backb = geometry.rigids_mul_rigids(chi3_frame_to_backb, chi4_frame_to_frame)

    # Recombine them to a Rigids with shape (N, 8).
    all_frames_to_backb = rigids_concate_all(all_frames, chi2_frame_to_backb,
                                             chi3_frame_to_backb, chi4_frame_to_backb)

    backb_to_global = (geometry.rots_expand_dims(backb_to_global[0], -1),
                       geometry.vecs_expand_dims(backb_to_global[1], -1))
    # Create the global frames.
    all_frames_to_global = geometry.rigids_mul_rigids(backb_to_global, all_frames_to_backb)
    return all_frames_to_global


def map_atoms_to_global_func(all_frames, group_mask):
    """map atoms to global."""
    all_frames_rot = all_frames[0]
    all_frames_trans = all_frames[1]
    rot = geometry.rots_scale(geometry.rots_expand_dims(all_frames_rot, 1), group_mask)
    res_rot = []
    for val in rot:
        res_rot.append(mnp.sum(val, axis=-1))
    trans = geometry.vecs_scale(geometry.vecs_expand_dims(all_frames_trans, 1), group_mask)
    res_trans = []
    for val in trans:
        res_trans.append(mnp.sum(val, axis=-1))
    return (res_rot, res_trans)


def atom14_to_atom37(atom14_data, residx_atom37_to_atom14, atom37_atom_exists, indices0):
    """Convert atom14 to atom37 representation."""

    seq_length = atom14_data.shape[0]
    residx_atom37_to_atom14 = residx_atom37_to_atom14.reshape((seq_length, 37, 1))
    new_indices = P.Concat(2)((indices0, residx_atom37_to_atom14))

    atom37_data = P.GatherNd()(atom14_data, new_indices)

    if len(atom14_data.shape) == 2:
        atom37_data *= atom37_atom_exists
    elif len(atom14_data.shape) == 3:
        atom37_data *= atom37_atom_exists[:, :, None].astype(atom37_data.dtype)

    return atom37_data


[docs]def make_atom14_positions(aatype, all_atom_mask, all_atom_positions): """ The function of transforming sparse encoding method to densely encoding method. Total coordinate encoding for atoms in proteins comes in two forms. - Sparse encoding, 20 amino acids contain a total of 37 atom types as shown in `common.residue_constants.atom_types`. So coordinates of atoms in protein can be encoded as a Tensor with shape :math:`(N_{res}, 37, 3)`. - Densely encoding. 20 amino acids contain a total of 14 atom types as shown in `common.residue_constants.restype_name_to_atom14_names`. So coordinates of atoms in protein can be encoded as a Tensor with shape :math:`(N_{res}, 14, 3)`. Args: aatype(numpy.ndarray): Protein sequence encoding. the encoding method refers to `common.residue_constants.restype_order`. Value range is :math:`[0,20]`. 20 means the amino acid is unknown (`UNK`). all_atom_mask(numpy.ndarray): Mask of coordinates of all atoms in proteins. Shape is :math:`(N_{res}, 37)`. If the corresponding position is 0, the amino acid does not contain the atom. all_atom_positions(numpy.ndarray): Coordinates of all atoms in protein. Shape is :math:`(N_{res}, 37, 3)` . Returns: - numpy.array. Densely encoding, mask of all atoms in protein, including unknown amino acid atoms. Shape is :math:`(N_{res}, 14)`. - numpy.array. Densely encoding, mask of all atoms in protein, excluding unknown amino acid atoms. Shape is :math:`(N_{res}, 14)`. - numpy.array. Densely encoding, coordinates of all atoms in protein. Shape is :math:`(N_{res}, 14, 3)`. - numpy.array. Index of mapping sparse encoding atoms with densely encoding method. Shape is :math:`(N_{res}, 14)` . - numpy.array. Index of mapping densely encoding atoms with sparse encoding method. Shape is :math:`(N_{res}, 37)` . - numpy.array. Sparse encoding, mask of all atoms in protein, including unknown amino acid atoms. Shape is :math:`(N_{res}, 37)` - numpy.array. The atomic coordinates after chiral transformation for the atomic coordinates of densely encoding method. Shape is :math:`(N_{res}, 14, 3)` . - numpy.array. Atom mask after chiral transformation. Shape is :math:`(N_{res}, 14)` . - numpy.array. Atom identifier of the chiral transformation. 1 is transformed and 0 is not transformed. Shape is :math:`(N_{res}, 14)` . Symbol: - :math:`N_{res}` - The number of amino acids in a protein, according to the sequence of the protein. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindsponge.common import make_atom14_positions >>> from mindsponge.common import protein >>> import numpy as np >>> pdb_path = "YOUR_PDB_FILE" >>> with open(pdb_path, 'r', encoding = 'UTF-8') as f: >>> prot_pdb = protein.from_pdb_string(f.read()) >>> result = make_atom14_positions(prot_pdb.aatype, prot_pdb.atom_mask.astype(np.float32), >>> prot_pdb.atom_positions.astype(np.float32)) >>> for val in result: >>> print(val.shape) (Nres, 14) (Nres, 14) (Nres, 14, 3) (Nres, 14) (Nres, 37) (Nres, 37) (Nres, 14, 3) (Nres, 14) (Nres, 14) """ restype_atom14_to_atom37 = [] # mapping (restype, atom14) --> atom37 restype_atom37_to_atom14 = [] # mapping (restype, atom37) --> atom14 restype_atom14_mask = [] for rt in residue_constants.restypes: atom_names = residue_constants.restype_name_to_atom14_names[ residue_constants.restype_1to3[rt]] restype_atom14_to_atom37.append([ (residue_constants.atom_order[name] if name else 0) for name in atom_names ]) atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} restype_atom37_to_atom14.append([ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) for name in residue_constants.atom_types ]) restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) # Add dummy mapping for restype 'UNK'. restype_atom14_to_atom37.append([0] * 14) restype_atom37_to_atom14.append([0] * 37) restype_atom14_mask.append([0.] * 14) restype_atom14_to_atom37 = np.array(restype_atom14_to_atom37, dtype=np.int32) restype_atom37_to_atom14 = np.array(restype_atom37_to_atom14, dtype=np.int32) restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) # Create the mapping for (residx, atom14) --> atom37, i.e. an array # with shape (num_res, 14) containing the atom37 indices for this protein. residx_atom14_to_atom37 = restype_atom14_to_atom37[aatype] residx_atom14_mask = restype_atom14_mask[aatype] # Create a mask for known ground truth positions. residx_atom14_gt_mask = residx_atom14_mask * np.take_along_axis( all_atom_mask, residx_atom14_to_atom37, axis=1).astype(np.float32) # Gather the ground truth positions. residx_atom14_gt_positions = residx_atom14_gt_mask[:, :, None] * ( np.take_along_axis(all_atom_positions, residx_atom14_to_atom37[..., None], axis=1)) atom14_atom_exists = residx_atom14_mask atom14_gt_exists = residx_atom14_gt_mask atom14_gt_positions = residx_atom14_gt_positions residx_atom14_to_atom37 = residx_atom14_to_atom37 # Create the gather indices for mapping back. residx_atom37_to_atom14 = restype_atom37_to_atom14[aatype] # Create the corresponding mask. restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) for restype, restype_letter in enumerate(residue_constants.restypes): restype_name = residue_constants.restype_1to3[restype_letter] atom_names = residue_constants.residue_atoms[restype_name] for atom_name in atom_names: atom_type = residue_constants.atom_order[atom_name] restype_atom37_mask[restype, atom_type] = 1 atom37_atom_exists = restype_atom37_mask[aatype] # As the atom naming is ambiguous for 7 of the 20 amino acids, provide # alternative ground truth coordinates where the naming is swapped restype_3 = [ residue_constants.restype_1to3[res] for res in residue_constants.restypes ] restype_3 += ["UNK"] # Matrices for renaming ambiguous atoms. all_matrices = {res: np.eye(14, dtype=np.float32) for res in restype_3} for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): correspondences = np.arange(14) for source_atom_swap, target_atom_swap in swap.items(): source_index = residue_constants.restype_name_to_atom14_names.get(resname).index(source_atom_swap) target_index = residue_constants.restype_name_to_atom14_names.get(resname).index(target_atom_swap) correspondences[source_index] = target_index correspondences[target_index] = source_index renaming_matrix = np.zeros((14, 14), dtype=np.float32) for index, correspondence in enumerate(correspondences): renaming_matrix[index, correspondence] = 1. all_matrices[resname] = renaming_matrix.astype(np.float32) renaming_matrices = np.stack([all_matrices[restype] for restype in restype_3]) # Pick the transformation matrices for the given residue sequence # shape (num_res, 14, 14). renaming_transform = renaming_matrices[aatype] # Apply it to the ground truth positions. shape (num_res, 14, 3). alternative_gt_positions = np.einsum("rac,rab->rbc", residx_atom14_gt_positions, renaming_transform) atom14_alt_gt_positions = alternative_gt_positions # Create the mask for the alternative ground truth (differs from the # ground truth mask, if only one of the atoms in an ambiguous pair has a # ground truth position). alternative_gt_mask = np.einsum("ra,rab->rb", residx_atom14_gt_mask, renaming_transform) atom14_alt_gt_exists = alternative_gt_mask # Create an ambiguous atoms mask. shape: (21, 14). restype_atom14_is_ambiguous = np.zeros((21, 14), dtype=np.float32) for resname, swap in residue_constants.residue_atom_renaming_swaps.items(): for atom_name1, atom_name2 in swap.items(): restype = residue_constants.restype_order[ residue_constants.restype_3to1[resname]] atom_idx1 = residue_constants.restype_name_to_atom14_names.get(resname).index(atom_name1) atom_idx2 = residue_constants.restype_name_to_atom14_names.get(resname).index(atom_name2) restype_atom14_is_ambiguous[restype, atom_idx1] = 1 restype_atom14_is_ambiguous[restype, atom_idx2] = 1 # From this create an ambiguous_mask for the given sequence. atom14_atom_is_ambiguous = restype_atom14_is_ambiguous[aatype] return_pack = (atom14_atom_exists, atom14_gt_exists, atom14_gt_positions, residx_atom14_to_atom37, residx_atom37_to_atom14, atom37_atom_exists, atom14_alt_gt_positions, atom14_alt_gt_exists, atom14_atom_is_ambiguous) return return_pack
[docs]def get_pdb_info(pdb_path): """ get atom positions, residue index etc. info from pdb file. Args: pdb_path(str): the path of the input pdb. Returns: features(dict), the information of pdb, including these keys - aatype, numpy.array. Protein sequence encoding. Encoding method refers to `common.residue_constants_restype_order`, :math:`[0,20]` . 20 means the amino acid is `UNK`. Shape :math:`(N_{res}, )` . - all_atom_positions, numpy.array. Coordinates of all residues in pdb. Shape :math:`(N_{res}, 37)` . - all_atom_mask, numpy.array. Mask of atoms in pdb. Shape :math:`(N_{res}, 37)` . 0 means the atom inexistence. - atom14_atom_exists, numpy.array. Densely encoding, mask of all atoms in protein. The position with atoms is 1 and the position without atoms is 0. Shape is :math:`(N_{res}, 14)`. - atom14_gt_exists, numpy.array. Densely encoding, mask of all atoms in protein. Keep the same as `atom14_atom_exist`. Shape is :math:`(N_{res}, 14)`. - atom14_gt_positions, numpy.array. Densely encoding, coordinates of all atoms in the protein. Shape is :math:`(N_{res}, 14, 3)`. - residx_atom14_to_atom37, numpy.array. Index of mapping sparse encoding atoms with densely encoding method. Shape is :math:`(N_{res}, 14)` . - residx_atom37_to_atom14, numpy.array. Index of mapping densely encoding atoms with sparse encoding method. Shape is :math:`(N_{res}, 37)` . - atom37_atom_exists, numpy.array. Sparse encoding, mask of all atoms in protein. The position with atoms is 1 and the position without atoms is 0. Shape is :math:`(N_{res}, 37)`. - atom14_alt_gt_positions, numpy.array. Densely encoding, coordinates of all atoms in chiral proteins. Shape is :math:`(N_{res}, 14, 3)` . - atom14_alt_gt_exists, numpy.array. Densely encoding, mask of all atoms in chiral proteins. Shape is :math:`(N_{res}, 14)` . - atom14_atom_is_ambiguous, numpy.array. Because of the local symmetry of some amino acid structures, the symmetric atomic codes can be transposed. Specific atoms can be found in `common.residue_atom_renaming_swaps`. This feature records the uncertain atom encoding positions. Shape is :math:`(N_{res}, 14)` . - residue_index, numpy.array. Residue index information of protein sequence, ranging from 1 to :math:`N_{res}` . Shape is :math:`(N_{res}, )` . Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindsponge.common import get_pdb_info >>> pdb_path = "YOUR PDB PATH" >>> pdb_feature = get_pdb_info(pdb_path) >>> for feature in pdb_feature: >>> print(feature, pdb_feature[feature]) # Nres represents the Amino acid num of the input pdb. aatype (Nres,) all_atom_positions (Nres, 37, 3) all_atom_mask (Nres, 37) atom14_atom_exists (Nres, 14) atom14_gt_exists (Nres, 14) atom14_gt_positions (Nres, 14, 3) residx_atom14_to_atom37 (Nres, 14) residx_atom37_to_atom14 (Nres, 37) atom37_atom_exists (Nres, 37) atom14_alt_gt_positions (Nres, 14, 3) atom14_alt_gt_exists (Nres, 14) atom14_atom_is_ambiguous (Nres, 14) residue_index (Nres, ) """ with open(pdb_path, 'r', encoding="UTF-8") as f: prot_pdb = protein.from_pdb_string(f.read()) aatype = prot_pdb.aatype atom37_positions = prot_pdb.atom_positions.astype(np.float32) atom37_mask = prot_pdb.atom_mask.astype(np.float32) # get ground truth of atom14 features = {'aatype': aatype, 'all_atom_positions': atom37_positions, 'all_atom_mask': atom37_mask} atom14_atom_exists, atom14_gt_exists, atom14_gt_positions, residx_atom14_to_atom37, residx_atom37_to_atom14, \ atom37_atom_exists, atom14_alt_gt_positions, atom14_alt_gt_exists, atom14_atom_is_ambiguous = \ make_atom14_positions(aatype, atom37_mask, atom37_positions) features.update({"atom14_atom_exists": atom14_atom_exists, "atom14_gt_exists": atom14_gt_exists, "atom14_gt_positions": atom14_gt_positions, "residx_atom14_to_atom37": residx_atom14_to_atom37, "residx_atom37_to_atom14": residx_atom37_to_atom14, "atom37_atom_exists": atom37_atom_exists, "atom14_alt_gt_positions": atom14_alt_gt_positions, "atom14_alt_gt_exists": atom14_alt_gt_exists, "atom14_atom_is_ambiguous": atom14_atom_is_ambiguous}) features["residue_index"] = prot_pdb.residue_index return features
[docs]def get_fasta_info(pdb_path): """ Put in a pdb file and get fasta information from it. Return the sequence of the pdb. Args: pdb_path(str): path of the input pdb. Returns: fasta(str), fasta of input pdb. The sequence is the order of residues in the protein and has no relationship with residue index, such as "GSHMGVQ". Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindsponge.common import get_fasta_info >>> pdb_path = "YOUR PDB PATH" >>> fasta = get_fasta_info(pdb_path) >>> print(fasta) "GSHMGVQ" """ with open(pdb_path, 'r', encoding='UTF-8') as f: prot_pdb = protein.from_pdb_string(f.read()) aatype = prot_pdb.aatype fasta = [residue_constants.order_restype_with_x.get(x, "X") for x in aatype] return ''.join(fasta)
[docs]def get_aligned_seq(gt_seq, pr_seq): """ Align two protein fasta sequence. Return two aligned sequences and the position of same residues. Args: gt_seq(str): one protein fasta sequence, such as "ABAAABAA". pr_seq(str): another protein fasta sequence, such as "A-AABBBA". Returns: - target(str), one protein fasta sequence. - align_relationship(str), the differences of the two sequences. - query(str), another protein fasta sequence. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> from mindsponge.common import get_aligned_seq >>> gt_seq = "ABAAABAA" >>> pr_seq = "AAABBBA" >>> aligned_gt_seq, aligned_info, aligned_pr_seq = get_aligned_seq(gt_seq, pr_seq) >>> print(aligned_gt_seq) ABAAABAA >>> print(aligned_info) |-||.|.| >>> print(aligned_pr_seq) A-AABBBA """ aligner = Align.PairwiseAligner() substitution_matrices.load() matrix = substitution_matrices.load("BLOSUM62") for i in range(len(str(matrix.alphabet))): res = matrix.alphabet[i] matrix['X'][res] = 0 matrix[res]['X'] = 0 aligner.substitution_matrix = matrix aligner.open_gap_score = -10 aligner.extend_gap_score = -1 # many align results, get only the one w/ highest score. gt_seq as reference alignments = aligner.align(gt_seq, pr_seq) align = alignments[0] align_str = str(align) align_str_len = len(align_str) point = [] target = '' align_relationship = '' query = '' for i in range(align_str_len): if align_str[i] == '\n': point.append(i) for i in range(int(point[0])): target = target + align_str[i] for i in range(int(point[1])-int(point[0])-1): align_relationship = align_relationship + align_str[i + int(point[0])+1] for i in range(int(point[2])-int(point[1])-1): query = query + align_str[i + int(point[1])+1] return target, align_relationship, query
[docs]def find_optimal_renaming( atom14_gt_positions, atom14_alt_gt_positions, atom14_atom_is_ambiguous, atom14_gt_exists, atom14_pred_positions, ): # (N): """ Find optimal renaming for ground truth that maximizes LDDT. Reference: `Jumper et al. (2021) Suppl. Alg. 26 "renameSymmetricGroundTruthAtoms" <https://www.nature.com/articles/s41586-021-03819-2>`_ Args: atom14_gt_positions (Tensor): Ground truth positions in global frame with shape :math:`(N_{res}, 14, 3)`. atom14_alt_gt_positions (Tensor): Alternate ground truth positions in global frame with coordinates of ambiguous atoms swapped relative to 'atom14_gt_positions'. The shape is :math:`(N_{res}, 14, 3)`. atom14_atom_is_ambiguous (Tensor): Mask denoting whether atom is among ambiguous atoms, see Jumper et al. (2021) Suppl. Table 3. The shape is :math:`(N_{res}, 14)`. atom14_gt_exists (Tensor): Mask denoting whether atom at positions exists in ground truth with shape :math:`(N_{res}, 14)`. atom14_pred_positions(Tensor): Predicted positions of atoms in global prediction frame with shape :math:`(N_{res}, 14, 3)`. Returns: Tensor, :math:`(N_{res},)` with 1.0 where atom14_alt_gt_positions is closer to prediction and otherwise 0. Supported Platforms: ``Ascend`` ``GPU`` Examples: >>> import numpy as np >>> from mindsponge.common.utils import find_optimal_renaming >>> from mindspore import Tensor >>> n_res = 16 >>> atom14_gt_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) >>> atom14_alt_gt_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) >>> atom14_atom_is_ambiguous = Tensor(np.random.randn(n_res, 14).astype(np.float32)) >>> atom14_gt_exists = Tensor(np.random.randn(n_res, 14).astype(np.float32)) >>> atom14_pred_positions = Tensor(np.random.randn(n_res, 14, 3).astype(np.float32)) >>> out = find_optimal_renaming(atom14_gt_positions, atom14_alt_gt_positions, ... atom14_atom_is_ambiguous, atom14_gt_exists, atom14_pred_positions) >>> print(out.shape) (16,) """ # Create the pred distance matrix. atom14_pred_positions = P.Pad(((0, 0), (0, 0), (0, 5)))(atom14_pred_positions) pred_dists = mnp.sqrt(1e-10 + mnp.sum( mnp.square(atom14_pred_positions[:, None, :, None, :] - atom14_pred_positions[None, :, None, :, :]), axis=-1)) # Compute distances for ground truth with original and alternative names. gt_dists = mnp.sqrt(1e-10 + mnp.sum( mnp.square(atom14_gt_positions[:, None, :, None, :] - atom14_gt_positions[None, :, None, :, :]), axis=-1)) alt_gt_dists = mnp.sqrt(1e-10 + mnp.sum( mnp.square(atom14_alt_gt_positions[:, None, :, None, :] - atom14_alt_gt_positions[None, :, None, :, :]), axis=-1)) # Compute LDDT's. lddt = mnp.sqrt(1e-10 + mnp.square(pred_dists - gt_dists)) alt_lddt = mnp.sqrt(1e-10 + mnp.square(pred_dists - alt_gt_dists)) # Create a mask for ambiguous atoms in rows vs. non-ambiguous atoms # in cols. mask = (atom14_gt_exists[:, None, :, None] * # rows atom14_atom_is_ambiguous[:, None, :, None] * # rows atom14_gt_exists[None, :, None, :] * # cols (1. - atom14_atom_is_ambiguous[None, :, None, :])) # cols # Aggregate distances for each residue to the non-amibuguous atoms. per_res_lddt = P.ReduceSum()(mask * lddt, (1, 2, 3)) alt_per_res_lddt = P.ReduceSum()(mask * alt_lddt, (1, 2, 3)) # Decide for each residue, whether alternative naming is better. alt_naming_is_better = (alt_per_res_lddt < per_res_lddt) return alt_naming_is_better