# Copyright 2021-2023 @ Shenzhen Bay Laboratory &
# Peking University &
# Huawei Technologies Co., Ltd
#
# This code is a part of MindSPONGE:
# MindSpore Simulation Package tOwards Next Generation molecular modelling.
#
# MindSPONGE is open-source software based on the AI-framework:
# MindSpore (https://www.mindspore.cn/)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Neighbour list
"""
from inspect import signature
from distutils.version import LooseVersion
from typing import Tuple
import mindspore as ms
import mindspore.numpy as msnp
from mindspore import Tensor
from mindspore import Parameter
from mindspore import ops, nn
from mindspore.ops import functional as F
from mindspore.nn import Cell
from . import FullConnectNeighbours, DistanceNeighbours, GridNeighbours
from ..system import Molecule
from ..function.functions import gather_vector, get_integer, get_ms_array
from ..function.operations import GetVector
[docs]class NeighbourList(Cell):
r"""Neighbour list
Args:
system (Molecule): Simulation system.
cutoff (float, optional): Cut-off distance. If ``None`` is given under periodic
boundary condition (PBC), the cutoff will be assigned
with the default value of 1 nm.
Default: ``None``.
pace (int, optional): Update frequency for neighbour list. Default: ``20``
exclude_index (Tensor, optional): Tensor of the indices of the neighbouring atoms
which could be excluded from the neighbour list. The shape
of Tensor is :math:`(B, A, Ex)`, and the data type is int.
Default: ``None``.
num_neighbours (int, optional): Maximum number of neighbours.
If ``None`` is given, this value will be calculated
by the ratio of the number of neighbouring grids to the
total number of grids. Default: ``None``.
num_cell_cut (int, optional): Number of subdivision of grid cells according to cutoff. Default: ``1``
cutoff_scale (float, optional): Factor to scale cutoff distance. Default: ``1.2``
cell_cap_scale (float, optional): Scale factor for `cell_capacity`. Default: ``1.25``
grid_num_scale (float, optional): Scale factor to calculate `num_neighbours` by ratio of grids.
If `num_neighbours` is not ``None``, it will not be used. Default: ``2``
use_grids (bool, optional): Whether to use grids to calculate the neighbour list. Default: ``None``.
cast_fp16 (bool, optional): If this is set to ``True``, the data will be cast to float16 before sort.
For use with some devices that only support sorting of float16 data.
Default: ``False``.
Note:
- B: Batchsize, i.e. number of walkers of the simulation.
- A: Number of the atoms in the simulation system.
- N: Number of the maximum neighbouring atoms.
- D: Dimension of position coordinates.
- Ex: Maximum number of excluded neighbour atoms.
Supported Platforms:
``Ascend`` ``GPU``
Examples:
>>> import sponge
>>> from sponge.partition import NeighbourList
>>> from sponge.system import Molecule
>>> system = Molecule(template='water.spce.yaml')
>>> neighbourlist = NeighbourList(system, 0.5)
>>> neighbourlist(system.coordinate, system.pbc_box)
(Tensor(shape=[1, 3, 2], dtype=Int64, value=
[[[1, 2],
[0, 2],
[0, 1]]]),
Tensor(shape=[1, 3, 2, 3], dtype=Float32, value=
[[[[ 8.16490427e-02, 5.77358976e-02, 0.00000000e+00],
[-8.16490427e-02, 5.77358976e-02, 0.00000000e+00]],
[[-8.16490427e-02, -5.77358976e-02, 0.00000000e+00],
[-1.63298085e-01, 0.00000000e+00, 0.00000000e+00]],
[[ 8.16490427e-02, -5.77358976e-02, 0.00000000e+00],
[ 1.63298085e-01, 0.00000000e+00, 0.00000000e+00]]]]),
Tensor(shape=[1, 3, 2], dtype=Float32, value=
[[[ 1.00000001e-01, 1.00000001e-01],
[ 1.00000001e-01, 1.63298085e-01],
[ 1.00000001e-01, 1.63298085e-01]]]),
Tensor(shape=[1, 3, 2], dtype=Bool, value=
[[[ True, True],
[ True, True],
[ True, True]]]))
>>> neighbourlist.calculate(system.coordinate, system.pbc_box)
(Tensor(shape=[1, 3, 2], dtype=Int64, value=
[[[1, 2],
[0, 2],
[0, 1]]]),
Tensor(shape=[1, 3, 2], dtype=Bool, value=
[[[ True, True],
[ True, True],
[ True, True]]]))
>>> neighbourlist.pace
20
"""
def __init__(self,
system: Molecule,
cutoff: float = None,
pace: int = 20,
exclude_index: Tensor = None,
num_neighbours: int = None,
num_cell_cut: int = 1,
cutoff_scale: float = 1.2,
cell_cap_scale: float = 1.25,
grid_num_scale: float = 2,
use_grids: bool = False,
cast_fp16: bool = False,
):
super().__init__()
if LooseVersion(ms.__version__) < LooseVersion('2.1.0'):
cast_fp16 = True
self.num_walker = system.num_walker
self.coordinate = system.get_coordinate()
self.num_atoms = self.coordinate.shape[-2]
self.dim = self.coordinate.shape[-1]
self.pbc_box = system.get_pbc_box()
use_pbc = self.pbc_box is not None
self.atom_mask = system.atom_mask
self.exclude_index = exclude_index
if exclude_index is not None:
self.exclude_index = Tensor(exclude_index, ms.int32)
self.use_grids = use_grids
self._pace = get_integer(pace)
if self._pace < 0:
raise ValueError('pace cannot be less than 0!')
if cutoff is None and self.pbc_box is not None:
cutoff = system.units.length(1, 'nm')
self.no_mask = False
if cutoff is None:
self.cutoff = None
self.large_dis = Tensor(1e4, ms.float32)
self._pace = 0
self.neighbour_list = FullConnectNeighbours(self.num_atoms)
if self.exclude_index is None:
self.no_mask = True
else:
self.cutoff = get_ms_array(cutoff, ms.float32)
self.large_dis = self.cutoff * 100
if self.use_grids or self.use_grids is None:
self.neighbour_list = GridNeighbours(
cutoff=self.cutoff,
coordinate=self.coordinate,
pbc_box=self.pbc_box,
atom_mask=self.atom_mask,
exclude_index=self.exclude_index,
num_neighbours=num_neighbours,
num_cell_cut=num_cell_cut,
cutoff_scale=cutoff_scale,
cell_cap_scale=cell_cap_scale,
grid_num_scale=grid_num_scale,
)
if self.neighbour_list.neigh_capacity >= self.num_atoms:
if self.use_grids is True:
print(f'[WARNING] The number of neighbour atoms in `GridNeighbours` '
f'({self.neighbour_list.neigh_capacity}) is not less than '
f'the number of atoms ({self.num_atoms}). '
f'It would be more efficient to use `DistanceNeighbours` '
f'(set `use_grids` to `False` or `None`).')
else:
self.use_grids = False
else:
self.use_grids = True
if not self.use_grids:
self.neighbour_list = DistanceNeighbours(
cutoff=self.cutoff,
num_neighbours=num_neighbours,
atom_mask=self.atom_mask,
exclude_index=self.exclude_index,
use_pbc=use_pbc,
cutoff_scale=cutoff_scale,
large_dis=self.large_dis,
cast_fp16=cast_fp16,
)
if num_neighbours is None:
self.neighbour_list.set_num_neighbours(self.coordinate, self.pbc_box)
self.num_neighbours = self.neighbour_list.num_neighbours
index, mask = self.calculate(self.coordinate, self.pbc_box)
self.neighbours = None
self.neighbour_mask = None
if index is not None:
self.neighbours = Parameter(index, name='neighbours', requires_grad=False)
if mask is not None:
self.neighbour_mask = Parameter(mask, name='neighbour_mask', requires_grad=False)
self.get_vector = GetVector(use_pbc)
self.identity = ops.Identity()
self.norm_last_dim = None
# MindSpore < 2.0.0-rc1
if 'ord' not in signature(ops.norm).parameters.keys():
self.norm_last_dim = nn.Norm(-1, False)
@property
def pace(self) -> int:
r"""Update frequency for neighbour list
Returns:
int, update pace
"""
return self._pace
[docs] def set_exclude_index(self, exclude_index: Tensor):
r"""set exclude index
Args:
exclude_index (Tensor): Tensor of shape :math:`(B, A, Ex)`. Data type is int.
"""
if exclude_index is None:
return self
self.exclude_index = self.neighbour_list.set_exclude_index(exclude_index)
index, mask = self.update(self.coordinate, self.pbc_box)
F.assign(self.neighbours, index)
if self.neighbour_mask is None:
self.neighbour_mask = Parameter(mask, name='neighbour_mask', requires_grad=False)
else:
F.assign(self.neighbour_mask, mask)
return self
[docs] def print_info(self):
r"""print information of neighbour list"""
self.neighbour_list.print_info()
return self
[docs] def update(self, coordinate: Tensor, pbc_box: Tensor = None) -> Tuple[Tensor, Tensor]:
r"""
update neighbour list.
Args:
coordinate (Tensor): Tensor of shape :math:`(B, A, D)`. Data type is float.
Position coordinate.
pbc_box (Tensor, optional): Tensor of shape :math:`(B, D)`. Data type is float.
Size of PBC box.
Returns:
neigh_idx (Tensor), Tensor of shape :math:`(B, A, N)`. Data type is int.
Index of neighbouring atoms of each atoms in system.
neigh_mask (Tensor), Tensor of shape :math:`(B, A, N)`. Data type is bool.
Mask for neighbour list `neigh_idx`.
Note:
- B: Batchsize, i.e. number of walkers of the simulation.
- A: Number of the atoms in the simulation system.
- N: Number of the maximum neighbouring atoms.
- D: Dimension of position coordinates.
"""
if self.neighbours is None:
return None, None
coordinate = F.stop_gradient(coordinate)
if pbc_box is not None:
pbc_box = F.stop_gradient(pbc_box)
neighbours, neighbour_mask = self.calculate(coordinate, pbc_box)
neighbours = F.depend(neighbours, self.neighbour_list.check_neighbour_list())
neighbours = F.depend(neighbours, F.assign(self.neighbours, neighbours))
if self.neighbour_mask is not None:
neighbour_mask = F.depend(neighbour_mask, F.assign(self.neighbour_mask, neighbour_mask))
return neighbours, neighbour_mask
[docs] def calculate(self, coordinate: Tensor, pbc_box: Tensor = None) -> Tuple[Tensor, Tensor]:
r"""
calculate neighbour list.
Args:
coordinate (Tensor): Tensor of shape :math:`(B, A, D)`. Data type is float.
Position coordinate.
pbc_box (Tensor, optional): Tensor of shape :math:`(B, D)`. Data type is float.
Size of PBC box.
Returns:
neigh_idx (Tensor), Tensor of shape :math:`(B, A, N)`. Data type is int.
Index of neighbouring atoms of each atoms in system.
neigh_mask (Tensor), Tensor of shape :math:`(B, A, N)`. Data type is bool.
Mask for neighbour list `neigh_idx`.
Note:
- B: Batchsize, i.e. number of walkers of the simulation.
- A: Number of the atoms in the simulation system.
- N: Number of the maximum neighbouring atoms.
- D: Dimension of position coordinates.
"""
if self.cutoff is None:
return self.neighbour_list(self.atom_mask, self.exclude_index)
if self.use_grids:
return self.neighbour_list(coordinate, pbc_box)
_, index, mask = self.neighbour_list(
coordinate, pbc_box, self.atom_mask, self.exclude_index)
return index, mask
[docs] def get_neighbour_list(self) -> Tuple[Tensor, Tensor]:
r"""
get neighbour list.
Returns:
neigh_idx (Tensor), Tensor of shape :math:`(B, A, N)`. Data type is int.
Index of neighbouring atoms of each atoms in system.
neigh_mask (Tensor):, Tensor of shape :math:`(B, A, N)`. Data type is bool.
Mask for neighbour list `neigh_idx`.
Note:
- B: Batchsize, i.e. number of walkers of the simulation.
- A: Number of the atoms in the simulation system.
- N: Number of the maximum neighbouring atoms.
"""
if self.neighbours is None:
return None, None
index = F.stop_gradient(self.identity(self.neighbours))
mask = None
if self.neighbour_mask is not None:
mask = F.stop_gradient(self.identity(self.neighbour_mask))
return index, mask
def construct(self,
coordinate: Tensor,
pbc_box: Tensor = None
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
# pylint: disable=missing-docstring
# Gather coordinate of neighbours atoms.
# Args:
# coordinate (Tensor): Tensor of shape :math:`(B, A, D)`. Data type is float.
# Position coordinate.
# pbc_box (Tensor): Tensor of shape :math:`(B, D)`. Data type is float.
# Size of PBC box.
# Returns:
# neigh_idx (Tensor): Tensor of shape :math:`(B, A, N)`. Data type is int.
# Index of neighbouring atoms of each atoms in system.
# neigh_vec (Tensor): Tensor of shape :math:`(B, A, N, D)`. Data type is float.
# Vectors from central atom to neighbouring atoms.
# neigh_dis (Tensor): Tensor of shape :math:`(B, A, N)`. Data type is float.
# Distance between center atoms and neighbouring atoms.
# neigh_mask (Tensor): Tensor of shape :math:`(B, A, N)`. Data type is bool.
# Mask for neighbour list `neigh_idx`.
# Note:
# - B: Batchsize, i.e. number of walkers of the simulation.
# - A: Number of the atoms in the simulation system.
# - N: Number of the maximum neighbouring atoms.
# - D: Dimension of position coordinates.
if self.neighbours is None:
return None, None, None, None
neigh_idx, neigh_mask = self.get_neighbour_list()
# (B, A, 1, D) <- (B, A, D)
center_pos = F.expand_dims(coordinate, -2)
# (B, A, N, D) <- (B, A, D)
neigh_vec = gather_vector(coordinate, neigh_idx)
neigh_vec = self.get_vector(center_pos, neigh_vec, pbc_box)
# Add a non-zero value to the neighbour_vector whose mask value is False
# to prevent them from becoming zero values after Norm operation,
# which could lead to auto-differentiation errors
if neigh_mask is not None:
# (B, A, N)
large_dis = msnp.broadcast_to(self.large_dis, neigh_mask.shape)
large_dis = F.select(neigh_mask, F.zeros_like(large_dis), large_dis)
# (B, A, N, D) = (B, A, N, D) + (B, A, N, 1)
neigh_vec += F.expand_dims(large_dis, -1)
# (B, A, N) <- (B, A, N, D)
if self.norm_last_dim is None:
neigh_dis = ops.norm(neigh_vec, None, -1)
else:
neigh_dis = self.norm_last_dim(neigh_vec)
return neigh_idx, neigh_vec, neigh_dis, neigh_mask