mindspore.Symbol
- class mindspore.Symbol(max=0, min=1, divisor=1, remainder=0, unique=False, **kawgs)[source]
Symbol is a data structure to indicate the symbolic info of shape.
For dynamic shape networks, compared with only setting the unknown dimensions (
None
) in Tensor , providing more symbolic shape info can help the framework better optimize the computation graph, to improve the performance of network execution.- Parameters
max (int) – The maximum length of this dimension, which is valid when it’s greater than min. Default:
0
.min (int) – The minimum length of this dimension. Default:
1
.divisor (int) – The divisor( \(d\) ). When remainder is 0, it means this dimension can be divided by \(d\) . Default:
1
.remainder (int) – The remainder( \(r\) ) when symbol is represented by \(d * N + r, N \ge 1\) . Default:
0
.unique (bool) – When the symbol object is used multiple times, if unique is
True
, the shape items of this symbol are considered to be same length, otherwise only symbol info is shared by multiple dimensions. Default:False
.
- Outputs:
Symbol.
- Raises
TypeError – If max, min, divisor, remainder is not an int.
TypeError – If unique is not a bool.
ValueError – If min is not positive value.
ValueError – If divisor is not positive value.
ValueError – If remainder is not in the range \([0, d)\) .
Examples
>>> import numpy as np >>> import mindspore as ms >>> from mindspore import nn, Tensor, Symbol >>> >>> class Net(nn.Cell): ... def __init__(self): ... super(Net, self).__init__() ... self.abs = ms.ops.Abs() ... def construct(self, x): ... return self.abs(x) ... >>> net = Net() >>> s1 = Symbol(divisor=8, remainder=1) >>> s2 = Symbol(max=32, unique=True) >>> dyn_t = Tensor(shape=(None, s1, s1, s2, s2), dtype=ms.float32) >>> net.set_inputs(dyn_t) >>> # the shape values of last two dimensions must be equal, because "s2" is set to "unique" >>> net(Tensor(np.random.randn(1, 9, 17, 32, 32), dtype=ms.float32)).shape (1, 9, 17, 32, 32) >>> net(Tensor(np.random.randn(8, 25, 9, 30, 30), dtype=ms.float32)).shape (8, 25, 9, 30, 30)