mindquantum.algorithm.library.qjpeg 源代码

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""QJPEG algorithm for quantum figure compression."""

from typing import Tuple, List
from mindquantum.algorithm.library.quantum_fourier import qft
from mindquantum.core.circuit import Circuit, dagger
from mindquantum.utils.type_value_check import _check_int_type, _check_value_should_not_less


[文档]def qjpeg(n_qubits: int, m_qubits: int) -> Tuple[Circuit, List[int], List[int]]: """ Construct the circuit for compressing quantum figure with the QJPEG algorithm. Args: n_qubits (int): The number of qubits used to encode the quantum figure to be compressed. m_qubits (int): The number of qubits used to encode the compressed quantum figure. Note: The input arguments, n_qubits and m_qubits, should both be even, and the n_qubits must be not less than the m_qubits. Please refer to arXiv:2306.09323v2 for more information. Returns: - Circuit, The QJPEG circuit for quantum image compression - List[int], List of indices for remainder qubits that carry the compressed quantum image information - List[int], List of indices for discarded qubits Examples: >>> from mindquantum import Simulator, normalize >>> import numpy as np >>> n_qubits = 4 >>> m_qubits = 2 >>> circ, remainder_qubits, discard_qubits = qjpeg(n_qubits, m_qubits) >>> print(remainder_qubits, discard_qubits) [0, 2] [1, 3] >>> data = np.array([[1,0,0,0], [1,1,0,0], [1,1,1,0], [1,1,1,1]]) >>> state = normalize(data.reshape(-1)) >>> sim = Simulator('mqmatrix', n_qubits) >>> sim.set_qs(state) >>> sim.apply_circuit(circ) >>> rho = sim.get_partial_trace(discard_qubits) >>> sub_probs = rho.diagonal().real >>> new_data = sub_probs.reshape((2**(m_qubits//2), -1)) >>> print(new_data) [[0.3, 0.], [0.4, 0.3]] """ _check_int_type("n_qubits", n_qubits) _check_int_type("m_qubits", m_qubits) _check_value_should_not_less("n_qubits", 0, n_qubits) _check_value_should_not_less("m_qubits", 0, m_qubits) if n_qubits < m_qubits: raise ValueError("n_qubits should be not less than m_qubits.") if n_qubits % 2 != 0 or m_qubits % 2 != 0: raise ValueError("Both n_qubits and m_qubits should be even numbers.") half_diff = (n_qubits - m_qubits) // 2 former_qubits = list(range(0, n_qubits // 2)) latter_qubits = list(range(n_qubits // 2, n_qubits)) mid_qubits = former_qubits[len(former_qubits) - half_diff :] last_qubits = latter_qubits[len(latter_qubits) - half_diff :] discard_qubits = mid_qubits + last_qubits remainder_qubits = list(set(range(n_qubits)).difference(discard_qubits)) circ = Circuit() circ += qft(range(n_qubits)) circ += dagger(qft(range(n_qubits - half_diff))) return circ, remainder_qubits, discard_qubits