mindspore.dataset.text.TruncateSequencePair

查看源文件
class mindspore.dataset.text.TruncateSequencePair(max_length)[源代码]

对两列 1-D 字符串输入进行截断,使其总长度小于指定长度。

参数:
  • max_length (int) - 字符串最大输出总长。当其大于或等于两列输入字符串总长时,不进行截断; 否则,优先截取两列输入中的较长者,直至其总长等于该值。

异常:
  • TypeError - 当 max_length 不为int类型。

支持平台:

CPU

样例:

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.text as text
>>>
>>> # Use the transform in dataset pipeline mode
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=([[1, 2, 3]], [[4, 5]]), column_names=["col1", "col2"])
>>> # Data before
>>> # |   col1    |   col2    |
>>> # +-----------+-----------|
>>> # | [1, 2, 3] |  [4, 5]   |
>>> # +-----------+-----------+
>>> truncate_sequence_pair_op = text.TruncateSequencePair(max_length=4)
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=truncate_sequence_pair_op,
...                                                 input_columns=["col1", "col2"])
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["col1"], item["col2"])
[1 2] [4 5]
>>> # Data after
>>> # |   col1    |   col2    |
>>> # +-----------+-----------+
>>> # |  [1, 2]   |  [4, 5]   |
>>> # +-----------+-----------+
>>>
>>> # Use the transform in eager mode
>>> data = [["1", "2", "3"], ["4", "5"]]
>>> output = text.TruncateSequencePair(4)(*data)
>>> print(output)
(array(['1', '2'], dtype='<U1'), array(['4', '5'], dtype='<U1'))
教程样例: