normalize.cc revision 0b15439f8f0f2d4755587f4096c3ea04cb199d23
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// Normalize the string input. 17// 18// Input: 19// Input[0]: One sentence. string[1] 20// 21// Output: 22// Output[0]: Normalized sentence. string[1] 23// 24#include "absl/strings/ascii.h" 25#include "absl/strings/str_cat.h" 26#include "absl/strings/strip.h" 27#include "re2/re2.h" 28#include "tensorflow/contrib/lite/context.h" 29#include "tensorflow/contrib/lite/kernels/kernel_util.h" 30#include "tensorflow/contrib/lite/string_util.h" 31 32namespace tflite { 33namespace ops { 34namespace custom { 35 36namespace normalize { 37 38// Predictor transforms. 39const char kPunctuationsRegex[] = "[.*()\"]"; 40 41const std::map<string, string>* kRegexTransforms = 42 new std::map<string, string>({ 43 {"([^\\s]+)n't", "\\1 not"}, 44 {"([^\\s]+)'nt", "\\1 not"}, 45 {"([^\\s]+)'ll", "\\1 will"}, 46 {"([^\\s]+)'re", "\\1 are"}, 47 {"([^\\s]+)'ve", "\\1 have"}, 48 {"i'm", "i am"}, 49 }); 50 51static const char kStartToken[] = "<S>"; 52static const char kEndToken[] = "<E>"; 53static const int32 kMaxInputChars = 300; 54 55TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 56 tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0); 57 58 string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len))); 59 absl::StripAsciiWhitespace(&result); 60 // Do not remove commas, semi-colons or colons from the sentences as they can 61 // indicate the beginning of a new clause. 62 RE2::GlobalReplace(&result, kPunctuationsRegex, ""); 63 RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])", 64 "\\1\\2"); 65 RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1"); 66 for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end(); 67 iter++) { 68 RE2::GlobalReplace(&result, iter->first, iter->second); 69 } 70 71 // Treat questions & interjections as special cases. 72 RE2::GlobalReplace(&result, "([?])+", "\\1"); 73 RE2::GlobalReplace(&result, "([!])+", "\\1"); 74 RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 "); 75 RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2"); 76 77 RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", ""); 78 RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", ""); 79 absl::StripAsciiWhitespace(&result); 80 81 // Add start and end token. 82 // Truncate input to maximum allowed size. 83 if (result.length() <= kMaxInputChars) { 84 absl::StrAppend(&result, " ", kEndToken); 85 } else { 86 result = result.substr(0, kMaxInputChars); 87 } 88 result = absl::StrCat(kStartToken, " ", result); 89 90 tflite::DynamicBuffer buf; 91 buf.AddString(result.data(), result.length()); 92 buf.WriteToTensor(GetOutput(context, node, 0)); 93 return kTfLiteOk; 94} 95 96} // namespace normalize 97 98TfLiteRegistration* Register_NORMALIZE() { 99 static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval}; 100 return &r; 101} 102 103} // namespace custom 104} // namespace ops 105} // namespace tflite 106