# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""control_ops"""
from ...common import dtype as mstype
from ..._checkparam import ParamValidator as validator
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
[docs]class ControlDepend(Primitive):
"""
Adds control dependency relation between source and destination operation.
In many cases, we need to control the execution order of operations. ControlDepend is designed for this.
ControlDepend will indicate the execution engine to run the operations in specific order. ControlDepend
tells the engine that the destination operations should depend on the source operation which means the source
operations should be executed before the destination.
Args:
depend_mode (int): Use 0 for normal depend, 1 for depend on operations that used the parameter. Default: 0.
Inputs:
- **src** (Any) - The source input. It can be a tuple of operations output or a single operation output. We do
not concern about the input data, but concern about the operation that generates the input data.
If `depend_mode = 1` is specified and the source input is parameter, we will try to find the operations that
used the parameter as input.
- **dst** (Any) - The destination input. It can be a tuple of operations output or a single operation output.
We do not concern about the input data, but concern about the operation that generates the input data.
If `depend_mode = 1` is specified and the source input is parameter, we will try to find the operations that
used the parameter as input.
Outputs:
Bool. This operation has no actual data output, it will be used to setup the order of relative operations.
Examples:
>>> # In the following example, the data calculation uses original global_step. After the calculation the global
>>> # step should be increased, so the add operation should depend on the data calculation operation.
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.global_step = Parameter(initializer(0, [1]), name="global_step")
>>> self.rate = 0.2
>>> self.control_depend = ControlDepend()
>>>
>>> def construct(self, x):
>>> data = self.rate * self.global_step + x
>>> added_global_step = self.global_step + 1
>>> self.global_step = added_global_step
>>> self.control_depend(data, added_global_step)
>>> return data
"""
@prim_attr_register
def __init__(self, depend_mode=0):
"""init"""
def __call__(self, src, dst):
return src
[docs]class GeSwitch(PrimitiveWithInfer):
"""
Adds control switch to data.
Switch data to flow into false or true branch depend on the condition. If the condition is true,
the true branch will be activated, or vise verse.
Inputs:
- **data** (Tensor) - The data to be used for switch control.
- **pred** (Tensor) - It should be a scalar whose type is bool and shape is `()`, It is used as condition for
switch control.
Outputs:
tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data.
The false_output connects with the false_branch and the true_output connects with the true_branch.
Examples:
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.square = P.Square()
>>> self.add = P.TensorAdd()
>>> self.value = Tensor(np.full((1), 3, dtype=np.float32))
>>> self.switch = P.GeSwitch()
>>> self.merge = P.Merge()
>>> self.less = P.Less()
>>>
>>> def construct(self, x, y):
>>> cond = self.less(x, y)
>>> st1, sf1 = self.switch(x, cond)
>>> st2, sf2 = self.switch(y, cond)
>>> add_ret = self.add(st1, st2)
>>> st3, sf3 = self.switch(self.value, cond)
>>> sq_ret = self.square(sf3)
>>> ret = self.merge((add_ret, sq_ret))
>>> return ret[0]
>>>
>>> x = Tensor(x_init, dtype=mindspore.float32)
>>> y = Tensor(y_init, dtype=mindspore.float32)
>>> net = Net()
>>> output = net(x, y)
"""
@prim_attr_register
def __init__(self):
"""init"""
def __call__(self, data, pred):
raise NotImplementedError
def infer_shape(self, data, pred):
validator.check_scalar_shape_input("pred", pred)
return (data, data)
def infer_dtype(self, data_type, pred_type):
validator.check_type("pred", pred_type, [type(mstype.bool_)])
return (data_type, data_type)
[docs]class Merge(PrimitiveWithInfer):
"""
Merges all input data to one.
One and only one of the inputs should be selected as the output
Inputs:
- **inputs** (Tuple) - The data to be merged. All tuple elements should have same shape.
Outputs:
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
"""
@prim_attr_register
def __init__(self):
"""init"""
def __call__(self, *args):
raise NotImplementedError
def infer_shape(self, inputs):
return (inputs[0], [1])
def infer_dtype(self, inputs):
return (inputs[0], mstype.int32)