量子神经网络在自然语言处理中的应用

下载Notebook下载样例代码查看源文件 在线运行

概述

在自然语言处理过程中,词嵌入(Word embedding)是其中的重要步骤,它是一个将高维度空间的词向量嵌入到一个维数更低的连续向量空间的过程。当给予神经网络的语料信息不断增加时,网络的训练过程将越来越困难。利用量子力学的态叠加和纠缠等特性,我们可以利用量子神经网络来处理这些经典语料信息,加入其训练过程,并提高收敛精度。下面,我们将简单地搭建一个量子经典混合神经网络来完成一个词嵌入任务。 提示:由于HiQ量子云平台JupyterLab环境中svg图片暂时无法显示,请开发者自行打印量子线路。

环境准备

设置系统所使用的线程数,当您的服务器CPU较多时,如果不设置,系统默认调用所有CPU,反而会导致模拟变慢甚至卡住。

[ ]:
import os

os.environ['OMP_NUM_THREADS'] = '1'

导入本教程所依赖模块

[1]:
import numpy as np
import time
import mindspore as ms
import mindspore.ops as ops
import mindspore.dataset as ds
from mindspore import nn
from mindquantum.framework import MQLayer
from mindquantum.core.gates import RX, RY, X, H
from mindquantum.core.circuit import Circuit, UN
from mindquantum.core.operators import Hamiltonian, QubitOperator

本教程实现的是一个CBOW模型,即利用某个词所处的环境来预测该词。例如对于“I love natural language processing”这句话,我们可以将其切分为5个词,[“I”, “love”, “natural”, “language”, “processing”],在所选窗口为2时,我们要处理的问题是利用[“I”, “love”, “language”, “processing”]来预测出目标词汇“natural”。这里我们以窗口为2为例,搭建如下的量子神经网络,来完成词嵌入任务。

quantum word embedding

这里,编码线路会将“I”、“love”、“language”和“processing”的编码信息编码到量子线路中,待训练的量子线路由四个Ansatz线路构成,最后我们在量子线路末端对量子比特做\(\text{Z}\)基矢上的测量,具体所需测量的比特的个数由所需嵌入空间的维数确定。

数据预处理

我们对所需要处理的语句进行处理,生成关于该句子的词典,并根据窗口大小来生成样本点。

[2]:
def GenerateWordDictAndSample(corpus, window=2):
    all_words = corpus.split()
    word_set = list(set(all_words))
    word_set.sort()
    word_dict = {w: i for i, w in enumerate(word_set)}
    sampling = []
    for index, _ in enumerate(all_words[window:-window]):
        around = []
        for i in range(index, index + 2*window + 1):
            if i != index + window:
                around.append(all_words[i])
        sampling.append([around, all_words[index + window]])
    return word_dict, sampling
[3]:
word_dict, sample = GenerateWordDictAndSample("I love natural language processing")
print(word_dict)
print('word dict size: ', len(word_dict))
print('samples: ', sample)
print('number of samples: ', len(sample))
{'I': 0, 'language': 1, 'love': 2, 'natural': 3, 'processing': 4}
word dict size:  5
samples:  [[['I', 'love', 'language', 'processing'], 'natural']]
number of samples:  1

根据如上信息,我们得到该句子的词典大小为5,能够产生一个样本点。

编码线路

为了简单起见,我们使用的编码线路由\(\text{RX}\)旋转门构成,结构如下。

encoder circuit

我们对每个量子门都作用一个\(\text{RX}\)旋转门。

[4]:
def GenerateEncoderCircuit(n_qubits, prefix=''):
    if prefix and prefix[-1] != '_':
        prefix += '_'
    circ = Circuit()
    for i in range(n_qubits):
        circ += RX(prefix + str(i)).on(i)
    return circ.as_encoder()
[5]:
GenerateEncoderCircuit(3, prefix='e').svg()
[5]:
_images/qnn_for_nlp_10_0.svg

我们通常用\(\left|0\right>\)\(\left|1\right>\)来标记二能级量子比特的两个状态,由态叠加原理,量子比特还可以处于这两个状态的叠加态:

\[\left|\psi\right>=\alpha\left|0\right>+\beta\left|1\right>\]

对于\(n\)比特的量子态,其将处于\(2^n\)维的希尔伯特空间中。对于上面由5个词构成的词典,我们只需要\(\lceil \log_2 5 \rceil=3\)个量子比特即可完成编码,这也体现出量子计算的优越性。

例如对于上面词典中的“love”,其对应的标签为2,2的二进制表示为010,我们只需将编码线路中的e_0e_1e_2分别设为\(0\)\(\pi\)\(0\)即可。下面来验证一下。

[6]:
from mindquantum.simulator import Simulator

n_qubits = 3 # number of qubits of this quantum circuit
label = 2 # label need to encode
label_bin = bin(label)[-1: 1: -1].ljust(n_qubits, '0') # binary form of label
label_array = np.array([int(i)*np.pi for i in label_bin]).astype(np.float32) # parameter value of encoder
encoder = GenerateEncoderCircuit(n_qubits, prefix='e') # encoder circuit
encoder_params_names = encoder.params_name # parameter names of encoder

print("Label is: ", label)
print("Binary label is: ", label_bin)
print("Parameters of encoder is: \n", np.round(label_array, 5))
print("Encoder circuit is: \n", encoder)
print("Encoder parameter names are: \n", encoder_params_names)

# quantum state evolution operator
state = encoder.get_qs(pr=dict(zip(encoder_params_names, label_array)))
amp = np.round(np.abs(state)**2, 3)

print("Amplitude of quantum state is: \n", amp)
print("Label in quantum state is: ", np.argmax(amp))
Label is:  2
Binary label is:  010
Parameters of encoder is:
 [0.      3.14159 0.     ]
Encoder circuit is:
 q0: ──RX(e_0)──

q1: ──RX(e_1)──

q2: ──RX(e_2)──
Encoder parameter names are:
 ['e_0', 'e_1', 'e_2']
Amplitude of quantum state is:
 [0. 0. 1. 0. 0. 0. 0. 0.]
Label in quantum state is:  2

通过上面的验证,我们发现,对于标签为2的数据,最后得到量子态的振幅最大的位置也是2,因此得到的量子态正是对输入标签的编码。我们将对数据编码生成参数数值的过程总结成如下函数。

[7]:
def GenerateTrainData(sample, word_dict):
    n_qubits = np.int(np.ceil(np.log2(1 + max(word_dict.values()))))
    data_x = []
    data_y = []
    for around, center in sample:
        data_x.append([])
        for word in around:
            label = word_dict[word]
            label_bin = bin(label)[-1: 1: -1].ljust(n_qubits, '0')
            label_array = [int(i)*np.pi for i in label_bin]
            data_x[-1].extend(label_array)
        data_y.append(word_dict[center])
    return np.array(data_x).astype(np.float32), np.array(data_y).astype(np.int32)
[8]:
GenerateTrainData(sample, word_dict)
<ipython-input-7-4e49928cbb0a>:2: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  n_qubits = np.int(np.ceil(np.log2(1 + max(word_dict.values()))))
[8]:
(array([[0.       , 0.       , 0.       , 0.       , 3.1415927, 0.       ,
         3.1415927, 0.       , 0.       , 0.       , 0.       , 3.1415927]],
       dtype=float32),
 array([3], dtype=int32))

根据上面的结果,我们将4个输入的词编码的信息合并为一个更长向量,便于后续神经网络调用。

Ansatz线路

Ansatz线路的选择多种多样,我们选择如下的量子线路作为Ansatz线路,它的一个单元由一层\(\text{RY}\)门和一层\(\text{CNOT}\)门构成,对此单元重复\(p\)次构成整个Ansatz线路。

ansatz circuit

定义如下函数生成Ansatz线路。

[9]:
def GenerateAnsatzCircuit(n_qubits, layers, prefix=''):
    if prefix and prefix[-1] != '_':
        prefix += '_'
    circ = Circuit()
    for l in range(layers):
        for i in range(n_qubits):
            circ += RY(prefix + str(l) + '_' + str(i)).on(i)
        for i in range(l % 2, n_qubits, 2):
            if i < n_qubits and i + 1 < n_qubits:
                circ += X.on(i + 1, i)
    return circ
[10]:
GenerateAnsatzCircuit(5, 2, 'a').svg()
[10]:
_images/qnn_for_nlp_18_0.svg

测量

我们把对不同比特位上的测量结果作为降维后的数据。具体过程与比特编码类似,例如当我们想将词向量降维为5维向量时,对于第3维的数据可以如下产生:

  • 3对应的二进制为00011

  • 测量量子线路末态对\(Z_0Z_1\)哈密顿量的期望值。

下面函数将给出产生各个维度上数据所需的哈密顿量(hams),其中n_qubits表示线路的比特数,dims表示词嵌入的维度:

[11]:
def GenerateEmbeddingHamiltonian(dims, n_qubits):
    hams = []
    for i in range(dims):
        s = ''
        for j, k in enumerate(bin(i + 1)[-1:1:-1]):
            if k == '1':
                s = s + 'Z' + str(j) + ' '
        hams.append(Hamiltonian(QubitOperator(s)))
    return hams
[12]:
GenerateEmbeddingHamiltonian(5, 5)
[12]:
[1 [Z0] , 1 [Z1] , 1 [Z0 Z1] , 1 [Z2] , 1 [Z0 Z2] ]

量子版词向量嵌入层

量子版词向量嵌入层结合前面的编码量子线路和待训练量子线路,以及测量哈密顿量,将num_embedding个词嵌入为embedding_dim维的词向量。这里我们还在量子线路的最开始加上了Hadamard门,将初态制备为均匀叠加态,用以提高量子神经网络的表达能力。

下面,我们定义量子嵌入层,它将返回一个量子线路模拟算子。

[13]:
def QEmbedding(num_embedding, embedding_dim, window, layers, n_threads):
    n_qubits = int(np.ceil(np.log2(num_embedding)))
    hams = GenerateEmbeddingHamiltonian(embedding_dim, n_qubits)
    circ = Circuit()
    circ = UN(H, n_qubits)
    encoder_param_name = []
    ansatz_param_name = []
    for w in range(2 * window):
        encoder = GenerateEncoderCircuit(n_qubits, 'Encoder_' + str(w))
        ansatz = GenerateAnsatzCircuit(n_qubits, layers, 'Ansatz_' + str(w))
        encoder.no_grad()
        circ += encoder
        circ += ansatz
        encoder_param_name.extend(encoder.params_name)
        ansatz_param_name.extend(ansatz.params_name)
    grad_ops = Simulator('projectq', circ.n_qubits).get_expectation_with_grad(hams,
                                                                              circ,
                                                                              parallel_worker=n_threads)
    return MQLayer(grad_ops)

整个训练模型跟经典网络类似,由一个嵌入层和两个全连通层构成,然而此处的嵌入层是由量子神经网络构成。下面定义量子神经网络CBOW。

[14]:
class CBOW(nn.Cell):
    def __init__(self, num_embedding, embedding_dim, window, layers, n_threads,
                 hidden_dim):
        super(CBOW, self).__init__()
        self.embedding = QEmbedding(num_embedding, embedding_dim, window,
                                    layers, n_threads)
        self.dense1 = nn.Dense(embedding_dim, hidden_dim)
        self.dense2 = nn.Dense(hidden_dim, num_embedding)
        self.relu = ops.ReLU()

    def construct(self, x):
        embed = self.embedding(x)
        out = self.dense1(embed)
        out = self.relu(out)
        out = self.dense2(out)
        return out

下面我们对一个稍长的句子来进行训练。首先定义LossMonitorWithCollection用于监督收敛过程,并搜集收敛过程的损失。

[15]:
class LossMonitorWithCollection(ms.train.callback.LossMonitor):
    def __init__(self, per_print_times=1):
        super(LossMonitorWithCollection, self).__init__(per_print_times)
        self.loss = []

    def begin(self, run_context):
        self.begin_time = time.time()

    def end(self, run_context):
        self.end_time = time.time()
        print('Total time used: {}'.format(self.end_time - self.begin_time))

    def epoch_begin(self, run_context):
        self.epoch_begin_time = time.time()

    def epoch_end(self, run_context):
        cb_params = run_context.original_args()
        self.epoch_end_time = time.time()
        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
            print('')

    def step_end(self, run_context):
        cb_params = run_context.original_args()
        loss = cb_params.net_outputs

        if isinstance(loss, (tuple, list)):
            if isinstance(loss[0], ms.Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
                loss = loss[0]

        if isinstance(loss, ms.Tensor) and isinstance(loss.asnumpy(), np.ndarray):
            loss = np.mean(loss.asnumpy())

        cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1

        if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
            raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
                cb_params.cur_epoch_num, cur_step_in_epoch))
        self.loss.append(loss)
        if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
            print("\repoch: %+3s step: %+3s time: %5.5s, loss is %5.5s" % (cb_params.cur_epoch_num, cur_step_in_epoch, time.time() - self.epoch_begin_time, loss), flush=True, end='')

接下来,利用量子版本的CBOW来对一个长句进行词嵌入。运行之前请在终端运行export OMP_NUM_THREADS=4,将量子模拟器的线程数设置为4个,当所需模拟的量子系统比特数较多时,可设置更多的线程数来提高模拟效率。

[16]:
import mindspore as ms
ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU")
corpus = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells."""

ms.set_seed(42)
window_size = 2
embedding_dim = 10
hidden_dim = 128
word_dict, sample = GenerateWordDictAndSample(corpus, window=window_size)
train_x, train_y = GenerateTrainData(sample, word_dict)

train_loader = ds.NumpySlicesDataset({
    "around": train_x,
    "center": train_y
}, shuffle=False).batch(3)
net = CBOW(len(word_dict), embedding_dim, window_size, 3, 4, hidden_dim)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
loss_monitor = LossMonitorWithCollection(500)
model = ms.Model(net, net_loss, net_opt)
model.train(350, train_loader, callbacks=[loss_monitor], dataset_sink_mode=False)
<ipython-input-7-4e49928cbb0a>:2: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  n_qubits = np.int(np.ceil(np.log2(1 + max(word_dict.values()))))
epoch:  25 step:  20 time: 0.351, loss is 3.154
epoch:  50 step:  20 time: 0.362, loss is 3.023
epoch:  75 step:  20 time: 0.353, loss is 2.948
epoch: 100 step:  20 time: 0.389, loss is 2.299
epoch: 125 step:  20 time: 0.392, loss is 0.810
epoch: 150 step:  20 time: 0.389, loss is 0.464
epoch: 175 step:  20 time: 0.384, loss is 0.306
epoch: 200 step:  20 time: 0.383, loss is 0.217
epoch: 225 step:  20 time: 0.387, loss is 0.168
epoch: 250 step:  20 time: 0.382, loss is 0.143
epoch: 275 step:  20 time: 0.389, loss is 0.130
epoch: 300 step:  20 time: 0.386, loss is 0.122
epoch: 325 step:  20 time: 0.408, loss is 0.117
epoch: 350 step:  20 time: 0.492, loss is 0.102
Total time used: 138.5629165172577

打印收敛过程中的损失函数值:

[17]:
import matplotlib.pyplot as plt

plt.plot(loss_monitor.loss, '.')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.show()
_images/qnn_for_nlp_31_0.png

通过如下方法打印量子嵌入层的量子线路中的参数:

[18]:
net.embedding.weight.asnumpy()
[18]:
array([-1.06950325e-03, -1.62345007e-01,  6.51378045e-03,  3.30513604e-02,
        1.43976521e-03, -8.73360550e-05,  1.58920437e-02,  4.88108210e-02,
       -1.38961999e-02, -8.95568263e-03, -9.16828722e-05,  6.78092847e-03,
        9.64443013e-03,  6.65064156e-02, -2.27977871e-03, -2.90895114e-04,
        6.87254360e-03, -3.33692250e-03, -5.43189228e-01, -1.90237209e-01,
       -3.96547168e-02, -1.54710874e-01,  3.94615083e-04, -3.17311606e-05,
       -5.17031252e-01,  9.45210159e-01,  6.53367564e-02, -4.39741276e-02,
       -6.84748637e-03, -9.54589061e-03, -5.17159104e-01,  7.45301664e-01,
       -3.10309901e-04, -3.35418060e-02,  2.80578714e-03, -1.21473498e-03,
        2.32869145e-02, -2.02556834e-01, -9.99295652e-01, -2.33947067e-05,
        6.91292621e-03, -1.37111245e-04,  1.10169267e-02, -2.61709969e-02,
       -5.76490164e-01,  6.42279327e-01, -1.17960293e-02, -3.99340130e-03,
        9.62817296e-03, -2.04294510e-02,  9.17679537e-03,  6.43585920e-01,
        7.80070573e-03,  1.40992356e-02, -1.67036298e-04, -7.76478276e-03,
       -3.02837696e-02, -2.40557283e-01,  2.06130613e-02,  7.22330203e-03,
        4.16821009e-03,  2.04327740e-02,  1.80713329e-02, -1.01204574e-01,
        1.14764208e-02,  2.05871137e-03, -5.73002594e-03,  2.16162428e-01,
       -1.32567063e-02, -1.02419645e-01,  4.16066934e-04, -9.28813033e-03],
      dtype=float32)

经典版词向量嵌入层

这里我们利用经典的词向量嵌入层来搭建一个经典的CBOW神经网络,并与量子版本进行对比。

首先,搭建经典的CBOW神经网络,其中的参数跟量子版本的类似。

[19]:
class CBOWClassical(nn.Cell):
    def __init__(self, num_embedding, embedding_dim, window, hidden_dim):
        super(CBOWClassical, self).__init__()
        self.dim = 2 * window * embedding_dim
        self.embedding = nn.Embedding(num_embedding, embedding_dim, True)
        self.dense1 = nn.Dense(self.dim, hidden_dim)
        self.dense2 = nn.Dense(hidden_dim, num_embedding)
        self.relu = ops.ReLU()
        self.reshape = ops.Reshape()

    def construct(self, x):
        embed = self.embedding(x)
        embed = self.reshape(embed, (-1, self.dim))
        out = self.dense1(embed)
        out = self.relu(out)
        out = self.dense2(out)
        return out

生成适用于经典CBOW神经网络的数据集。

[20]:
train_x = []
train_y = []
for i in sample:
    around, center = i
    train_y.append(word_dict[center])
    train_x.append([])
    for j in around:
        train_x[-1].append(word_dict[j])
train_x = np.array(train_x).astype(np.int32)
train_y = np.array(train_y).astype(np.int32)
print("train_x shape: ", train_x.shape)
print("train_y shape: ", train_y.shape)
train_x shape:  (58, 4)
train_y shape:  (58,)

我们对经典CBOW网络进行训练。

[21]:
ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")

train_loader = ds.NumpySlicesDataset({
    "around": train_x,
    "center": train_y
}, shuffle=False).batch(3)
net = CBOWClassical(len(word_dict), embedding_dim, window_size, hidden_dim)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)
loss_monitor = LossMonitorWithCollection(500)
model = ms.Model(net, net_loss, net_opt)
model.train(350, train_loader, callbacks=[loss_monitor], dataset_sink_mode=False)
epoch:  25 step:  20 time: 0.023, loss is 3.155
epoch:  50 step:  20 time: 0.014, loss is 3.027
epoch:  75 step:  20 time: 0.022, loss is 3.010
epoch: 100 step:  20 time: 0.021, loss is 2.955
epoch: 125 step:  20 time: 0.021, loss is 0.630
epoch: 150 step:  20 time: 0.022, loss is 0.059
epoch: 175 step:  20 time: 0.023, loss is 0.008
epoch: 200 step:  20 time: 0.022, loss is 0.003
epoch: 225 step:  20 time: 0.023, loss is 0.001
epoch: 250 step:  20 time: 0.021, loss is 0.001
epoch: 275 step:  20 time: 0.021, loss is 0.000
epoch: 300 step:  20 time: 0.018, loss is 0.000
epoch: 325 step:  20 time: 0.022, loss is 0.000
epoch: 350 step:  20 time: 0.019, loss is 0.000
Total time used: 8.10720443725586

打印收敛过程中的损失函数值:

[22]:
import matplotlib.pyplot as plt

plt.plot(loss_monitor.loss, '.')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.show()
_images/qnn_for_nlp_41_0.png

由上可知,通过量子模拟得到的量子版词嵌入模型也能很好的完成嵌入任务。当数据集大到经典计算机算力难以承受时,量子计算机将能够轻松处理这类问题。

参考文献

[1] Tomas Mikolov, Kai Chen, Greg Corrado, Jeffrey Dean. Efficient Estimation of Word Representations in Vector Space