1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15// Unit test for TFLite Lookup op. 16 17#include <iomanip> 18#include <vector> 19 20#include <gmock/gmock.h> 21#include <gtest/gtest.h> 22#include "tensorflow/contrib/lite/interpreter.h" 23#include "tensorflow/contrib/lite/kernels/register.h" 24#include "tensorflow/contrib/lite/kernels/test_util.h" 25#include "tensorflow/contrib/lite/model.h" 26#include "tensorflow/contrib/lite/string_util.h" 27 28namespace tflite { 29namespace { 30 31using ::testing::ElementsAreArray; 32 33class HashtableLookupOpModel : public SingleOpModel { 34 public: 35 HashtableLookupOpModel(std::initializer_list<int> lookup_shape, 36 std::initializer_list<int> key_shape, 37 std::initializer_list<int> value_shape, 38 TensorType type) { 39 lookup_ = AddInput(TensorType_INT32); 40 key_ = AddInput(TensorType_INT32); 41 value_ = AddInput(type); 42 output_ = AddOutput(type); 43 hit_ = AddOutput(TensorType_UINT8); 44 SetBuiltinOp(BuiltinOperator_HASHTABLE_LOOKUP, BuiltinOptions_NONE, 0); 45 BuildInterpreter({lookup_shape, key_shape, value_shape}); 46 } 47 48 void SetLookup(std::initializer_list<int> data) { 49 PopulateTensor<int>(lookup_, data); 50 } 51 52 void SetHashtableKey(std::initializer_list<int> data) { 53 PopulateTensor<int>(key_, data); 54 } 55 56 void SetHashtableValue(const std::vector<string>& content) { 57 PopulateStringTensor(value_, content); 58 } 59 60 void SetHashtableValue(const std::function<float(int)>& function) { 61 TfLiteTensor* tensor = interpreter_->tensor(value_); 62 int rows = tensor->dims->data[0]; 63 for (int i = 0; i < rows; i++) { 64 tensor->data.f[i] = function(i); 65 } 66 } 67 68 void SetHashtableValue(const std::function<float(int, int)>& function) { 69 TfLiteTensor* tensor = interpreter_->tensor(value_); 70 int rows = tensor->dims->data[0]; 71 int features = tensor->dims->data[1]; 72 for (int i = 0; i < rows; i++) { 73 for (int j = 0; j < features; j++) { 74 tensor->data.f[i * features + j] = function(i, j); 75 } 76 } 77 } 78 79 std::vector<string> GetStringOutput() { 80 TfLiteTensor* output = interpreter_->tensor(output_); 81 int num = GetStringCount(output); 82 std::vector<string> result(num); 83 for (int i = 0; i < num; i++) { 84 auto ref = GetString(output, i); 85 result[i] = string(ref.str, ref.len); 86 } 87 return result; 88 } 89 90 std::vector<float> GetOutput() { return ExtractVector<float>(output_); } 91 std::vector<uint8_t> GetHit() { return ExtractVector<uint8_t>(hit_); } 92 93 private: 94 int lookup_; 95 int key_; 96 int value_; 97 int output_; 98 int hit_; 99}; 100 101// TODO(yichengfan): write more tests that exercise the details of the op, 102// such as lookup errors and variable input shapes. 103TEST(HashtableLookupOpTest, Test2DInput) { 104 HashtableLookupOpModel m({4}, {3}, {3, 2}, TensorType_FLOAT32); 105 106 m.SetLookup({1234, -292, -11, 0}); 107 m.SetHashtableKey({-11, 0, 1234}); 108 m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; }); 109 110 m.Invoke(); 111 112 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 113 2.0, 2.1, // 2-nd item 114 0, 0, // Not found 115 0.0, 0.1, // 0-th item 116 1.0, 1.1, // 1-st item 117 }))); 118 EXPECT_THAT(m.GetHit(), ElementsAreArray({ 119 1, 120 0, 121 1, 122 1, 123 })); 124} 125 126TEST(HashtableLookupOpTest, Test1DInput) { 127 HashtableLookupOpModel m({4}, {3}, {3}, TensorType_FLOAT32); 128 129 m.SetLookup({1234, -292, -11, 0}); 130 m.SetHashtableKey({-11, 0, 1234}); 131 m.SetHashtableValue([](int i) { return i * i / 10.0f; }); 132 133 m.Invoke(); 134 135 EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 136 0.4, // 2-nd item 137 0, // Not found 138 0.0, // 0-th item 139 0.1, // 1-st item 140 }))); 141 EXPECT_THAT(m.GetHit(), ElementsAreArray({ 142 1, 143 0, 144 1, 145 1, 146 })); 147} 148 149TEST(HashtableLookupOpTest, TestString) { 150 HashtableLookupOpModel m({4}, {3}, {3}, TensorType_STRING); 151 152 m.SetLookup({1234, -292, -11, 0}); 153 m.SetHashtableKey({-11, 0, 1234}); 154 m.SetHashtableValue({"Hello", "", "Hi"}); 155 156 m.Invoke(); 157 158 EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({ 159 "Hi", // 2-nd item 160 "", // Not found 161 "Hello", // 0-th item 162 "", // 1-st item 163 })); 164 EXPECT_THAT(m.GetHit(), ElementsAreArray({ 165 1, 166 0, 167 1, 168 1, 169 })); 170} 171 172} // namespace 173} // namespace tflite 174 175int main(int argc, char** argv) { 176 ::tflite::LogToStderr(); 177 ::testing::InitGoogleTest(&argc, argv); 178 return RUN_ALL_TESTS(); 179} 180