1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
18#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
19
20#include <string>
21#include <vector>
22
23#include "common/embedding-feature-extractor.h"
24#include "common/feature-extractor.h"
25#include "common/task-context.h"
26#include "common/workspace.h"
27#include "lang_id/light-sentence-features.h"
28#include "lang_id/light-sentence.h"
29#include "util/base/macros.h"
30
31namespace libtextclassifier {
32namespace nlp_core {
33namespace lang_id {
34
35// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
36class LangIdEmbeddingFeatureExtractor
37    : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
38 public:
39  LangIdEmbeddingFeatureExtractor() {}
40  const std::string ArgPrefix() const override { return "language_identifier"; }
41
42  TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor);
43};
44
45// Handles sentence -> numeric_features and numeric_prediction -> language
46// conversions.
47class LangIdBrainInterface {
48 public:
49  LangIdBrainInterface() {}
50
51  // Initializes resources and parameters.
52  bool Init(TaskContext *context) {
53    if (!feature_extractor_.Init(context)) {
54      return false;
55    }
56    feature_extractor_.RequestWorkspaces(&workspace_registry_);
57    return true;
58  }
59
60  // Extract features from sentence.  On return, FeatureVector features[i]
61  // contains the features for the embedding space #i.
62  void GetFeatures(LightSentence *sentence,
63                   std::vector<FeatureVector> *features) const {
64    WorkspaceSet workspace;
65    workspace.Reset(workspace_registry_);
66    feature_extractor_.Preprocess(&workspace, sentence);
67    return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
68  }
69
70  int NumEmbeddings() const {
71    return feature_extractor_.NumEmbeddings();
72  }
73
74 private:
75  // Typed feature extractor for embeddings.
76  LangIdEmbeddingFeatureExtractor feature_extractor_;
77
78  // The registry of shared workspaces in the feature extractor.
79  WorkspaceRegistry workspace_registry_;
80
81  TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface);
82};
83
84}  // namespace lang_id
85}  // namespace nlp_core
86}  // namespace libtextclassifier
87
88#endif  // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
89