Source code for mindspore_rl.utils.batch_read_write

# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
BatchReadWrite
"""
from __future__ import absolute_import

from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.ops.operations._rl_inner_ops import BatchAssign


[docs]class BatchWrite(Cell): r"""Write a list of parameters to assign the target. .. warning:: This is an experiential prototype that is subject to change and/or deletion. Supported Platforms: ``GPU`` ``CPU`` Examples: >>> import mindspore >>> from mindspore import nn >>> from mindspore.common.parameter import Parameter, ParameterTuple >>> from mindspore_rl.utils import BatchWrite >>> class SourceNet(nn.Cell): ... def __init__(self): ... super(SourceNet, self).__init__() ... self.a = Parameter(Tensor(0.5, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0) >>> class DstNet(nn.Cell): ... def __init__(self): ... super(DstNet, self).__init__() ... self.a = Parameter(Tensor(0.1, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1) >>> class Write(nn.Cell): ... def __init__(self, dst, src): ... super(Write, self).__init__() ... self.w = BatchWrite() ... self.dst = ParameterTuple(dst.trainable_params()) ... self.src = ParameterTuple(src.trainable_params()) ... def construct(self): ... success = self.w(self.dst, self.src) ... return success >>> dst_net = DstNet() >>> source_net = SourceNet() >>> nets = nn.CellList() >>> nets.append(dst_net) >>> nets.append(source_net) >>> success = Write(nets[0], nets[1])() """ def __init__(self): """Initialize BatchWrite""" # pylint: disable=R1725 super(BatchWrite, self).__init__() self.write = BatchAssign(lock=True) if context.get_context("device_target") in ["Ascend"]: self.write.add_prim_attr("primitive_target", "CPU")
[docs] def construct(self, dst, src): """ Write the source parameter list to assign the dst. Args: dst (tuple(Parameters)): A paramameter tuple of the dst model. src (tuple(Parameters)): A paramameter tuple of the source model. Returns: True. """ self.write(dst, src) return True
[docs]class BatchRead(Cell): r"""Read a list of parameters to assign the target. .. warning:: This is an experiential prototype that is subject to change and/or deletion. Supported Platforms: ``GPU`` ``CPU`` Examples: >>> import mindspore >>> from mindspore import nn >>> from mindspore.common.parameter import Parameter, ParameterTuple >>> from mindspore_rl.utils import BatchRead >>> class SNet(nn.Cell): ... def __init__(self): ... super(SNet, self).__init__() ... self.a = Parameter(Tensor(0.5, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1, weight_init=0) >>> class DNet(nn.Cell): ... def __init__(self): ... super(DNet, self).__init__() ... self.a = Parameter(Tensor(0.1, mstype.float32), name="a") ... self.dense = nn.Dense(in_channels=16, out_channels=1) >>> class Read(nn.Cell): ... def __init__(self, dst, src): ... super(Read, self).__init__() ... self.read = BatchRead() ... self.dst = ParameterTuple(dst.trainable_params()) ... self.src = ParameterTuple(src.trainable_params()) ... def construct(self): ... success = self.read(self.dst, self.src) ... return success >>> dst_net = DNet() >>> source_net = SNet() >>> nets = nn.CellList() >>> nets.append(dst_net) >>> nets.append(source_net) >>> success = Read(nets[0], nets[1])() """ def __init__(self): """Initialize BatchRead""" # pylint: disable=R1725 super(BatchRead, self).__init__() self.read = BatchAssign(lock=False) if context.get_context("device_target") in ["Ascend"]: self.read.add_prim_attr("primitive_target", "CPU")
[docs] def construct(self, dst, src): """ Read the source parameter list to assign the dst. Args: dst (tuple(Parameters)): A paramameter tuple of the dst model. src (tuple(Parameters)): A paramameter tuple of the source model. Returns: True. """ self.read(dst, src) return True