mindspore.dataset.text.Truncate

查看源文件
class mindspore.dataset.text.Truncate(max_seq_len)[源代码]

截断输入序列,使其不超过最大长度。

参数:
  • max_seq_len (int) - 最大截断长度。

异常:
  • TypeError - 如果 max_seq_len 的类型不是int。

  • ValueError - 如果 max_seq_len 的值小于或等于0。

  • RuntimeError - 如果输入张量的数据类型不是bool、int、float、double或者str。

支持平台:

CPU

样例:

>>> import mindspore.dataset as ds
>>> import mindspore.dataset.text as text
>>>
>>> # Use the transform in dataset pipeline mode
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=[['a', 'b', 'c', 'd', 'e']], column_names=["text"],
...                                              shuffle=False)
>>> # Data before
>>> # |           col1            |
>>> # +---------------------------+
>>> # | ['a', 'b', 'c', 'd', 'e'] |
>>> # +---------------------------+
>>> truncate = text.Truncate(4)
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=truncate, input_columns=["text"])
>>> for item in numpy_slices_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
...     print(item["text"])
['a' 'b' 'c' 'd']
>>> # Data after
>>> # |          col1          |
>>> # +------------------------+
>>> # |  ['a', 'b', 'c', 'd']  |
>>> # +------------------------+
>>>
>>> # Use the transform in eager mode
>>> data = ["happy", "birthday", "to", "you"]
>>> output = text.Truncate(2)(data)
>>> print(output)
['happy' 'birthday']
教程样例: