mindspore.mindrecord.config 源代码

# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
The configuration module provides various functions to set and get the supported
configuration parameters.

Common imported modules in corresponding API examples are as follows:

.. code-block::

    from mindspore.mindrecord import set_enc_key, set_enc_mode, set_dec_mode
"""

import os
import shutil
import stat
import time

from mindspore import log as logger
from mindspore._c_expression import _encrypt, _decrypt_data
from .shardutils import MIN_FILE_SIZE


__all__ = ['set_enc_key',
           'set_enc_mode',
           'set_dec_mode']


# default encode key
ENC_KEY = None
ENC_MODE = "AES-GCM"
DEC_MODE = None
HASH_MODE = None


# the final mindrecord after encode should be like below
# 1. for create new mindrecord
# mindrecord -> enc_mindrecord+'ENCRYPT'
# 2. for read mindrecord
# enc_mindrecord+'ENCRYPT' -> mindrecord


# mindrecord file encode end flag, we will append 'ENCRYPT' to the end of file
ENCRYPT_END_FLAG = str('ENCRYPT').encode('utf-8')


# directory which stored decrypt mindrecord files
DECRYPT_DIRECTORY = ".decrypt_mindrecord"
DECRYPT_DIRECTORY_LIST = []


# time for warning when encrypt/decrypt takes too long time
ENCRYPT_TIME = 0
DECRYPT_TIME = 0
WARNING_INTERVAL = 30   # 30s


[文档]def set_enc_key(enc_key): """ Set the encode key. Args: enc_key (str): Str-type key used for encryption. The valid length is 16, 24, or 32. ``None`` indicates that encryption is not enabled. Raises: ValueError: The input is not str or length error. Examples: >>> from mindspore.mindrecord import set_enc_key >>> >>> set_enc_key("0123456789012345") """ global ENC_KEY if enc_key is None: ENC_KEY = None return if not isinstance(enc_key, str): raise ValueError("The input enc_key is not str.") if len(enc_key) not in [16, 24, 32]: raise ValueError("The length of input enc_key is not 16, 24, 32.") ENC_KEY = enc_key
def _get_enc_key(): """Get the encode key. If the enc_key is not set, it will return ``None``.""" global ENC_KEY return ENC_KEY
[文档]def set_enc_mode(enc_mode="AES-GCM"): """ Set the encode mode. Args: enc_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` . Specifies the encryption mode or customized encryption function, currently supports ``"AES-GCM"`` . Default: ``"AES-GCM"`` . If it is customized encryption, users need to ensure its correctness, the security of the encryption algorithm and raise exceptions when errors occur. Raises: ValueError: The input is not valid encode mode or callable function. Examples: >>> from mindspore.mindrecord import set_enc_mode >>> >>> set_enc_mode("AES-GCM") """ global ENC_MODE if callable(enc_mode): ENC_MODE = enc_mode return if not isinstance(enc_mode, str): raise ValueError("The input enc_mode is not str.") if enc_mode not in ["AES-GCM"]: raise ValueError("The input enc_mode is invalid.") ENC_MODE = enc_mode
def _get_enc_mode(): """Get the encode mode. If the enc_mode is not set, it will return default encode mode ``"AES-GCM"``.""" global ENC_MODE return ENC_MODE
[文档]def set_dec_mode(dec_mode="AES-GCM"): """ Set the decode mode. If the built-in `enc_mode` is used and `dec_mode` is not specified, the encryption algorithm specified by `enc_mode` is used for decryption. If you are using customized encryption function, you must specify customized decryption function at read time. Args: dec_mode (Union[str, function], optional): This parameter is valid only when enc_key is not set to ``None`` . Specifies the decryption mode or customized decryption function, currently supports ``"AES-GCM"`` . Default: ``"AES-GCM"`` . ``None`` indicates that decryption mode is not defined. If it is customized decryption, users need to ensure its correctness and raise exceptions when errors occur. Raises: ValueError: The input is not valid decode mode or callable function. Examples: >>> from mindspore.mindrecord import set_dec_mode >>> >>> set_dec_mode("AES-GCM") """ global DEC_MODE if dec_mode is None: DEC_MODE = None return if callable(dec_mode): DEC_MODE = dec_mode return if not isinstance(dec_mode, str): raise ValueError("The input dec_mode is not str.") if dec_mode not in ["AES-GCM"]: raise ValueError("The input dec_mode is invalid.") DEC_MODE = dec_mode
def _get_dec_mode(): """Get the decode mode. If the dec_mode is not set, it will return encode mode.""" global ENC_MODE global DEC_MODE if DEC_MODE is None: if callable(ENC_MODE): raise RuntimeError("You use custom encryption, so you must also define custom decryption.") return ENC_MODE return DEC_MODE def _get_enc_mode_as_str(): """Get the encode mode as string. The length of mode should be 7.""" global ENC_MODE valid_enc_mode = "" if callable(ENC_MODE): valid_enc_mode = "UDF-ENC" # "UDF-ENC" else: valid_enc_mode = ENC_MODE if len(valid_enc_mode) != 7: raise RuntimeError("The length of enc_mode string is not 7.") return str(valid_enc_mode).encode('utf-8') def _get_dec_mode_as_str(): """Get the decode mode as string. The length of mode should be 7.""" global ENC_MODE global DEC_MODE valid_dec_mode = "" if DEC_MODE is None: if callable(ENC_MODE): raise RuntimeError("You use custom encryption, so you must also define custom decryption.") valid_dec_mode = ENC_MODE # "AES-GCM" elif callable(DEC_MODE): valid_dec_mode = "UDF-ENC" # "UDF-ENC" else: valid_dec_mode = DEC_MODE if len(valid_dec_mode) != 7: raise RuntimeError("The length of enc_mode string is not 7.") return str(valid_dec_mode).encode('utf-8') def encrypt(filename, enc_key, enc_mode): """Encrypt the file and the original file will be deleted""" if not os.path.exists(filename): raise RuntimeError("The input: {} is not exists.".format(filename)) if not os.path.isfile(filename): raise RuntimeError("The input: {} should be a regular file.".format(filename)) logger.info("Begin to encrypt file: {}.".format(filename)) start = time.time() offset = 64 * 1024 * 1024 ## read the offset 64M current_offset = 0 ## use this to seek file file_size = os.path.getsize(filename) f = open(filename, 'rb') # create new encrypt file encrypt_filename = filename + ".encrypt" f_encrypt = open(encrypt_filename, 'wb') try: if callable(enc_mode): enc_mode(f, file_size, f_encrypt, enc_key) else: # read the file with offset and do encrypt # original mindrecord file like: # |64M|64M|64M|64M|... # encrypted mindrecord file like: # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG while True: if file_size - current_offset >= offset: read_size = offset elif file_size - current_offset > 0: read_size = file_size - current_offset else: # have read the entire file break try: f.seek(current_offset) except Exception as e: # pylint: disable=W0703 f.close() f_encrypt.close() raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" .format(filename, current_offset, str(e))) data = f.read(read_size) encode_data = _encrypt(data, len(data), enc_key, len(enc_key), enc_mode) # write length of data to encrypt file f_encrypt.write(int(len(encode_data)).to_bytes(length=4, byteorder='big', signed=True)) # write data to encrypt file f_encrypt.write(encode_data) current_offset += read_size except Exception as e: f.close() f_encrypt.close() os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR) raise e f.close() # writing 0 at the end indicates that all encrypted data has been written. f_encrypt.write(int(0).to_bytes(length=4, byteorder='big', signed=True)) # write enc_mode f_encrypt.write(_get_enc_mode_as_str()) # write ENCRYPT_END_FLAG f_encrypt.write(ENCRYPT_END_FLAG) f_encrypt.close() end = time.time() global ENCRYPT_TIME ENCRYPT_TIME += end - start if ENCRYPT_TIME > WARNING_INTERVAL: logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to encrypt the mindrecord file.") ENCRYPT_TIME = ENCRYPT_TIME - WARNING_INTERVAL # change the file mode os.chmod(encrypt_filename, stat.S_IRUSR | stat.S_IWUSR) # move the encrypt file to origin file shutil.move(encrypt_filename, filename) return True def _get_encrypt_end_flag(filename): """get encrypt end flag from the file""" if not os.path.exists(filename): raise RuntimeError("The input: {} is not exists.".format(filename)) if not os.path.isfile(filename): raise RuntimeError("The input: {} should be a regular file.".format(filename)) # get the file size first file_size = os.path.getsize(filename) offset = file_size - len(ENCRYPT_END_FLAG) f = open(filename, 'rb') # get the encrypt end flag which is 'ENCRYPT' try: f.seek(offset) except Exception as e: # pylint: disable=W0703 f.close() raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) data = f.read(len(ENCRYPT_END_FLAG)) f.close() return data def _get_enc_mode_from_file(filename): """get encrypt end flag from the file""" if not os.path.exists(filename): raise RuntimeError("The input: {} is not exists.".format(filename)) if not os.path.isfile(filename): raise RuntimeError("The input: {} should be a regular file.".format(filename)) # get the file size first file_size = os.path.getsize(filename) offset = file_size - len(ENCRYPT_END_FLAG) - 7 f = open(filename, 'rb') # get the encrypt end flag which is 'ENCRYPT' try: f.seek(offset) except Exception as e: # pylint: disable=W0703 f.close() raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}".format(filename, offset, str(e))) # read the enc_mode str which length is 7 data = f.read(7) f.close() return data def decrypt(filename, enc_key, dec_mode): """decrypt the file by enc_key and dec_mode""" if not os.path.exists(filename): raise RuntimeError("The input: {} is not exists.".format(filename)) if not os.path.isfile(filename): raise RuntimeError("The input: {} should be a regular file.".format(filename)) whole_file_size = os.path.getsize(filename) if whole_file_size < MIN_FILE_SIZE: raise RuntimeError("Invalid file, the size of mindrecord file: " + str(whole_file_size) + " is smaller than the lower limit: " + str(MIN_FILE_SIZE) + ".\n Please check file path: " + filename + " and use 'FileWriter' to generate valid mindrecord files.") global DECRYPT_DIRECTORY_LIST # check ENCRYPT_END_FLAG stored_encrypt_end_flag = _get_encrypt_end_flag(filename) if _get_enc_key() is not None: if stored_encrypt_end_flag != ENCRYPT_END_FLAG: raise RuntimeError("The mindrecord file is not encrypted. You can set " + "'mindspore.mindrecord.config.set_enc_key(None)' to disable the decryption.") else: if stored_encrypt_end_flag == ENCRYPT_END_FLAG: raise RuntimeError("The mindrecord file is encrypted. You need to configure " + "'mindspore.mindrecord.config.set_enc_key(...)' and " + "'mindspore.mindrecord.config.set_enc_mode(...)' for decryption.") return filename # check dec_mode with enc_mode enc_mode_from_file = _get_enc_mode_from_file(filename) if enc_mode_from_file != _get_dec_mode_as_str(): raise RuntimeError("Failed to decrypt data, please check if enc_key and enc_mode / dec_mode is valid.") logger.info("Begin to decrypt file: {}.".format(filename)) start = time.time() file_size = os.path.getsize(filename) - len(ENCRYPT_END_FLAG) f = open(filename, 'rb') real_path_filename = os.path.realpath(filename) parent_dir = os.path.dirname(real_path_filename) only_filename = os.path.basename(real_path_filename) current_decrypt_dir = os.path.join(parent_dir, DECRYPT_DIRECTORY) if not os.path.exists(current_decrypt_dir): os.mkdir(current_decrypt_dir) os.chmod(current_decrypt_dir, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) logger.info("Create directory: {} to store decrypt mindrecord files." .format(os.path.join(parent_dir, DECRYPT_DIRECTORY))) if current_decrypt_dir not in DECRYPT_DIRECTORY_LIST: DECRYPT_DIRECTORY_LIST.append(current_decrypt_dir) logger.warning("The decrypt mindrecord file will be stored in [" + current_decrypt_dir + "] directory. " "If you don't use it anymore after train / eval, you need to delete it manually.") # create new decrypt file decrypt_filename = os.path.join(current_decrypt_dir, only_filename) if os.path.isfile(decrypt_filename): # the file which had been decrypted early maybe update by user, so we remove the old decrypted one os.remove(decrypt_filename) f_decrypt = open(decrypt_filename, 'wb+') try: if callable(dec_mode): dec_mode(f, file_size, f_decrypt, enc_key) else: # read the file and do decrypt # encrypted mindrecord file like: # len+encrypt_data|len+encrypt_data|len+encrypt_data|...|0|enc_mode|ENCRYPT_END_FLAG current_offset = 0 ## use this to seek file length = int().from_bytes(f.read(4), byteorder='big', signed=True) while length != 0: # current_offset is the encrypted data current_offset += 4 try: f.seek(current_offset) except Exception as e: # pylint: disable=W0703 f.close() raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" .format(filename, current_offset, str(e))) data = f.read(length) decode_data = _decrypt_data(data, len(data), enc_key, len(enc_key), dec_mode) if decode_data is None: raise RuntimeError("Failed to decrypt data, " + "please check if enc_key and enc_mode / dec_mode is valid.") # write to decrypt file f_decrypt.write(decode_data) # current_offset is the length of next encrypted data block current_offset += length try: f.seek(current_offset) except Exception as e: # pylint: disable=W0703 f.close() raise RuntimeError("Seek the file: {} to position: {} failed. Error: {}" .format(filename, current_offset, str(e))) length = int().from_bytes(f.read(4), byteorder='big', signed=True) except Exception as e: f.close() f_decrypt.close() os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR) raise e f.close() f_decrypt.close() end = time.time() global DECRYPT_TIME DECRYPT_TIME += end - start if DECRYPT_TIME > WARNING_INTERVAL: logger.warning("It takes another " + str(WARNING_INTERVAL) + "s to decrypt the mindrecord file.") DECRYPT_TIME = DECRYPT_TIME - WARNING_INTERVAL # change the file mode os.chmod(decrypt_filename, stat.S_IRUSR | stat.S_IWUSR) return decrypt_filename