mindspore.dataset.text.TruncateSequencePair
- class mindspore.dataset.text.TruncateSequencePair(max_length)[源代码]
对两列 1-D 字符串输入进行截断,使其总长度小于指定长度。
- 参数:
max_length (int) - 字符串最大输出总长。当其大于或等于两列输入字符串总长时,不进行截断; 否则,优先截取两列输入中的较长者,直至其总长等于该值。
- 异常:
TypeError - 当 max_length 不为int类型。
- 支持平台:
CPU
样例:
>>> import mindspore.dataset as ds >>> import mindspore.dataset.text as text >>> >>> # Use the transform in dataset pipeline mode >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=([[1, 2, 3]], [[4, 5]]), column_names=["col1", "col2"]) >>> # Data before >>> # | col1 | col2 | >>> # +-----------+-----------| >>> # | [1, 2, 3] | [4, 5] | >>> # +-----------+-----------+ >>> truncate_sequence_pair_op = text.TruncateSequencePair(max_length=4) >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=truncate_sequence_pair_op, ... input_columns=["col1", "col2"]) >>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True): ... print(item["col1"], item["col2"]) [1 2] [4 5] >>> # Data after >>> # | col1 | col2 | >>> # +-----------+-----------+ >>> # | [1, 2] | [4, 5] | >>> # +-----------+-----------+ >>> >>> # Use the transform in eager mode >>> data = [["1", "2", "3"], ["4", "5"]] >>> output = text.TruncateSequencePair(4)(*data) >>> print(output) (array(['1', '2'], dtype='<U1'), array(['4', '5'], dtype='<U1'))
- 教程样例: