normalize.cc revision 0b15439f8f0f2d4755587f4096c3ea04cb199d23
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Normalize the string input.
17//
18// Input:
19//     Input[0]: One sentence. string[1]
20//
21// Output:
22//     Output[0]: Normalized sentence. string[1]
23//
24#include "absl/strings/ascii.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/strip.h"
27#include "re2/re2.h"
28#include "tensorflow/contrib/lite/context.h"
29#include "tensorflow/contrib/lite/kernels/kernel_util.h"
30#include "tensorflow/contrib/lite/string_util.h"
31
32namespace tflite {
33namespace ops {
34namespace custom {
35
36namespace normalize {
37
38// Predictor transforms.
39const char kPunctuationsRegex[] = "[.*()\"]";
40
41const std::map<string, string>* kRegexTransforms =
42    new std::map<string, string>({
43        {"([^\\s]+)n't", "\\1 not"},
44        {"([^\\s]+)'nt", "\\1 not"},
45        {"([^\\s]+)'ll", "\\1 will"},
46        {"([^\\s]+)'re", "\\1 are"},
47        {"([^\\s]+)'ve", "\\1 have"},
48        {"i'm", "i am"},
49    });
50
51static const char kStartToken[] = "<S>";
52static const char kEndToken[] = "<E>";
53static const int32 kMaxInputChars = 300;
54
55TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
56  tflite::StringRef input = tflite::GetString(GetInput(context, node, 0), 0);
57
58  string result(absl::AsciiStrToLower(absl::string_view(input.str, input.len)));
59  absl::StripAsciiWhitespace(&result);
60  // Do not remove commas, semi-colons or colons from the sentences as they can
61  // indicate the beginning of a new clause.
62  RE2::GlobalReplace(&result, kPunctuationsRegex, "");
63  RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)([\\s,;:/])",
64                     "\\1\\2");
65  RE2::GlobalReplace(&result, "\\s('t|'nt|n't|'d|'ll|'s|'m|'ve|'re)$", "\\1");
66  for (auto iter = kRegexTransforms->begin(); iter != kRegexTransforms->end();
67       iter++) {
68    RE2::GlobalReplace(&result, iter->first, iter->second);
69  }
70
71  // Treat questions & interjections as special cases.
72  RE2::GlobalReplace(&result, "([?])+", "\\1");
73  RE2::GlobalReplace(&result, "([!])+", "\\1");
74  RE2::GlobalReplace(&result, "([^?!]+)([?!])", "\\1 \\2 ");
75  RE2::GlobalReplace(&result, "([?!])([?!])", "\\1 \\2");
76
77  RE2::GlobalReplace(&result, "[\\s,:;\\-&'\"]+$", "");
78  RE2::GlobalReplace(&result, "^[\\s,:;\\-&'\"]+", "");
79  absl::StripAsciiWhitespace(&result);
80
81  // Add start and end token.
82  // Truncate input to maximum allowed size.
83  if (result.length() <= kMaxInputChars) {
84    absl::StrAppend(&result, " ", kEndToken);
85  } else {
86    result = result.substr(0, kMaxInputChars);
87  }
88  result = absl::StrCat(kStartToken, " ", result);
89
90  tflite::DynamicBuffer buf;
91  buf.AddString(result.data(), result.length());
92  buf.WriteToTensor(GetOutput(context, node, 0));
93  return kTfLiteOk;
94}
95
96}  // namespace normalize
97
98TfLiteRegistration* Register_NORMALIZE() {
99  static TfLiteRegistration r = {nullptr, nullptr, nullptr, normalize::Eval};
100  return &r;
101}
102
103}  // namespace custom
104}  // namespace ops
105}  // namespace tflite
106