sciai.common.Sampler
- class sciai.common.Sampler(dim, coords, func, name=None)[源代码]
常用的数据采样器。
- 参数:
dim (int) - 数据的维度。
coords (Union[array, list]) - 下界坐标和上界坐标,例如[[0.0, 0.0], [0.0, 1.0]]。
func (Callable) - 精确解函数。
name (str) - 采样器名称。默认值:None。
- 支持平台:
GPU
CPU
Ascend
样例:
>>> import numpy as np >>> from sciai.common import Sampler >>> def u(x_): >>> t = x_[:, 0:1] >>> x = x_[:, 1:2] >>> return np.exp(-t) * np.sin(500 * np.pi * x) >>> ics_coords = np.array([[0.0, 0.0], [0.0, 1.0]]) >>> ics_sampler = Sampler(2, ics_coords, u, name='Initial Condition 1') >>> x_batch, y_batch = ics_sampler.sample(10) >>> print(x_batch.shape, y_batch.shape) (10, 2), (10, 1)
- fetch_minibatch(n, mu_x, sigma_x)[源代码]
从采样器采出一个minibatch的数据。
- 参数:
n (int) - 一个minibatch的数据点个数。
mu_x (int) - 采样点的均值。
sigma_x (int) - 采样点的方差。
- 返回:
tuple[Tensor],一个minibatch的正则化后的采样点。