Document feedback

Question document fragment

When a question document fragment contains a formula, it is displayed as a space.

Submission type
issue

It's a little complicated...

I'd like to ask someone.

PR

Just a small problem.

I can fix it online!

Please select the submission type

Problem type
Specifications and Common Mistakes

- Specifications and Common Mistakes:

- Misspellings or punctuation mistakes,incorrect formulas, abnormal display.

- Incorrect links, empty cells, or wrong formats.

- Chinese characters in English context.

- Minor inconsistencies between the UI and descriptions.

- Low writing fluency that does not affect understanding.

- Incorrect version numbers, including software package names and version numbers on the UI.

Usability

- Usability:

- Incorrect or missing key steps.

- Missing main function descriptions, keyword explanation, necessary prerequisites, or precautions.

- Ambiguous descriptions, unclear reference, or contradictory context.

- Unclear logic, such as missing classifications, items, and steps.

Correctness

- Correctness:

- Technical principles, function descriptions, supported platforms, parameter types, or exceptions inconsistent with that of software implementation.

- Incorrect schematic or architecture diagrams.

- Incorrect commands or command parameters.

- Incorrect code.

- Commands inconsistent with the functions.

- Wrong screenshots.

- Sample code running error, or running results inconsistent with the expectation.

Risk Warnings

- Risk Warnings:

- Lack of risk warnings for operations that may damage the system or important data.

Content Compliance

- Content Compliance:

- Contents that may violate applicable laws and regulations or geo-cultural context-sensitive words and expressions.

- Copyright infringement.

Please select the type of question

Problem description

Describe the bug so that we can quickly locate the problem.

Shor’s Algorithm Based on MindSpore Quantum

Download NotebookDownload CodeView Source On Gitee

Introduction to Shor’s Algorithm

The time complexity of Shor’s algorithm to decompose an integer N on a quantum computer is logN, which is almost exponential e speedup over the most efficient known classical factorization algorithm, and this speedup may break the modern cryptography such as RSA on a quantum computer.

Basic Idea of Shor’s Algorithm

Shor’s algorithm aims to solve the problem: given an integer N, find its prime factors. That is, for a given large number N, determine two prime factors p1 and p2 in polynomial time to satisfy p1p2=N. Before introducing the Shor’s algorithm, let’s learn some basic knowledge of number theory.

Factorization involves some knowledge in number theory, and it is possible to reduce the factorization problem to the function

f(x)=axmodN

where a and N are relatively prime, otherwise a factor can be obtained immediately by calling gcd(a,N). Since the function f(x) has a period of r, f(x)=f(x+r) is satisfied. In this case, we can get

ax=ax+rmodN,x

Set x=0, we can get ar=1+qN, where q is an integer,

ar1=(ar/21)(ar/2+1)=qN

It indicates that the factors of N can be found through finding greatest common divisor.

Therefore, the main idea of Shor’s algorithm is to transform the problem of factoring large numbers into the problem of finding the function’s period. Since we can use the superposition principle to perform parallel computing in quantum computing, we can quickly find the period r of the function f(x) through quantum algorithms (for specific principles and steps, please refer to the period finding algorithm in this document ). In general, we need to implement the function: f(|x)=a|xmodN in the quantum circuit. We can construct a unitary matrix $U_{a,N} $ where Ua,N|x|y|x|yf(x), and then using Quantum Fourier Transform to find the period r which satisfies ar1(modN).

Taking N=15 as an example, the steps of Shor’s algorithm is introduced as follows,

  1. Randomly choose a number, such as a=2(<15)

  2. Find the greatest common divisor, gcd(a,N)=gcd(2,15)=1

  3. Find the period of the function f(x)=axmodN, so that f(x+r)=f(x)

  4. Running the quantum circuit we can get r=4

  5. Find the greatest common divisor, gcd(ar/2+1,N)=gcd(5,15)=5

  6. Find the greatest common divisor, gcd(ar/21,N)=gcd(3,15)=3

  7. Hence, the prime factor of N=15 are 3 and 5, and the decomposition operation is complete.

The quantum circuit of Shor’s algorithm is shown as follows:

shor’s algorithm circuit

Implementing Shor’s Algorithm Using MindSpore Quantum

First, we need to import some required modules.

[1]:
#pylint: disable=W0611
import numpy as np
from fractions import Fraction
from mindquantum.core.gates import X, H, UnivMathGate, Measure
from mindquantum.core.circuit import Circuit, UN
from mindquantum.algorithm.library import qft
from mindquantum.simulator import Simulator

From the basic idea of Shor’s algorithm, we can see that the main part of Shor’s algorithm is period finding subroutine processed by quantum computers, and the most difficult part of the period search algorithm is the operator U which convert the state |x|y into |x|yf(x). The quantum circuit structure of this operator is more complicated. Therefore, we will first calculate the operator U through a classical computer and use it as Make an Oracle so that this document can demonstrate Shor’s algorithm as a whole and intuitively.

Constructing the Oracle

Shor’s algorithm’s core quantum part is the period finding, and its key lies in efficiently implementing a unitary operator Ua,N. This operator acts on two quantum registers (register 1 storing the exponent x, register 2 storing auxiliary results y) and performs reversible modular exponentiation.

Specifically, the unitary operator Ua,N we need to construct must precisely implement the following transformation for all possible input basis states |x|y:

Ua,N|x|y=|x|y(axmodN)

Where:

  • |x is a register of q qubits, used to store the exponent from 0 to Q1 (Q=2qN).

  • |y is also a register of q qubits, used to store intermediate results.

  • axmodN is the result of the classical modular exponentiation.

  • represents the bitwise XOR operation. XOR is chosen to facilitate the construction of the corresponding unitary matrix (a permutation matrix) and ensure the operation’s reversibility.

Although the complete |yf(x) transformation requires more complex quantum circuits (like quantum modular adders and multipliers), we can directly construct this 22q×22q unitary matrix Ua,N. This matrix is essentially a permutation matrix that uniquely maps each input basis state |x|y to the output basis state |x|y(axmodN).

Implementation Steps

  1. Determine the number of qubits:

    • Target register (register 2): Needs q=log2N qubits to store axmodN (range from 0 to N1).

    • Control register (register 1): Stores the exponent x. To ensure the Quantum Fourier Transform yields the period r with high probability, the number of qubits tq should satisfy 2tqN2, i.e., tq2log2N. Hence, in theory we choose tq=2q, where Qctrl=2tq is the size of the control register’s state space.

    • Simplified approach in this tutorial: For demonstration and resource considerations, we use tq=q qubits for the control register (thus Qctrl=2q), giving a total of ntotal=q+tq=2q qubits. While this works for N=15 and a=2, for larger N it significantly reduces the probability of finding the correct period r, possibly requiring more trials or only yielding a factor of r, or failing.

    • Note on notation: In the remainder of this tutorial, whenever we refer to the control register’s qubit count or its corresponding state space size Q, we are referring to the simplified q and Q=2q.

  2. Calculate modular exponentiation values: For all x[0,Q1] (where Q=2q), compute f(x)=axmodN.

  3. Construct the Unitary Matrix :math:`U`: Create a 2n×2n zero matrix. For each basis state |x|y (x,y[0,Q1]), compute:

    • idxin=(xq)+y and idxout=(xq)+(yf(x)).

    • Set U[idxout,idxin]=1. This permutation matrix is unitary.

  4. Create ``UnivMathGate``: Instantiate a UnivMathGate with the constructed matrix U and apply it to register2 + register1 (ensuring y corresponds to the lower bits).

Example: N=15, a=2

We need q=4 qubits, because 24=1615. The total number of qubits n=2q=8. The Hilbert space dimension is 28=256.

We can obtain x and f(x):

[2]:
q = 4  # number of qubits
N = 15
a = 2
x = []
f = []
for i in range(2**q):
    x.append(i)
    f.append(a**i % N)
print('x: ', x)
print('f(x): ', f)
x:  [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
f(x):  [1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8]

We can observe that f(x) is indeed a periodic function.

Next, we construct the unitary gate corresponding to the modular exponentiation operation:

[3]:
def create_mod_exp_oracle(N, a, register1, register2):
    """
    Construct the gate for modular exponentiation U|x>|y> = |x>|y XOR (a^x mod N)>.

    Args:
        N (int): The number to be factored.
        a (int): The random base chosen in Shor's algorithm.
        register1 (list): Qubit indices for register 1 (requires q qubits where 2^q >= N).
        register2 (list): Qubit indices for register 2 (requires q qubits where 2^q >= N).

    Returns:
        UnivMathGate: The gate corresponding to the modular exponentiation U|x>|y> = |x>|y XOR (a^x mod N)>.
    """
    q = len(register1)
    n_qubits = 2 * q
    dim = 2**n_qubits
    Q = 2**q
    U_matrix = np.zeros((dim, dim), dtype=complex)

    # Precompute f(x) = a^x mod N
    fx_map = {}
    for x in range(Q):
        fx_map[x] = pow(a, x, N)

    # Construct the permutation matrix
    for x in range(Q):       # Iterate through states |x> of register 1
        fx = fx_map[x]
        for y in range(Q):   # Iterate through states |y> of register 2
            idx_in = (x << q) + y         # Index of |x>|y>
            idx_out = (x << q) + (y ^ fx) # Index of |x>|y XOR f(x)>
            U_matrix[idx_out, idx_in] = 1

    # Verify unitarity
    assert np.allclose(U_matrix @ U_matrix.conj().T, np.eye(dim))

    # Create the gate
    # Note: The order in .on() is register2 + register1 to match the matrix construction where y is the lower bits
    oracle_gate = UnivMathGate(f'ModExp({a},{N})', U_matrix).on(register2 + register1)
    return oracle_gate

Now, the gate constructed by the create_mod_exp_oracle() function can perform modular exponentiation on the quantum state |x in register 1 and store the result a|xmodN in register 2.

Little-Endian Convention and Qubit Allocation

MindQuantum uses the little-endian convention to represent quantum states. In this convention, the index of a qubit corresponds to its significance in representing a numerical value: the qubit with the lowest index (index 0) represents the Least Significant Bit (LSB). Therefore, an N-qubit state is typically written as |qN1...q1q0, where q0 is the LSB.

To naturally align MindQuantum’s little-endian convention with the quantum state notation |x|y used earlier, we perform the following qubit allocation:

  • Register 1 (logical value :math:`x`): Qubits q to 2q1 (higher-indexed qubits).

  • Register 2 (logical value :math:`y`): Qubits 0 to q1 (lower-indexed qubits).

This means that although register 1 might be depicted above register 2 in schematic diagrams of Shor’s algorithm, the circuit diagram drawn by MindQuantum will show register 2 above register 1. Quantum gates and measurements will be adjusted accordingly.

This allocation ensures that in MindQuantum’s state vector representation, the qubits associated with the logical value y (corresponding to the lower part of the numerical value) have lower indices, while the qubits associated with the logical value x (corresponding to the higher part) have higher indices. The advantage is that the index of the state vector for the entire quantum state directly maps to the integer value x2q+y, simplifying the matrix construction process for operations like UnivMathGate. It’s important to emphasize that this index-based allocation does not change the core logical function of the Oracle, which is to modify the value of register 2 (y) based on the value of register 1 (x).

Verify the Oracle

We can verify if the Oracle works as expected by applying it to a specific initial state. For example, let’s compute U|8|0. We expect to obtain |8|0(28mod15).

Since 28=256, and 256mod15=1, we have 01=1. Therefore, we expect the final state to be |8|1.

|8 corresponds to binary 1000. |1 corresponds to binary 0001. So the final state |8|1 corresponds to binary 1000 0001 (with register 1 as the higher bits).

[4]:
#pylint: disable=W0104
register1 = range(4, 8)
register2 = range(4)
circuit = Circuit(X.on(7))  # Create circuit, initialize state to |1000>|0000>, i.e., x=8, |8>|0>
circuit += create_mod_exp_oracle(15, 2, list(register1), list(register2))  # Apply the oracle operator
circuit.svg() # Print the circuit diagram
[4]:
../_images/case_library_shor_algorithm_7_0.svg
[5]:
print(circuit.get_qs('mqvector', ket=True))  # Print the final state
1¦10000001⟩

The result in register 1 is 1000, and the result in register 2 is 0001. We previously calculated f(8)=28mod15=1. Therefore, the output is correct.

Next, we need to implement the period finding algorithm.

Period Finding Subroutine

  1. In register 1, we need q>log2N qubits to record the binary number of the variable x[0,N1], and we also need q qubits in register 2 to record f(x)=axmodN,x[0,N1] binary form. At this time, register 1 and register 2 can respectively record the integers of [0,Q1], where Q=2q>N.

  2. The Hadamard gate is applied to all bits in register 1, and the bits in register 1 are in a uniform superposition state of all integers in [0,Q1]

    |ψ=1Qx=0Q1|x
  3. Perform function operation a|ψmodN on the state |ψ stored in register 1, and store the result in register 2. This step is completed by the previously constructed U_operator . Due to the direct operation on the superposition state |ψ, this step can completed in one step, which shows the Quantum Advantage - parallel computing. At this time, the state stored in the circuit is an entangled state, which can be expressed as

    x=0Q1|x|f(x)=i=0r1(|i+|i+r+|i+2r+...)|f(i)
  4. Perform an inverse Quantum Fourier Transform (iQFT) on register 1. This transform uses a Q-order unit root ω=e2πi/Q, which evenly distributes the amplitude of any given state |x on Q states of |y. As shown in step 3, the equivalent states of |i, |i+r, etc. in register 1 are all entangled with the same state |f(i) in register 2. Due to quantum interference, the final measurement probability for a state |y is larger when the phase factor e2πiry/Q is closer to 1 (i.e., points towards the positive real axis). In other words, the measured state |y has a high probability that ry/Q is close to an integer c. For a more detailed mathematical description, please refer to the link: https://en.wikipedia.org/wiki/Shor%27s_algorithm.

  5. Measure register 1 to get the binary string. Convert the binary string to the decimal number y. At this point, y/Qc/r, where c is an unknown integer. Use the continued fraction algorithm to find the irreducible fraction that approximates y/Q (with denominator no larger than N). The denominator of this fraction is the period candidate r. However, there might be another fraction closer to y/Q, or c and r might share a common factor, resulting in r being a factor of the true period. In such cases, the calculation fails, and we need to repeat the process.

Taking the example of N=15,a=2 again, in Constructing the Oracle, we calculated each f(x), from which we can directly see that the function period is 4. Now we can build the corresponding period-finding subroutine and run 100 simulations to see what we get.

[6]:
#pylint: disable=W0104
circuit = Circuit() # Create a quantum circuit
register1 = range(4, 8) # Set qubits 4-7 to register 1
register2 = range(4)    # Set qubits 0-3 to register 2

circuit += UN(H, register1) # Apply H gate to all bits in register 1

# Perform the modular exponentiation operation using the Oracle
# U|x>|y> -> |x>|y XOR a^x mod N>
circuit += create_mod_exp_oracle(15, 2, list(register1), list(register2))

# Perform the inverse Quantum Fourier Transform on register 1.
# Note the qubit order for QFT: [::-1] reverses the register order for correct transformation.
circuit += qft(register1[::-1]).hermitian()
circuit += UN(Measure(), register1) # Measure register 1

circuit.svg() # Draw a circuit diagram
[6]:
../_images/case_library_shor_algorithm_11_0.svg

From the circuit diagram, we can intuitively see that the entire period-finding circuit consists of four parts: Superposition Generation Function Operation Inverse Fourier Transform Measurement.

Next, run the circuit 100 times and observe the measurement results.

[7]:
# pylint: disable=W0104
sim = Simulator('mqvector', circuit.n_qubits) # Create a quantum circuit simulator

# Simulate the circuit 100 times, print the measurement results, set the random seed to a random integer within 100
result = sim.sampling(circuit, shots=100, seed=np.random.randint(100))

result.svg()
[7]:
../_images/case_library_shor_algorithm_13_0.svg

From the statistical results, we can see that only 4 states can be measured in the last register 1, which are y=[0,4,8,12]. This is because ω2πiry/Q,(Q=16) is exactly 1 when y takes these four values, while other states cancel out to zero probability amplitude due to quantum interference. Substituting the measurement results into yQcr, we can see that the formula indeed holds. We have about a 50% probability of getting the correct period r, but about a 25% probability of getting a factor of r, and a 25% probability of getting the 0 state. The latter two cases require recalculation.

Next we are going to construct a general period-finding algorithm.

[8]:
def period_finder(N, a, q):
    circuit = Circuit()  # Create a quantum circuit
    register1 = range(q, 2 * q)  # Set qubits q to 2q-1 for register 1
    register2 = range(q)  # Set qubits 0 to q-1 for register 2

    circuit += UN(H, register1)  # Apply H gate to all qubits in register 1

    # Apply the modular exponentiation Oracle as one big U gate
    circuit += create_mod_exp_oracle(N, a, list(register1), list(register2))

    circuit += qft(register1[::-1]).hermitian()  # Perform inverse QFT on register 1 (note reversed order)
    circuit += UN(Measure(), register1)  # Measure register 1

    sim = Simulator('mqvector', circuit.n_qubits)  # Create a quantum circuit simulator

    # Simulate the circuit once, collect the measurement result, random seed in [0,100)
    result = sim.sampling(circuit, seed=np.random.randint(100), shots=1)

    # result.data is a dict where key is measured binary string, value is count (1)
    result = list(result.data.keys())[0]  # Get the measured binary string
    result = int(result, 2)  # Convert the result from binary to decimal

    # Use continued fraction to approximate result/2**q with denominator <= N
    eigenphase = float(result / 2**q)
    f = Fraction.from_float(eigenphase).limit_denominator(N)
    r = f.denominator  # The denominator is the period candidate

    # Verify if r is the actual period
    if pow(a, r, N) == 1:
        return r
    return None

Classic Computer Part

The classical computer part is responsible for transforming the factorization problem into the problem of finding function period. The specific steps are as follows:

  1. Randomly pick an integer a less than N, use the gcd algorithm to verify whether a and N are mutually prime, if there is a common factor between a and N, then we directly get one of N’s factor, output the result.

  2. Determine the number of qubits q required to store the binary representation of N (such that 2qN).

  3. Use the period finding algorithm (quantum part) to get the period r of the function f(x)=axmodN.

  4. Determine whether r is an even number. If not, go back to step 1.

  5. Calculate ar/2+1 and ar/21. Compute the greatest common divisor (gcd) of each of these with N. One of these gcds might be a non-trivial factor of N. However, it’s possible that ar/2+1 is divisible by N, or the gcd is 1 or N. If a non-trivial factor is found, output the factors. Otherwise (e.g., if r is odd or the gcds are trivial), go back to step 1.

[9]:
#pylint: disable=C0121,R1705
def shor(N):
    while True:
        a = np.random.randint(N - 2) + 2  # Generate a random integer a in [2, N-1]
        b = np.gcd(a, N)  # Compute gcd(a, N)
        if b != 1:
            return b, int(N / b)  # If b is not equal to 1, then b is a prime factor of N. Return the decomposition result

        # Determine the number of bits q such that 2**q >= N
        q = 0
        while True:
            Q = 2**q
            if Q >= N:
                break
            q += 1

        r = period_finder(N, a, q)  # Get the period r

        # If r is not even or not found, retry
        if r != None and r % 2 == 0:
            break

    # Compute a**(r/2)+1 and a**(r/2)-1 and verify that they have a convention with N. If so, output the result
    c = np.gcd(a**(int(r / 2)) + 1, N)
    d = np.gcd(a**(int(r / 2)) - 1, N)
    if c != 1 and N % c == 0:
        return c, int(N / c)
    else:
        return d, int(N / d)

It should be noted that since we directly constructed the oracle as a huge unitary matrix gate, the simulation time has increased significantly. Therefore, for cases where N>55, it may take a long time to get the result. Furthermore, due to the simplified control register qubit count mentioned earlier, the probability of finding the correct period r decreases for larger values of N.

Finally, let’s try to factor N=35 using the Shor’s algorithm we have written.

[10]:
N = 35
print("Factoring N = p * q =", N)

p, q = shor(N)

print("p =", p)
print("q =", q)

Factoring N = p * q = 35
p = 5
q = 7

As we can see from the results, we successfully decomposed 35 into two prime factors 5 and 7.

So far, we have successfully implemented the Shor’s algorithm using MindSpore Quantum.

[11]:
from mindquantum.utils.show_info import InfoTable

InfoTable('mindquantum', 'scipy', 'numpy')
[11]:
Software Version
mindquantum0.10.0
scipy1.15.2
numpy1.26.4
System Info
Python3.10.16
OSDarwin arm64
Memory17.18 GB
CPU Max Thread10
DateFri May 16 19:22:48 2025