1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Purpose: A container for sparse weight vectors
18// Maintains the sparse vector as a list of (name, value) pairs alongwith
19// a normalizer_. All operations assume that (name, value/normalizer_) is the
20// true value in question.
21
22#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
23#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
24
25#include <math.h>
26
27#include <iosfwd>
28#include <sstream>
29#include <string>
30#include <unordered_map>
31
32#include "common_defs.h"
33
34namespace learning_stochastic_linear {
35
36template<class Key = std::string, class Hash = std::unordered_map<Key, double> >
37class SparseWeightVector {
38 public:
39  typedef Hash Wmap;
40  typedef typename Wmap::iterator Witer;
41  typedef typename Wmap::const_iterator Witer_const;
42  SparseWeightVector() {
43    normalizer_ = 1.0;
44  }
45  ~SparseWeightVector() {}
46  explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
47    CopyFrom(other);
48  }
49  void operator=(const SparseWeightVector<Key, Hash> &other) {
50    CopyFrom(other);
51  }
52  void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
53    w_ = other.w_;
54    wmin_ = other.wmin_;
55    wmax_ = other.wmax_;
56    normalizer_ = other.normalizer_;
57  }
58
59  // This function implements checks to prevent unbounded vectors. It returns
60  // true if the checks succeed and false otherwise. A vector is deemed invalid
61  // if any of these conditions are met:
62  // 1. it has no values.
63  // 2. its normalizer is nan or inf or close to zero.
64  // 3. any of its values are nan or inf.
65  // 4. its L0 norm is close to zero.
66  bool IsValid() const;
67
68  // Normalizer getters and setters.
69  double GetNormalizer() const {
70    return normalizer_;
71  }
72  void SetNormalizer(const double norm) {
73    normalizer_ = norm;
74  }
75  void NormalizerMultUpdate(const double mul) {
76    normalizer_ = normalizer_ * mul;
77  }
78  void NormalizerAddUpdate(const double add) {
79    normalizer_ += add;
80  }
81
82  // Divides all the values by the normalizer, then it resets it to 1.0
83  void ResetNormalizer();
84
85  // Bound getters and setters.
86  // True if there is a bound with val containing the bound. false otherwise.
87  bool GetElementMinBound(const Key &fname, double *val) const {
88    return GetValue(wmin_, fname, val);
89  }
90  bool GetElementMaxBound(const Key &fname, double *val) const {
91    return GetValue(wmax_, fname, val);
92  }
93  void SetElementMinBound(const Key &fname, const double bound) {
94    wmin_[fname] = bound;
95  }
96  void SetElementMaxBound(const Key &fname, const double bound) {
97    wmax_[fname] = bound;
98  }
99  // Element getters and setters.
100  double GetElement(const Key &fname) const {
101    double val = 0;
102    GetValue(w_, fname, &val);
103    return val;
104  }
105  void SetElement(const Key &fname, const double val) {
106    //DCHECK(!isnan(val));
107    w_[fname] = val;
108  }
109  void AddUpdateElement(const Key &fname, const double val) {
110    w_[fname] += val;
111  }
112  void MultUpdateElement(const Key &fname, const double val) {
113    w_[fname] *= val;
114  }
115  // Load another weight vectors. Will overwrite the current vector.
116  void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
117    w_.clear();
118    w_.insert(vec.w_.begin(), vec.w_.end());
119    wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
120    wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
121    normalizer_ = vec.normalizer_;
122  }
123  void Clear() {
124    w_.clear();
125    wmax_.clear();
126    wmin_.clear();
127  }
128  const Wmap& GetMap() const {
129    return w_;
130  }
131  // Vector Operations.
132  void AdditiveWeightUpdate(const double multiplier,
133                            const SparseWeightVector<Key, Hash> &w1,
134                            const double additive_const);
135  void AdditiveSquaredWeightUpdate(const double multiplier,
136                                   const SparseWeightVector<Key, Hash> &w1,
137                                   const double additive_const);
138  void AdditiveInvSqrtWeightUpdate(const double multiplier,
139                                   const SparseWeightVector<Key, Hash> &w1,
140                                   const double additive_const);
141  void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
142  double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
143  // L-x norm. eg. L1, L2.
144  double LxNorm(const double x) const;
145  double L2Norm() const;
146  double L1Norm() const;
147  double L0Norm(const double epsilon) const;
148  // Bound preserving updates.
149  void AdditiveWeightUpdateBounded(const double multiplier,
150                                   const SparseWeightVector<Key, Hash> &w1,
151                                   const double additive_const);
152  void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
153  void ReprojectToBounds();
154  void ReprojectL0(const double l0_norm);
155  void ReprojectL1(const double l1_norm);
156  void ReprojectL2(const double l2_norm);
157  // Reproject using the given norm.
158  // Will also rescale regularizer_ if it gets too small/large.
159  int32 Reproject(const double norm, const RegularizationType r);
160  // Convert this vector to a string, simply for debugging.
161  std::string DebugString() const {
162    std::stringstream stream;
163    stream << *this;
164    return stream.str();
165  }
166 private:
167  // The weight map.
168  Wmap w_;
169  // Constraint bounds.
170  Wmap wmin_;
171  Wmap wmax_;
172  // Normalizing constant in magnitude measurement.
173  double normalizer_;
174  // This function is necessary since by default unordered_map inserts an
175  // element if it does not find the key through [] operator. It implements a
176  // lookup without the space overhead of an add.
177  bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
178    Witer_const iter = w1.find(fname);
179    if (iter != w1.end()) {
180      (*val) = iter->second;
181      return true;
182    } else {
183      (*val) = 0;
184      return false;
185    }
186  }
187};
188
189// Outputs a SparseWeightVector, for debugging.
190template <class Key, class Hash>
191std::ostream& operator<<(std::ostream &stream,
192                    const SparseWeightVector<Key, Hash> &vector) {
193  typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
194  stream << "[[ ";
195  for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
196       iter != w_map.end();
197       ++iter) {
198    stream << "<" << iter->first << ", " << iter->second << "> ";
199  }
200  return stream << " ]]";
201};
202
203}  // namespace learning_stochastic_linear
204#endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
205