Program Listing for File text.h

Return to documentation for file (include/text.h)

#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_

#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "include/api/dual_abi_helper.h"
#include "include/api/status.h"
#include "include/dataset/constants.h"
#include "include/dataset/transforms.h"

namespace mindspore {
namespace dataset {

class Vocab;
class SentencePieceVocab;
class TensorOperation;

// Transform operations for text
namespace text {

#ifndef _WIN32
class  BasicTokenizer final : public TensorTransform {
 public:
  explicit BasicTokenizer(bool lower_case = false, bool keep_whitespace = false,
                          const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
                          bool with_offsets = false);

  ~BasicTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  BertTokenizer final : public TensorTransform {
 public:
  explicit BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator = "##",
                         int32_t max_bytes_per_token = 100, const std::string &unknown_token = "[UNK]",
                         bool lower_case = false, bool keep_whitespace = false,
                         const NormalizeForm normalize_form = NormalizeForm::kNone, bool preserve_unused_token = true,
                         bool with_offsets = false)
      : BertTokenizer(vocab, StringToChar(suffix_indicator), max_bytes_per_token, StringToChar(unknown_token),
                      lower_case, keep_whitespace, normalize_form, preserve_unused_token, with_offsets) {}
  BertTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
                int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool lower_case,
                bool keep_whitespace, const NormalizeForm normalize_form, bool preserve_unused_token,
                bool with_offsets);

  ~BertTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  CaseFold final : public TensorTransform {
 public:
  CaseFold();

  ~CaseFold() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;
};
#endif

class  JiebaTokenizer final : public TensorTransform {
 public:
  JiebaTokenizer(const std::string &hmm_path, const std::string &mp_path, const JiebaMode &mode = JiebaMode::kMix,
                 bool with_offsets = false)
      : JiebaTokenizer(StringToChar(hmm_path), StringToChar(mp_path), mode, with_offsets) {}

  JiebaTokenizer(const std::vector<char> &hmm_path, const std::vector<char> &mp_path, const JiebaMode &mode,
                 bool with_offsets);

  ~JiebaTokenizer() = default;

  Status AddWord(const std::string &word, int64_t freq = 0) { return AddWordChar(StringToChar(word), freq); }

  Status AddDict(const std::vector<std::pair<std::string, int64_t>> &user_dict) {
    return AddDictChar(PairStringInt64ToPairCharInt64(user_dict));
  }

  Status AddDict(const std::string &file_path) { return AddDictChar(StringToChar(file_path)); }

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  Status ParserFile(const std::string &file_path, std::vector<std::pair<std::string, int64_t>> *const user_dict);

  Status AddWordChar(const std::vector<char> &word, int64_t freq = 0);

  Status AddDictChar(const std::vector<std::pair<std::vector<char>, int64_t>> &user_dict);

  Status AddDictChar(const std::vector<char> &file_path);

  struct Data;
  std::shared_ptr<Data> data_;
};

class  Lookup final : public TensorTransform {
 public:
  explicit Lookup(const std::shared_ptr<Vocab> &vocab, const std::optional<std::string> &unknown_token = {},
                  mindspore::DataType data_type = mindspore::DataType::kNumberTypeInt32) {
    std::optional<std::vector<char>> unknown_token_c = std::nullopt;
    if (unknown_token != std::nullopt) {
      unknown_token_c = std::vector<char>(unknown_token->begin(), unknown_token->end());
    }
    new (this) Lookup(vocab, unknown_token_c, data_type);
  }

  Lookup(const std::shared_ptr<Vocab> &vocab, const std::optional<std::vector<char>> &unknown_token,
         mindspore::DataType data_type = mindspore::DataType::kNumberTypeInt32);

  ~Lookup() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  Ngram final : public TensorTransform {
 public:
  explicit Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::string, int32_t> &left_pad = {"", 0},
                 const std::pair<std::string, int32_t> &right_pad = {"", 0}, const std::string &separator = " ")
      : Ngram(ngrams, PairStringToChar(left_pad), PairStringToChar(right_pad), StringToChar(separator)) {}

  Ngram(const std::vector<int32_t> &ngrams, const std::pair<std::vector<char>, int32_t> &left_pad,
        const std::pair<std::vector<char>, int32_t> &right_pad, const std::vector<char> &separator);

  ~Ngram() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

#ifndef _WIN32
class  NormalizeUTF8 final : public TensorTransform {
 public:
  explicit NormalizeUTF8(NormalizeForm normalize_form = NormalizeForm::kNfkc);

  ~NormalizeUTF8() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  RegexReplace final : public TensorTransform {
 public:
  RegexReplace(std::string pattern, std::string replace, bool replace_all = true)
      : RegexReplace(StringToChar(pattern), StringToChar(replace), replace_all) {}

  RegexReplace(const std::vector<char> &pattern, const std::vector<char> &replace, bool replace_all);

  ~RegexReplace() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  RegexTokenizer final : public TensorTransform {
 public:
  explicit RegexTokenizer(std::string delim_pattern, std::string keep_delim_pattern = "", bool with_offsets = false)
      : RegexTokenizer(StringToChar(delim_pattern), StringToChar(keep_delim_pattern), with_offsets) {}

  explicit RegexTokenizer(const std::vector<char> &delim_pattern, const std::vector<char> &keep_delim_pattern,
                          bool with_offsets);

  ~RegexTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};
#endif

class  SentencePieceTokenizer final : public TensorTransform {
 public:
  SentencePieceTokenizer(const std::shared_ptr<SentencePieceVocab> &vocab,
                         mindspore::dataset::SPieceTokenizerOutType out_type);

  SentencePieceTokenizer(const std::string &vocab_path, mindspore::dataset::SPieceTokenizerOutType out_type)
      : SentencePieceTokenizer(StringToChar(vocab_path), out_type) {}

  SentencePieceTokenizer(const std::vector<char> &vocab_path, mindspore::dataset::SPieceTokenizerOutType out_type);

  ~SentencePieceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  SlidingWindow final : public TensorTransform {
 public:
  explicit SlidingWindow(const int32_t width, const int32_t axis = 0);

  ~SlidingWindow() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  ToNumber final : public TensorTransform {
 public:
  explicit ToNumber(mindspore::DataType data_type);

  ~ToNumber() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  TruncateSequencePair final : public TensorTransform {
 public:
  explicit TruncateSequencePair(int32_t max_length);

  ~TruncateSequencePair() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  UnicodeCharTokenizer final : public TensorTransform {
 public:
  explicit UnicodeCharTokenizer(bool with_offsets = false);

  ~UnicodeCharTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  WordpieceTokenizer final : public TensorTransform {
 public:
  explicit WordpieceTokenizer(const std::shared_ptr<Vocab> &vocab, const std::string &suffix_indicator = "##",
                              int32_t max_bytes_per_token = 100, const std::string &unknown_token = "[UNK]",
                              bool with_offsets = false)
      : WordpieceTokenizer(vocab, StringToChar(suffix_indicator), max_bytes_per_token, StringToChar(unknown_token),
                           with_offsets) {}

  explicit WordpieceTokenizer(const std::shared_ptr<Vocab> &vocab, const std::vector<char> &suffix_indicator,
                              int32_t max_bytes_per_token, const std::vector<char> &unknown_token, bool with_offsets);

  ~WordpieceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

#ifndef _WIN32
class  UnicodeScriptTokenizer final : public TensorTransform {
 public:
  explicit UnicodeScriptTokenizer(bool keep_whitespace = false, bool with_offsets = false);

  ~UnicodeScriptTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};

class  WhitespaceTokenizer final : public TensorTransform {
 public:
  explicit WhitespaceTokenizer(bool with_offsets = false);

  ~WhitespaceTokenizer() = default;

 protected:
  std::shared_ptr<TensorOperation> Parse() override;

 private:
  struct Data;
  std::shared_ptr<Data> data_;
};
#endif
}  // namespace text
}  // namespace dataset
}  // namespace mindspore
#endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_TEXT_H_