mindspore.nn.GRUCell
- class mindspore.nn.GRUCell(input_size: int, hidden_size: int, has_bias: bool = True)[source]
A GRU(Gated Recurrent Unit) cell.
Here
is the sigmoid function, and is the Hadamard product. are learnable weights between the output and the input in the formula. For instance, are the weight and bias used to transform from input to . Details can be found in paper Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation.- Parameters
- Inputs:
x (Tensor) - Tensor of shape (batch_size, input_size).
hx (Tensor) - Tensor of data type mindspore.float32 and shape (batch_size, hidden_size). Data type of hx must be the same as x.
- Outputs:
h’ (Tensor) - Tensor of shape (batch_size, hidden_size).
- Supported Platforms:
Ascend
GPU
Examples
>>> net = nn.GRUCell(10, 16) >>> x = Tensor(np.ones([5, 3, 10]).astype(np.float32)) >>> hx = Tensor(np.ones([3, 16]).astype(np.float32)) >>> output = [] >>> for i in range(5): >>> hx = net(x[i], hx) >>> output.append(hx) >>> print(output[0].shape) (3, 16)