10b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
20b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
30b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleLicensed under the Apache License, Version 2.0 (the "License");
40b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selleyou may not use this file except in compliance with the License.
50b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleYou may obtain a copy of the License at
60b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
70b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    http://www.apache.org/licenses/LICENSE-2.0
80b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
90b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleUnless required by applicable law or agreed to in writing, software
100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selledistributed under the License is distributed on an "AS IS" BASIS,
110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew SelleSee the License for the specific language governing permissions and
130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellelimitations under the License.
140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle==============================================================================*/
150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/models/smartreply/predictor.h"
170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
180b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "absl/strings/str_split.h"
190b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "re2/re2.h"
200b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/interpreter.h"
210b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/kernels/register.h"
220b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/model.h"
230b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/string_util.h"
240b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
250b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
260b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellevoid RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
270b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
280b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace tflite {
290b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace custom {
300b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellenamespace smartreply {
310b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
320b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Split sentence into segments (using punctuation).
336b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlowerstd::vector<std::string> SplitSentence(const std::string& input) {
340b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  string result(input);
350b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
360b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  RE2::GlobalReplace(&result, "([?.!,])+", " \\1");
370b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  RE2::GlobalReplace(&result, "([?.!,])+\\s+", "\\1\t");
380b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  RE2::GlobalReplace(&result, "[ ]+", " ");
390b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  RE2::GlobalReplace(&result, "\t+$", "");
400b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
416b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower  return absl::StrSplit(result, '\t');
420b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
430b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
440b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle// Predict with TfLite model.
456b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlowervoid ExecuteTfLite(const std::string& sentence,
466b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower                   ::tflite::Interpreter* interpreter,
476b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower                   std::map<std::string, float>* response_map) {
480b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  {
490b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    TfLiteTensor* input = interpreter->tensor(interpreter->inputs()[0]);
500b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    tflite::DynamicBuffer buf;
510b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    buf.AddString(sentence.data(), sentence.length());
520b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    buf.WriteToTensor(input);
530b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    interpreter->AllocateTensors();
540b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
550b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    interpreter->Invoke();
560b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
570b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    TfLiteTensor* messages = interpreter->tensor(interpreter->outputs()[0]);
580b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    TfLiteTensor* confidence = interpreter->tensor(interpreter->outputs()[1]);
590b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
600b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    for (int i = 0; i < confidence->dims->data[0]; i++) {
610b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      float weight = confidence->data.f[i];
620b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      auto response_text = tflite::GetString(messages, i);
630b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      if (response_text.len > 0) {
640b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle        (*response_map)[string(response_text.str, response_text.len)] += weight;
650b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      }
660b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
670b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
680b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
690b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
700b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Sellevoid GetSegmentPredictions(
716b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower    const std::vector<std::string>& input,
726b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower    const ::tflite::FlatBufferModel& model, const SmartReplyConfig& config,
730b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    std::vector<PredictorResponse>* predictor_responses) {
740b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Initialize interpreter
750b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::unique_ptr<::tflite::Interpreter> interpreter;
760b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ::tflite::MutableOpResolver resolver;
770b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  RegisterSelectedOps(&resolver);
780b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  ::tflite::InterpreterBuilder(model, resolver)(&interpreter);
790b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
800b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  if (!model.initialized()) {
810b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    fprintf(stderr, "Failed to mmap model \n");
820b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    return;
830b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
840b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
850b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Execute Tflite Model
866b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower  std::map<std::string, float> response_map;
876b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower  std::vector<std::string> sentences;
886b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower  for (const std::string& str : input) {
896b6244c40197b34f49bb50aa52efb082380d4637A. Unique TensorFlower    std::vector<std::string> splitted_str = SplitSentence(str);
900b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    sentences.insert(sentences.end(), splitted_str.begin(), splitted_str.end());
910b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
920b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const auto& sentence : sentences) {
930b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    ExecuteTfLite(sentence, interpreter.get(), &response_map);
940b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
950b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
960b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Generate the result.
970b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const auto& iter : response_map) {
980b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    PredictorResponse prediction(iter.first, iter.second);
990b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    predictor_responses->emplace_back(prediction);
1000b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1010b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  std::sort(predictor_responses->begin(), predictor_responses->end(),
1020b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle            [](const PredictorResponse& a, const PredictorResponse& b) {
1030b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle              return a.GetScore() > b.GetScore();
1040b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle            });
1050b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1060b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  // Add backoff response.
1070b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  for (const string& backoff : config.backoff_responses) {
1080b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    if (predictor_responses->size() >= config.num_response) {
1090b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle      break;
1100b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    }
1110b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle    predictor_responses->push_back({backoff, config.backoff_confidence});
1120b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle  }
1130b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}
1140b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle
1150b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace smartreply
1160b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace custom
1170b15439f8f0f2d4755587f4096c3ea04cb199d23Andrew Selle}  // namespace tflite
118