1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#ifndef LIBTEXTCLASSIFIER_TYPES_H_
18#define LIBTEXTCLASSIFIER_TYPES_H_
19
20#include <algorithm>
21#include <cmath>
22#include <functional>
23#include <set>
24#include <string>
25#include <utility>
26#include <vector>
27#include "util/base/integral_types.h"
28
29#include "util/base/logging.h"
30
31namespace libtextclassifier2 {
32
33constexpr int kInvalidIndex = -1;
34
35// Index for a 0-based array of tokens.
36using TokenIndex = int;
37
38// Index for a 0-based array of codepoints.
39using CodepointIndex = int;
40
41// Marks a span in a sequence of codepoints. The first element is the index of
42// the first codepoint of the span, and the second element is the index of the
43// codepoint one past the end of the span.
44// TODO(b/71982294): Make it a struct.
45using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>;
46
47inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) {
48  return a.first < b.second && b.first < a.second;
49}
50
51inline bool ValidNonEmptySpan(const CodepointSpan& span) {
52  return span.first < span.second && span.first >= 0 && span.second >= 0;
53}
54
55template <typename T>
56bool DoesCandidateConflict(
57    const int considered_candidate, const std::vector<T>& candidates,
58    const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) {
59  if (chosen_indices_set.empty()) {
60    return false;
61  }
62
63  auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate);
64  // Check conflict on the right.
65  if (conflicting_it != chosen_indices_set.end() &&
66      SpansOverlap(candidates[considered_candidate].span,
67                   candidates[*conflicting_it].span)) {
68    return true;
69  }
70
71  // Check conflict on the left.
72  // If we can't go more left, there can't be a conflict:
73  if (conflicting_it == chosen_indices_set.begin()) {
74    return false;
75  }
76  // Otherwise move one span left and insert if it doesn't overlap with the
77  // candidate.
78  --conflicting_it;
79  if (!SpansOverlap(candidates[considered_candidate].span,
80                    candidates[*conflicting_it].span)) {
81    return false;
82  }
83
84  return true;
85}
86
87// Marks a span in a sequence of tokens. The first element is the index of the
88// first token in the span, and the second element is the index of the token one
89// past the end of the span.
90// TODO(b/71982294): Make it a struct.
91using TokenSpan = std::pair<TokenIndex, TokenIndex>;
92
93// Returns the size of the token span. Assumes that the span is valid.
94inline int TokenSpanSize(const TokenSpan& token_span) {
95  return token_span.second - token_span.first;
96}
97
98// Returns a token span consisting of one token.
99inline TokenSpan SingleTokenSpan(int token_index) {
100  return {token_index, token_index + 1};
101}
102
103// Returns an intersection of two token spans. Assumes that both spans are valid
104// and overlapping.
105inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1,
106                                     const TokenSpan& token_span2) {
107  return {std::max(token_span1.first, token_span2.first),
108          std::min(token_span1.second, token_span2.second)};
109}
110
111// Returns and expanded token span by adding a certain number of tokens on its
112// left and on its right.
113inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span,
114                                 int num_tokens_left, int num_tokens_right) {
115  return {token_span.first - num_tokens_left,
116          token_span.second + num_tokens_right};
117}
118
119// Token holds a token, its position in the original string and whether it was
120// part of the input span.
121struct Token {
122  std::string value;
123  CodepointIndex start;
124  CodepointIndex end;
125
126  // Whether the token is a padding token.
127  bool is_padding;
128
129  // Default constructor constructs the padding-token.
130  Token()
131      : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
132
133  Token(const std::string& arg_value, CodepointIndex arg_start,
134        CodepointIndex arg_end)
135      : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
136
137  bool operator==(const Token& other) const {
138    return value == other.value && start == other.start && end == other.end &&
139           is_padding == other.is_padding;
140  }
141
142  bool IsContainedInSpan(CodepointSpan span) const {
143    return start >= span.first && end <= span.second;
144  }
145};
146
147// Pretty-printing function for Token.
148inline logging::LoggingStringStream& operator<<(
149    logging::LoggingStringStream& stream, const Token& token) {
150  if (!token.is_padding) {
151    return stream << "Token(\"" << token.value << "\", " << token.start << ", "
152                  << token.end << ")";
153  } else {
154    return stream << "Token()";
155  }
156}
157
158enum DatetimeGranularity {
159  GRANULARITY_UNKNOWN = -1,  // GRANULARITY_UNKNOWN is used as a proxy for this
160                             // structure being uninitialized.
161  GRANULARITY_YEAR = 0,
162  GRANULARITY_MONTH = 1,
163  GRANULARITY_WEEK = 2,
164  GRANULARITY_DAY = 3,
165  GRANULARITY_HOUR = 4,
166  GRANULARITY_MINUTE = 5,
167  GRANULARITY_SECOND = 6
168};
169
170struct DatetimeParseResult {
171  // The absolute time in milliseconds since the epoch in UTC. This is derived
172  // from the reference time and the fields specified in the text - so it may
173  // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm)
174  int64 time_ms_utc;
175
176  // The precision of the estimate then in to calculating the milliseconds
177  DatetimeGranularity granularity;
178
179  DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {}
180
181  DatetimeParseResult(int64 arg_time_ms_utc,
182                      DatetimeGranularity arg_granularity)
183      : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {}
184
185  bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; }
186
187  bool operator==(const DatetimeParseResult& other) const {
188    return granularity == other.granularity && time_ms_utc == other.time_ms_utc;
189  }
190};
191
192const float kFloatCompareEpsilon = 1e-5;
193
194struct DatetimeParseResultSpan {
195  CodepointSpan span;
196  DatetimeParseResult data;
197  float target_classification_score;
198  float priority_score;
199
200  bool operator==(const DatetimeParseResultSpan& other) const {
201    return span == other.span && data.granularity == other.data.granularity &&
202           data.time_ms_utc == other.data.time_ms_utc &&
203           std::abs(target_classification_score -
204                    other.target_classification_score) < kFloatCompareEpsilon &&
205           std::abs(priority_score - other.priority_score) <
206               kFloatCompareEpsilon;
207  }
208};
209
210// Pretty-printing function for DatetimeParseResultSpan.
211inline logging::LoggingStringStream& operator<<(
212    logging::LoggingStringStream& stream,
213    const DatetimeParseResultSpan& value) {
214  return stream << "DatetimeParseResultSpan({" << value.span.first << ", "
215                << value.span.second << "}, {/*time_ms_utc=*/ "
216                << value.data.time_ms_utc << ", /*granularity=*/ "
217                << value.data.granularity << "})";
218}
219
220struct ClassificationResult {
221  std::string collection;
222  float score;
223  DatetimeParseResult datetime_parse_result;
224
225  // Internal score used for conflict resolution.
226  float priority_score;
227
228  explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {}
229
230  ClassificationResult(const std::string& arg_collection, float arg_score)
231      : collection(arg_collection),
232        score(arg_score),
233        priority_score(arg_score) {}
234
235  ClassificationResult(const std::string& arg_collection, float arg_score,
236                       float arg_priority_score)
237      : collection(arg_collection),
238        score(arg_score),
239        priority_score(arg_priority_score) {}
240};
241
242// Pretty-printing function for ClassificationResult.
243inline logging::LoggingStringStream& operator<<(
244    logging::LoggingStringStream& stream, const ClassificationResult& result) {
245  return stream << "ClassificationResult(" << result.collection << ", "
246                << result.score << ")";
247}
248
249// Pretty-printing function for std::vector<ClassificationResult>.
250inline logging::LoggingStringStream& operator<<(
251    logging::LoggingStringStream& stream,
252    const std::vector<ClassificationResult>& results) {
253  stream = stream << "{\n";
254  for (const ClassificationResult& result : results) {
255    stream = stream << "    " << result << "\n";
256  }
257  stream = stream << "}";
258  return stream;
259}
260
261// Represents a result of Annotate call.
262struct AnnotatedSpan {
263  // Unicode codepoint indices in the input string.
264  CodepointSpan span = {kInvalidIndex, kInvalidIndex};
265
266  // Classification result for the span.
267  std::vector<ClassificationResult> classification;
268};
269
270// Pretty-printing function for AnnotatedSpan.
271inline logging::LoggingStringStream& operator<<(
272    logging::LoggingStringStream& stream, const AnnotatedSpan& span) {
273  std::string best_class;
274  float best_score = -1;
275  if (!span.classification.empty()) {
276    best_class = span.classification[0].collection;
277    best_score = span.classification[0].score;
278  }
279  return stream << "Span(" << span.span.first << ", " << span.span.second
280                << ", " << best_class << ", " << best_score << ")";
281}
282
283// StringPiece analogue for std::vector<T>.
284template <class T>
285class VectorSpan {
286 public:
287  VectorSpan() : begin_(), end_() {}
288  VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
289      : begin_(v.begin()), end_(v.end()) {}
290  VectorSpan(typename std::vector<T>::const_iterator begin,
291             typename std::vector<T>::const_iterator end)
292      : begin_(begin), end_(end) {}
293
294  const T& operator[](typename std::vector<T>::size_type i) const {
295    return *(begin_ + i);
296  }
297
298  int size() const { return end_ - begin_; }
299  typename std::vector<T>::const_iterator begin() const { return begin_; }
300  typename std::vector<T>::const_iterator end() const { return end_; }
301  const float* data() const { return &(*begin_); }
302
303 private:
304  typename std::vector<T>::const_iterator begin_;
305  typename std::vector<T>::const_iterator end_;
306};
307
308struct DateParseData {
309  enum Relation {
310    NEXT = 1,
311    NEXT_OR_SAME = 2,
312    LAST = 3,
313    NOW = 4,
314    TOMORROW = 5,
315    YESTERDAY = 6,
316    PAST = 7,
317    FUTURE = 8
318  };
319
320  enum RelationType {
321    MONDAY = 1,
322    TUESDAY = 2,
323    WEDNESDAY = 3,
324    THURSDAY = 4,
325    FRIDAY = 5,
326    SATURDAY = 6,
327    SUNDAY = 7,
328    DAY = 8,
329    WEEK = 9,
330    MONTH = 10,
331    YEAR = 11
332  };
333
334  enum Fields {
335    YEAR_FIELD = 1 << 0,
336    MONTH_FIELD = 1 << 1,
337    DAY_FIELD = 1 << 2,
338    HOUR_FIELD = 1 << 3,
339    MINUTE_FIELD = 1 << 4,
340    SECOND_FIELD = 1 << 5,
341    AMPM_FIELD = 1 << 6,
342    ZONE_OFFSET_FIELD = 1 << 7,
343    DST_OFFSET_FIELD = 1 << 8,
344    RELATION_FIELD = 1 << 9,
345    RELATION_TYPE_FIELD = 1 << 10,
346    RELATION_DISTANCE_FIELD = 1 << 11
347  };
348
349  enum AMPM { AM = 0, PM = 1 };
350
351  enum TimeUnit {
352    DAYS = 1,
353    WEEKS = 2,
354    MONTHS = 3,
355    HOURS = 4,
356    MINUTES = 5,
357    SECONDS = 6,
358    YEARS = 7
359  };
360
361  // Bit mask of fields which have been set on the struct
362  int field_set_mask;
363
364  // Fields describing absolute date fields.
365  // Year of the date seen in the text match.
366  int year;
367  // Month of the year starting with January = 1.
368  int month;
369  // Day of the month starting with 1.
370  int day_of_month;
371  // Hour of the day with a range of 0-23,
372  // values less than 12 need the AMPM field below or heuristics
373  // to definitively determine the time.
374  int hour;
375  // Hour of the day with a range of 0-59.
376  int minute;
377  // Hour of the day with a range of 0-59.
378  int second;
379  // 0 == AM, 1 == PM
380  int ampm;
381  // Number of hours offset from UTC this date time is in.
382  int zone_offset;
383  // Number of hours offest for DST
384  int dst_offset;
385
386  // The permutation from now that was made to find the date time.
387  Relation relation;
388  // The unit of measure of the change to the date time.
389  RelationType relation_type;
390  // The number of units of change that were made.
391  int relation_distance;
392};
393
394}  // namespace libtextclassifier2
395
396#endif  // LIBTEXTCLASSIFIER_TYPES_H_
397