mindspore.experimental.optim.Adamax =================================== .. image:: https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg :target: https://gitee.com/mindspore/mindspore/blob/master/docs/api/api_python/experimental/optim/mindspore.experimental.optim.Adamax.rst :alt: 查看源文件 .. py:class:: mindspore.experimental.optim.Adamax(params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, *, maximize=False) Adamax算法的实现(基于无穷范数的Adam算法)。 更新公式如下: .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ &\hspace{13mm} \epsilon \text{ (epsilon)} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}if \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\ &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} .. warning:: 这是一个实验性的优化器接口,需要和 `LRScheduler `_ 下的动态学习率接口配合使用。 参数: - **params** (Union[list(Parameter), list(dict)]) - 网络参数的列表或指定了参数组的列表。 - **lr** (Union[int, float, Tensor], 可选) - 学习率。默认值:``2e-3``。 - **betas** (Tuple[float, float], 可选) - 梯度及其平方的运行平均值的系数。默认值:``(0.9, 0.999)``。 - **eps** (float, 可选) - 加在分母上的值,以确保数值稳定。必须大于0。默认值:``1e-8``。 - **weight_decay** (float, 可选) - 权重衰减(L2 penalty)。默认值:``0.``。 关键字参数: - **maximize** (bool, 可选) - 是否根据目标函数最大化网络参数。默认值:``False``。 输入: - **gradients** (tuple[Tensor]) - 网络权重的梯度。 异常: - **ValueError** - 学习率不是int、float或Tensor。 - **ValueError** - 学习率小于0。 - **ValueError** - `eps` 小于0。 - **ValueError** - `betas` 范围不在[0,1)之间。 - **ValueError** - `weight_decay` 小于0。