1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "smartselect/cached-features.h"
18#include "util/base/logging.h"
19
20namespace libtextclassifier {
21
22void CachedFeatures::Extract(
23    const std::vector<std::vector<int>>& sparse_features,
24    const std::vector<std::vector<float>>& dense_features,
25    const std::function<bool(const std::vector<int>&, const std::vector<float>&,
26                             float*)>& feature_vector_fn) {
27  features_.resize(feature_vector_size_ * tokens_.size());
28  for (int i = 0; i < tokens_.size(); ++i) {
29    feature_vector_fn(sparse_features[i], dense_features[i],
30                      features_.data() + i * feature_vector_size_);
31  }
32}
33
34bool CachedFeatures::Get(int click_pos, VectorSpan<float>* features,
35                         VectorSpan<Token>* output_tokens) {
36  const int token_start = click_pos - context_size_;
37  const int token_end = click_pos + context_size_ + 1;
38  if (token_start < 0 || token_end > tokens_.size()) {
39    TC_LOG(ERROR) << "Tokens out of range: " << token_start << " " << token_end;
40    return false;
41  }
42
43  *features =
44      VectorSpan<float>(features_.begin() + token_start * feature_vector_size_,
45                        features_.begin() + token_end * feature_vector_size_);
46  *output_tokens = VectorSpan<Token>(tokens_.begin() + token_start,
47                                     tokens_.begin() + token_end);
48  if (remap_v0_feature_vector_) {
49    RemapV0FeatureVector(features);
50  }
51
52  return true;
53}
54
55void CachedFeatures::RemapV0FeatureVector(VectorSpan<float>* features) {
56  if (!remap_v0_feature_vector_) {
57    return;
58  }
59
60  auto it = features->begin();
61  int num_suffix_features =
62      feature_vector_size_ - remap_v0_chargram_embedding_size_;
63  int num_tokens = context_size_ * 2 + 1;
64  for (int t = 0; t < num_tokens; ++t) {
65    for (int i = 0; i < remap_v0_chargram_embedding_size_; ++i) {
66      v0_feature_storage_[t * remap_v0_chargram_embedding_size_ + i] = *it;
67      ++it;
68    }
69    // Rest of the features are the dense features that come to the end.
70    for (int i = 0; i < num_suffix_features; ++i) {
71      // clang-format off
72      v0_feature_storage_[num_tokens * remap_v0_chargram_embedding_size_
73                      + t * num_suffix_features
74                      + i] = *it;
75      // clang-format on
76      ++it;
77    }
78  }
79  *features = VectorSpan<float>(v0_feature_storage_);
80}
81
82}  // namespace libtextclassifier
83