Program Listing for File text.h

Return to documentation for file (include/text.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_

#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/dataset/constants.h"
#include "include/dataset/transforms.h"

namespace mindspore {
namespace dataset {

class Vocab;
class SentencePieceVocab;
class TensorOperation;

// Transform operations for text
namespace text {

#ifndef _WIN32
class BasicTokenizer final : public TensorTransform {
 public:
  explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false,
                          const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
                          bool with_offsets = false);

  ~BasicTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class BertTokenizer final : public TensorTransform {
 public:
  explicit BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator = "##",
                         int32_t max_bytes_per_token = 100, const std::string &unknown_token = "[UNK]",
                         bool lower_case = false, bool keep_whitespace = false,
                         const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
                         bool with_offsets = false)
      : BertTokenizer(vocab, StringToChar(suffix_indicator), max_bytes_per_token, StringToChar(unknown_token),
                      lower_case, keep_whitespace, normalize_form, preserve_unused_token, with_offsets) {}
  explicit BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
                         int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case,
                         bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token,
                         bool with_offsets);

  ~BertTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class CaseFold final : public TensorTransform {
 public:
  CaseFold();

  ~CaseFold() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;
};
#endif

class JiebaTokenizer final : public TensorTransform {
 public:
  explicit JiebaTokenizer(const std::string &hmm_path, const std::string &mp_path,
                          const JiebaMode &mode = JiebaMode::kMix, bool with_offsets = false)
      : JiebaTokenizer(StringToChar(hmm_path), StringToChar(mp_path), mode, with_offsets) {}

  explicit JiebaTokenizer(const std::vector<char> &hmm_path, const std::vector<char> &mp_path, const JiebaMode &mode,
                          bool with_offsets);

  ~JiebaTokenizer() = default;

  Status AddWord(const std::string &word, int64_t freq = 0) { return AddWordChar(StringToChar(word), freq); }

  Status AddDict(const std::vector<std::pair<std::string, int64_t>> &user_dict) {
    return AddDictChar(PairStringInt64ToPairCharInt64(user_dict));
  }

  Status AddDict(const std::string &file_path) { return AddDictChar(StringToChar(file_path)); }

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  Status ParserFile(const std::string &file_path, std::vector<std::pair<std::string, int64_t>> *const user_dict);

  Status AddWordChar(const std::vector<char> &word, int64_t freq = 0);

  Status AddDictChar(const std::vector<std::pair<std::vector<char>, int64_t>> &user_dict);

  Status AddDictChar(const std::vector<char> &file_path);

  struct Data;
  std::shared_ptr<Data> data_;
};

class Lookup final : public TensorTransform {
 public:
  explicit Lookup(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token = {},
                  mindspore::DataType data_type = mindspore::DataType::kNumberTypeInt32) {
    std::optional<std::vector<char>> unknown_token_c = std::nullopt;
    if (unknown_token != std::nullopt) {
      unknown_token_c = std::vector<char>(unknown_token->begin(), unknown_token->end());
    }
    new (this) Lookup(vocab, unknown_token_c, data_type);
  }

  explicit Lookup(const std::shared_ptr<Vocab> &vocab, const std::optional<std::vector<char>> &unknown_token,
                  mindspore::DataType data_type = mindspore::DataType::kNumberTypeInt32);

  ~Lookup() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class Ngram final : public TensorTransform {
 public:
  explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
                 const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
      : Ngram(ngrams, PairStringToChar(left_pad), PairStringToChar(right_pad), StringToChar(separator)) {}

  explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::vector<char>, int32_t> &left_pad,
                 const std::pair<std::vector<char>, int32_t> &right_pad, const std::vector<char> &separator);

  ~Ngram() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

#ifndef _WIN32
class NormalizeUTF8 final : public TensorTransform {
 public:
  explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);

  ~NormalizeUTF8() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class RegexReplace final : public TensorTransform {
 public:
  explicit RegexReplace(std::string pattern, std::string replace, bool replace_all = true)
      : RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {}

  explicit RegexReplace(const std::vector<char> &pattern, const std::vector<char> &replace, bool replace_all);

  ~RegexReplace() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class RegexTokenizer final : public TensorTransform {
 public:
  explicit RegexTokenizer(std::string delim_pattern, std::string keep_delim_pattern = "", bool with_offsets = false)
      : RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {}

  explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern,
                          bool with_offsets);

  ~RegexTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};
#endif

class SentencePieceTokenizer final : public TensorTransform {
 public:
  SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab,
                         mindspore::dataset::SPieceTokenizerOutType out_type);

  SentencePieceTokenizer(const std::string &vocab_path, mindspore::dataset::SPieceTokenizerOutType out_type)
      : SentencePieceTokenizer(StringToChar(vocab_path), out_type) {}

  SentencePieceTokenizer(const std::vector<char> &vocab_path, mindspore::dataset::SPieceTokenizerOutType out_type);

  ~SentencePieceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class SlidingWindow final : public TensorTransform {
 public:
  explicit SlidingWindow(const int32_t width, const int32_t axis = 0);

  ~SlidingWindow() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class ToNumber final : public TensorTransform {
 public:
  explicit ToNumber(mindspore::DataType data_type);

  ~ToNumber() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class TruncateSequencePair final : public TensorTransform {
 public:
  explicit TruncateSequencePair(int32_t max_length);

  ~TruncateSequencePair() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class UnicodeCharTokenizer final : public TensorTransform {
 public:
  explicit UnicodeCharTokenizer(bool with_offsets = false);

  ~UnicodeCharTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class WordpieceTokenizer final : public TensorTransform {
 public:
  explicit WordpieceTokenizer(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator = "##",
                              int32_t max_bytes_per_token = 100, const std::string &unknown_token = "[UNK]",
                              bool with_offsets = false)
      : WordpieceTokenizer(vocab, StringToChar(suffix_indicator), max_bytes_per_token, StringToChar(unknown_token),
                           with_offsets) {}

  explicit WordpieceTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
                              int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool with_offsets);

  ~WordpieceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

#ifndef _WIN32
class UnicodeScriptTokenizer final : public TensorTransform {
 public:
  explicit UnicodeScriptTokenizer(bool keep_whitespace = false, bool with_offsets = false);

  ~UnicodeScriptTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class WhitespaceTokenizer final : public TensorTransform {
 public:
  explicit WhitespaceTokenizer(bool with_offsets = false);

  ~WhitespaceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};
#endif
}  // namespace text
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_