1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
18#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
19
20#include <algorithm>
21#include <string>
22
23#include "common/float16.h"
24#include "common/task-context.h"
25#include "common/task-spec.pb.h"
26#include "util/base/logging.h"
27
28namespace libtextclassifier {
29namespace nlp_core {
30
31enum class QuantizationType { NONE = 0, UINT8 };
32
33// API for accessing parameters for a feed-forward neural network with
34// embeddings.
35//
36// Note: this API is closely related to embedding-network.proto.  The reason we
37// have a separate API is that the proto may not be the only way of packaging
38// these parameters.
39class EmbeddingNetworkParams {
40 public:
41  virtual ~EmbeddingNetworkParams() {}
42
43  // **** High-level API.
44
45  // Simple representation of a matrix.  This small struct that doesn't own any
46  // resource intentionally supports copy / assign, to simplify our APIs.
47  struct Matrix {
48    // Number of rows.
49    int rows;
50
51    // Number of columns.
52    int cols;
53
54    QuantizationType quant_type;
55
56    // Pointer to matrix elements, in row-major order
57    // (https://en.wikipedia.org/wiki/Row-major_order) Not owned.
58    const void *elements;
59
60    // Quantization scales: one scale for each row.
61    const float16 *quant_scales;
62  };
63
64  // Returns number of embedding spaces.
65  int GetNumEmbeddingSpaces() const {
66    if (embeddings_size() != embedding_num_features_size()) {
67      TC_LOG(ERROR) << "Embedding spaces mismatch " << embeddings_size()
68                    << " != " << embedding_num_features_size();
69    }
70    return std::max(0,
71                    std::min(embeddings_size(), embedding_num_features_size()));
72  }
73
74  // Returns embedding matrix for the i-th embedding space.
75  //
76  // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
77  // otherwise.
78  Matrix GetEmbeddingMatrix(int i) const {
79    TC_DCHECK(InRange(i, embeddings_size()));
80    Matrix matrix;
81    matrix.rows = embeddings_num_rows(i);
82    matrix.cols = embeddings_num_cols(i);
83    matrix.elements = embeddings_weights(i);
84    matrix.quant_type = embeddings_quant_type(i);
85    matrix.quant_scales = embeddings_quant_scales(i);
86    return matrix;
87  }
88
89  // Returns number of features in i-th embedding space.
90  //
91  // NOTE: i must be in [0, GetNumEmbeddingSpaces()).  Undefined behavior
92  // otherwise.
93  int GetNumFeaturesInEmbeddingSpace(int i) const {
94    TC_DCHECK(InRange(i, embedding_num_features_size()));
95    return std::max(0, embedding_num_features(i));
96  }
97
98  // Returns number of hidden layers in the neural network.  Each such layer has
99  // weight matrix and a bias vector (a matrix with one column).
100  int GetNumHiddenLayers() const {
101    if (hidden_size() != hidden_bias_size()) {
102      TC_LOG(ERROR) << "Hidden layer mismatch " << hidden_size()
103                    << " != " << hidden_bias_size();
104    }
105    return std::max(0, std::min(hidden_size(), hidden_bias_size()));
106  }
107
108  // Returns weight matrix for i-th hidden layer.
109  //
110  // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
111  // otherwise.
112  Matrix GetHiddenLayerMatrix(int i) const {
113    TC_DCHECK(InRange(i, hidden_size()));
114    Matrix matrix;
115    matrix.rows = hidden_num_rows(i);
116    matrix.cols = hidden_num_cols(i);
117
118    // Quantization not supported here.
119    matrix.quant_type = QuantizationType::NONE;
120    matrix.elements = hidden_weights(i);
121    return matrix;
122  }
123
124  // Returns bias matrix for i-th hidden layer.  Technically a Matrix, but we
125  // expect it to be a vector (i.e., num cols is 1).
126  //
127  // NOTE: i must be in [0, GetNumHiddenLayers()).  Undefined behavior
128  // otherwise.
129  Matrix GetHiddenLayerBias(int i) const {
130    TC_DCHECK(InRange(i, hidden_bias_size()));
131    Matrix matrix;
132    matrix.rows = hidden_bias_num_rows(i);
133    matrix.cols = hidden_bias_num_cols(i);
134
135    // Quantization not supported here.
136    matrix.quant_type = QuantizationType::NONE;
137    matrix.elements = hidden_bias_weights(i);
138    return matrix;
139  }
140
141  // Returns true if a softmax layer exists.
142  bool HasSoftmaxLayer() const {
143    if (softmax_size() != softmax_bias_size()) {
144      TC_LOG(ERROR) << "Softmax layer mismatch " << softmax_size()
145                    << " != " << softmax_bias_size();
146    }
147    return (softmax_size() == 1) && (softmax_bias_size() == 1);
148  }
149
150  // Returns weight matrix for the softmax layer.
151  //
152  // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
153  // behavior otherwise.
154  Matrix GetSoftmaxMatrix() const {
155    TC_DCHECK(softmax_size() == 1);
156    Matrix matrix;
157    matrix.rows = softmax_num_rows(0);
158    matrix.cols = softmax_num_cols(0);
159
160    // Quantization not supported here.
161    matrix.quant_type = QuantizationType::NONE;
162    matrix.elements = softmax_weights(0);
163    return matrix;
164  }
165
166  // Returns bias for the softmax layer.  Technically a Matrix, but we expect it
167  // to be a row/column vector (i.e., num cols is 1).
168  //
169  // NOTE: Should be called only if HasSoftmaxLayer() is true.  Undefined
170  // behavior otherwise.
171  Matrix GetSoftmaxBias() const {
172    TC_DCHECK(softmax_bias_size() == 1);
173    Matrix matrix;
174    matrix.rows = softmax_bias_num_rows(0);
175    matrix.cols = softmax_bias_num_cols(0);
176
177    // Quantization not supported here.
178    matrix.quant_type = QuantizationType::NONE;
179    matrix.elements = softmax_bias_weights(0);
180    return matrix;
181  }
182
183  // Updates the EmbeddingNetwork-related parameters from task_context.  Returns
184  // true on success, false on error.
185  virtual bool UpdateTaskContextParameters(TaskContext *task_context) {
186    const TaskSpec *task_spec = GetTaskSpec();
187    if (task_spec == nullptr) {
188      TC_LOG(ERROR) << "Unable to get TaskSpec";
189      return false;
190    }
191    for (const TaskSpec::Parameter &parameter : task_spec->parameter()) {
192      task_context->SetParameter(parameter.name(), parameter.value());
193    }
194    return true;
195  }
196
197  // Returns a pointer to a TaskSpec with the EmbeddingNetwork-related
198  // parameters.  Returns nullptr in case of problems.  Ownership with the
199  // returned pointer is *not* transfered to the caller.
200  virtual const TaskSpec *GetTaskSpec() {
201    TC_LOG(ERROR) << "Not implemented";
202    return nullptr;
203  }
204
205 protected:
206  // **** Low-level API.
207  //
208  // * Most low-level API methods are documented by giving an equivalent
209  //   function call on proto, the original proto (of type
210  //   EmbeddingNetworkProto) which was used to generate the C++ code.
211  //
212  // * To simplify our generation code, optional proto fields of message type
213  //   are treated as repeated fields with 0 or 1 instances.  As such, we have
214  //   *_size() methods for such optional fields: they return 0 or 1.
215  //
216  // * "transpose(M)" denotes the transpose of a matrix M.
217  //
218  // * Behavior is undefined when trying to retrieve a piece of data that does
219  //   not exist: e.g., embeddings_num_rows(5) if embeddings_size() == 2.
220
221  // ** Access methods for repeated MatrixParams embeddings.
222  //
223  // Returns proto.embeddings_size().
224  virtual int embeddings_size() const = 0;
225
226  // Returns number of rows of transpose(proto.embeddings(i)).
227  virtual int embeddings_num_rows(int i) const = 0;
228
229  // Returns number of columns of transpose(proto.embeddings(i)).
230  virtual int embeddings_num_cols(int i) const = 0;
231
232  // Returns pointer to elements of transpose(proto.embeddings(i)), in row-major
233  // order.  NOTE: for unquantized embeddings, this returns a pointer to float;
234  // for quantized embeddings, this returns a pointer to uint8.
235  virtual const void *embeddings_weights(int i) const = 0;
236
237  virtual QuantizationType embeddings_quant_type(int i) const {
238    return QuantizationType::NONE;
239  }
240
241  virtual const float16 *embeddings_quant_scales(int i) const {
242    return nullptr;
243  }
244
245  // ** Access methods for repeated MatrixParams hidden.
246  //
247  // Returns embedding_network_proto.hidden_size().
248  virtual int hidden_size() const = 0;
249
250  // Returns embedding_network_proto.hidden(i).rows().
251  virtual int hidden_num_rows(int i) const = 0;
252
253  // Returns embedding_network_proto.hidden(i).rows().
254  virtual int hidden_num_cols(int i) const = 0;
255
256  // Returns pointer to beginning of array of floats with all values from
257  // embedding_network_proto.hidden(i).
258  virtual const void *hidden_weights(int i) const = 0;
259
260  // ** Access methods for repeated MatrixParams hidden_bias.
261  //
262  // Returns proto.hidden_bias_size().
263  virtual int hidden_bias_size() const = 0;
264
265  // Returns number of rows of proto.hidden_bias(i).
266  virtual int hidden_bias_num_rows(int i) const = 0;
267
268  // Returns number of columns of proto.hidden_bias(i).
269  virtual int hidden_bias_num_cols(int i) const = 0;
270
271  // Returns pointer to elements of proto.hidden_bias(i), in row-major order.
272  virtual const void *hidden_bias_weights(int i) const = 0;
273
274  // ** Access methods for optional MatrixParams softmax.
275  //
276  // Returns 1 if proto has optional field softmax, 0 otherwise.
277  virtual int softmax_size() const = 0;
278
279  // Returns number of rows of transpose(proto.softmax()).
280  virtual int softmax_num_rows(int i) const = 0;
281
282  // Returns number of columns of transpose(proto.softmax()).
283  virtual int softmax_num_cols(int i) const = 0;
284
285  // Returns pointer to elements of transpose(proto.softmax()), in row-major
286  // order.
287  virtual const void *softmax_weights(int i) const = 0;
288
289  // ** Access methods for optional MatrixParams softmax_bias.
290  //
291  // Returns 1 if proto has optional field softmax_bias, 0 otherwise.
292  virtual int softmax_bias_size() const = 0;
293
294  // Returns number of rows of proto.softmax_bias().
295  virtual int softmax_bias_num_rows(int i) const = 0;
296
297  // Returns number of columns of proto.softmax_bias().
298  virtual int softmax_bias_num_cols(int i) const = 0;
299
300  // Returns pointer to elements of proto.softmax_bias(), in row-major order.
301  virtual const void *softmax_bias_weights(int i) const = 0;
302
303  // ** Access methods for repeated int32 embedding_num_features.
304  //
305  // Returns proto.embedding_num_features_size().
306  virtual int embedding_num_features_size() const = 0;
307
308  // Returns proto.embedding_num_features(i).
309  virtual int embedding_num_features(int i) const = 0;
310
311  // Returns true if and only if index is in range [0, size).  Log an error
312  // message otherwise.
313  static bool InRange(int index, int size) {
314    if ((index < 0) || (index >= size)) {
315      TC_LOG(ERROR) << "Index " << index << " outside [0, " << size << ")";
316      return false;
317    }
318    return true;
319  }
320};  // class EmbeddingNetworkParams
321
322}  // namespace nlp_core
323}  // namespace libtextclassifier
324
325#endif  // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_H_
326