1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <vector> 17 18#include <gtest/gtest.h> 19#include "tensorflow/contrib/lite/interpreter.h" 20#include "tensorflow/contrib/lite/kernels/register.h" 21#include "tensorflow/contrib/lite/kernels/test_util.h" 22#include "tensorflow/contrib/lite/model.h" 23#include "tensorflow/contrib/lite/string_util.h" 24 25namespace tflite { 26 27namespace ops { 28namespace custom { 29TfLiteRegistration* Register_PREDICT(); 30 31namespace { 32 33using ::testing::ElementsAreArray; 34 35class PredictOpModel : public SingleOpModel { 36 public: 37 PredictOpModel(std::initializer_list<int> input_signature_shape, 38 std::initializer_list<int> key_shape, 39 std::initializer_list<int> labelweight_shape, int num_output, 40 float threshold) { 41 input_signature_ = AddInput(TensorType_INT32); 42 model_key_ = AddInput(TensorType_INT32); 43 model_label_ = AddInput(TensorType_INT32); 44 model_weight_ = AddInput(TensorType_FLOAT32); 45 output_label_ = AddOutput(TensorType_INT32); 46 output_weight_ = AddOutput(TensorType_FLOAT32); 47 48 std::vector<uint8_t> predict_option; 49 writeInt32(num_output, &predict_option); 50 writeFloat32(threshold, &predict_option); 51 SetCustomOp("Predict", predict_option, Register_PREDICT); 52 BuildInterpreter({{input_signature_shape, key_shape, labelweight_shape, 53 labelweight_shape}}); 54 } 55 56 void SetInputSignature(std::initializer_list<int> data) { 57 PopulateTensor<int>(input_signature_, data); 58 } 59 60 void SetModelKey(std::initializer_list<int> data) { 61 PopulateTensor<int>(model_key_, data); 62 } 63 64 void SetModelLabel(std::initializer_list<int> data) { 65 PopulateTensor<int>(model_label_, data); 66 } 67 68 void SetModelWeight(std::initializer_list<float> data) { 69 PopulateTensor<float>(model_weight_, data); 70 } 71 72 std::vector<int> GetLabel() { return ExtractVector<int>(output_label_); } 73 std::vector<float> GetWeight() { 74 return ExtractVector<float>(output_weight_); 75 } 76 77 void writeFloat32(float value, std::vector<uint8_t>* data) { 78 union { 79 float v; 80 uint8_t r[4]; 81 } float_to_raw; 82 float_to_raw.v = value; 83 for (unsigned char i : float_to_raw.r) { 84 data->push_back(i); 85 } 86 } 87 88 void writeInt32(int32_t value, std::vector<uint8_t>* data) { 89 union { 90 int32_t v; 91 uint8_t r[4]; 92 } int32_to_raw; 93 int32_to_raw.v = value; 94 for (unsigned char i : int32_to_raw.r) { 95 data->push_back(i); 96 } 97 } 98 99 private: 100 int input_signature_; 101 int model_key_; 102 int model_label_; 103 int model_weight_; 104 int output_label_; 105 int output_weight_; 106}; 107 108TEST(PredictOpTest, AllLabelsAreValid) { 109 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 110 m.SetInputSignature({1, 3, 7, 9}); 111 m.SetModelKey({1, 2, 4, 6, 7}); 112 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 113 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 114 m.Invoke(); 115 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, 11})); 116 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0.05}))); 117} 118 119TEST(PredictOpTest, MoreLabelsThanRequired) { 120 PredictOpModel m({4}, {5}, {5, 2}, 1, 0.0001); 121 m.SetInputSignature({1, 3, 7, 9}); 122 m.SetModelKey({1, 2, 4, 6, 7}); 123 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 124 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 125 m.Invoke(); 126 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12})); 127 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1}))); 128} 129 130TEST(PredictOpTest, OneLabelDoesNotPassThreshold) { 131 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.07); 132 m.SetInputSignature({1, 3, 7, 9}); 133 m.SetModelKey({1, 2, 4, 6, 7}); 134 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 135 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 136 m.Invoke(); 137 EXPECT_THAT(m.GetLabel(), ElementsAreArray({12, -1})); 138 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.1, 0}))); 139} 140 141TEST(PredictOpTest, NoneLabelPassThreshold) { 142 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.6); 143 m.SetInputSignature({1, 3, 7, 9}); 144 m.SetModelKey({1, 2, 4, 6, 7}); 145 m.SetModelLabel({11, 12, 11, 12, 11, 12, 11, 12, 11, 12}); 146 m.SetModelWeight({0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2}); 147 m.Invoke(); 148 EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); 149 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); 150} 151 152TEST(PredictOpTest, OnlyOneLabelGenerated) { 153 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 154 m.SetInputSignature({1, 3, 7, 9}); 155 m.SetModelKey({1, 2, 4, 6, 7}); 156 m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 11, 0}); 157 m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0}); 158 m.Invoke(); 159 EXPECT_THAT(m.GetLabel(), ElementsAreArray({11, -1})); 160 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0.05, 0}))); 161} 162 163TEST(PredictOpTest, NoLabelGenerated) { 164 PredictOpModel m({4}, {5}, {5, 2}, 2, 0.0001); 165 m.SetInputSignature({5, 3, 7, 9}); 166 m.SetModelKey({1, 2, 4, 6, 7}); 167 m.SetModelLabel({11, 0, 11, 0, 11, 0, 11, 0, 0, 0}); 168 m.SetModelWeight({0.1, 0, 0.1, 0, 0.1, 0, 0.1, 0, 0, 0}); 169 m.Invoke(); 170 EXPECT_THAT(m.GetLabel(), ElementsAreArray({-1, -1})); 171 EXPECT_THAT(m.GetWeight(), ElementsAreArray(ArrayFloatNear({0, 0}))); 172} 173 174} // namespace 175} // namespace custom 176} // namespace ops 177} // namespace tflite 178 179int main(int argc, char** argv) { 180 // On Linux, add: tflite::LogToStderr(); 181 ::testing::InitGoogleTest(&argc, argv); 182 return RUN_ALL_TESTS(); 183} 184