sponge.core.WithForceCell

View Source On Gitee
class sponge.core.WithForceCell(system: Molecule = None, force: ForceCell = None, neighbour_list: NeighbourList = None, modifier: ForceModifier = None)[source]

Cell that wraps the simulation system with the atomic force function.

Parameters
  • system (sponge.system.Molecule) – Simulation system.

  • force (sponge.potential.ForceCell) – Atomic force calculation cell.

  • neighbour_list (sponge.partition.NeighbourList, optional) – Neighbour list. Default: None.

  • modifier (sponge.sampling.modifier.ForceModifier, optional) – Force modifier. Default: None.

Inputs:
  • energy (Tensor) - Total potential energy of the simulation system. Tensor of shape \((B, 1)\). Here B is batch size, i.e. the number of walkers in simulation. Data type is float.

  • force (Tensor) - Data type is float.Force on each atoms of the simulation system. Tensor of shape \((B, A, D)\). Here \(B\) is batch size, i.e. the number of walkers in simulation, A is the number of atoms, and \(D\) is the spatial dimension of the simulation system, which is usually 3.

  • virial (Tensor) - Virial tensor of the simulation system. Tensor of shape \((B, D)\). Data type is float.

Outputs:
  • energy (Tensor) - with shape of \((B, 1)\). Total potential energy of the simulation system. Data type is float.

  • force (Tensor) - with shape of \((B, A, D)\). Force on each atoms of the simulation system. Data type is float.

  • virial (Tensor) - with shape of \((B, D)\). Virial tensor of the simulation system. Data type is float.

Supported Platforms:

Ascend GPU

Examples

>>> # You can find case2.pdb file under MindSPONGE/tutorials/basic/case2.pdb
>>> from sponge import Protein
>>> from sponge.potential.forcefield import ForceField
>>> from sponge.partition import NeighbourList
>>> from sponge.core.simulation import WithEnergyCell, WithForceCell
>>> from sponge.sampling import MaskedDriven
>>> system = Protein(pdb='case2.pdb', rebuild_hydrogen=True)
>>> energy = ForceField(system, 'AMBER.FF99SB')
>>> neighbour_list = NeighbourList(system, cutoff=None, cast_fp16=True)
>>> with_energy = WithEnergyCell(system, energy, neighbour_list=neighbour_list)
>>> modifier = MaskedDriven(length_unit=with_energy.length_unit,
...                         energy_unit=with_energy.energy_unit,
...                         mask=system.heavy_atom_mask)
>>> with_force = WithForceCell(system, neighbour_list=neighbour_list, modifier=modifier)
property cutoff: mindspore.common.tensor.Tensor

Cutoff distance for neighbour list

Returns

Tensor, cutoff

property energy_unit: str

Energy unit

Returns

str, energy unit

get_neighbour_list()[source]

Get neighbour list

Returns

  • neigh_idx, Tensor of shape \((B, A, N)\). Index of neighbouring atoms of each atoms in system. Here \(B\) is the number of walkers in simulation, A is the number of atoms, \(N\) is the number of neighbouring atoms. Data type is int.

  • neigh_mask, Tensor of shape \((B, A, N)\). Mask for neighbour list neigh_idx. Data type is bool.

property length_unit: str

Length unit

Returns

str, length unit

property neighbour_list_pace: int

Update step for neighbour list

Returns

int, step

set_pbc_grad(grad_box: bool)[source]

Set whether to calculate the gradient of PBC box

Parameters

grad_box (bool) – Whether to calculate the gradient of PBC box.

update_modifier(step: int)[source]

Update force modifier

Parameters

step (int) – Simulatio step.

update_neighbour_list()[source]

Update neighbour list

Parameters
  • coordinate (Tensor) – Position coordinate. Tensor of shape \((B, A, D)\). Here \(B\) is the number of walkers in simulation, \(A\) is the number of atoms, \(D\) is the spatial dimension of the simulation system, which is usually 3. Data type is float.

  • pbc_box (Tensor) – Size of PBC box. Tensor of shape \((B, D)\). Data type is float.

Returns

  • neigh_idx, Tensor of shape \((B, A, N)\). Index of neighbouring atoms of each atoms in system. Here \(N\) is the number of neighbouring atoms. Data type is int.

  • neigh_mask, Tensor of shape \((B, A, N)\). Mask for neighbour list neigh_idx. Data type is bool.