Source code for mindspore.train.summary.summary_record

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Record the summary event."""
import os
import threading
from mindspore import log as logger
from ._summary_scheduler import WorkerScheduler, SummaryDataManager
from ._summary_adapter import get_event_file_name, package_graph_event
from ._event_writer import EventRecord
from .._utils import _make_directory
from ..._checkparam import _check_str_by_regular

# cache the summary data
_summary_tensor_cache = {}
_summary_lock = threading.Lock()


def _cache_summary_tensor_data(summary):
    """
    Get the time of ms.

    Args:
         summary (list): [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...].
    """
    _summary_lock.acquire()
    if "SummaryRecord" in _summary_tensor_cache:
        for record in summary:
            _summary_tensor_cache["SummaryRecord"].append(record)
    else:
        _summary_tensor_cache["SummaryRecord"] = summary
    _summary_lock.release()
    return True


[docs]class SummaryRecord: """ Summary log record. SummaryRecord is used to record the summary value. The API will create an event file in a given directory and add summaries and events to it. Args: log_dir (str): The log_dir is a directory location to save the summary. queue_max_size (int): The capacity of event queue.(reserved). Default: 0. flush_time (int): Frequency to flush the summaries to disk, the unit is second. Default: 120. file_prefix (str): The prefix of file. Default: "events". file_suffix (str): The suffix of file. Default: "_MS". network (Cell): Obtain a pipeline through network for saving graph summary. Default: None. Raises: TypeError: If `queue_max_size` and `flush_time` is not int, or `file_prefix` and `file_suffix` is not str. RuntimeError: If the log_dir can not be resolved to a canonicalized absolute pathname. Examples: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") """ def __init__(self, log_dir, queue_max_size=0, flush_time=120, file_prefix="events", file_suffix="_MS", network=None): _check_str_by_regular(file_prefix) _check_str_by_regular(file_suffix) self.log_path = _make_directory(log_dir) if not isinstance(queue_max_size, int) or not isinstance(flush_time, int): raise TypeError("`queue_max_size` and `flush_time` should be int") if not isinstance(file_prefix, str) or not isinstance(file_suffix, str): raise TypeError("`file_prefix` and `file_suffix` should be str.") self.queue_max_size = queue_max_size if queue_max_size < 0: # 0 is not limit logger.warning("The queue_max_size(%r) set error, will use the default value: 0", queue_max_size) self.queue_max_size = 0 self.flush_time = flush_time if flush_time <= 0: logger.warning("The flush_time(%r) set error, will use the default value: 120", flush_time) self.flush_time = 120 self.prefix = file_prefix self.suffix = file_suffix # create the summary writer file self.event_file_name = get_event_file_name(self.prefix, self.suffix) if self.log_path[-1:] == '/': self.full_file_name = self.log_path + self.event_file_name else: self.full_file_name = self.log_path + '/' + self.event_file_name try: self.full_file_name = os.path.realpath(self.full_file_name) except Exception as ex: raise RuntimeError(ex) self.event_writer = EventRecord(self.full_file_name, self.flush_time) self.writer_id = SummaryDataManager.summary_file_set(self.event_writer) self.worker_scheduler = WorkerScheduler(self.writer_id) self.step = 0 self._closed = False self.network = network self.has_graph = False
[docs] def record(self, step, train_network=None): """ Record the summary. Args: step (int): Represents training step number. train_network (Cell): The network that called the callback. Examples: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") >>> summary_record.record(step=2) Returns: bool, whether the record process is successful or not. """ logger.info("SummaryRecord step is %r.", step) if self._closed: logger.error("The record writer is closed.") return False if not isinstance(step, int) or isinstance(step, bool): raise ValueError("`step` should be int") # Set the current summary of train step self.step = step if self.network is not None and self.has_graph is False: graph_proto = self.network.get_func_graph_proto() if graph_proto is None and train_network is not None: graph_proto = train_network.get_func_graph_proto() if graph_proto is None: logger.error("Failed to get proto for graph") else: self.event_writer.write_event_to_file( package_graph_event(graph_proto).SerializeToString()) self.event_writer.flush() self.has_graph = True data = _summary_tensor_cache.get("SummaryRecord") if data is None: logger.error("The step(%r) does not have record data.", self.step) return False if self.queue_max_size > 0 and len(data) > self.queue_max_size: logger.error("The size of data record is %r, which is greater than queue_max_size %r.", len(data), self.queue_max_size) # clean the data of cache del _summary_tensor_cache["SummaryRecord"] # process the data self.worker_scheduler.dispatch(self.step, data) # count & flush self.event_writer.count_event() self.event_writer.flush_cycle() logger.debug("Send the summary data to scheduler for saving, step = %d", self.step) return True
@property def log_dir(self): """ Get the full path of the log file. Examples: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") >>> print(summary_record.log_dir) Returns: String, the full path of log file. """ return self.event_writer.full_file_name
[docs] def flush(self): """ Flush the event file to disk. Call it to make sure that all pending events have been written to disk. Examples: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") >>> summary_record.flush() """ if self._closed: logger.error("The record writer is closed and can not flush.") else: self.event_writer.flush()
[docs] def close(self): """ Flush all events and close summary records. Examples: >>> summary_record = SummaryRecord(log_dir="/opt/log", queue_max_size=50, flush_time=6, >>> file_prefix="xxx_", file_suffix="_yyy") >>> summary_record.close() """ if not self._closed: self._check_data_before_close() self.worker_scheduler.close() # event writer flush and close self.event_writer.close() self._closed = True
def __del__(self): """Process exit is called.""" if hasattr(self, "worker_scheduler"): if self.worker_scheduler: self.close() def _check_data_before_close(self): "Check whether there is any data in the cache, and if so, call record" data = _summary_tensor_cache.get("SummaryRecord") if data is not None: self.record(self.step)