1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17// Purpose: A container for sparse weight vectors
18// Maintains the sparse vector as a list of (name, value) pairs alongwith
19// a normalizer_. All operations assume that (name, value/normalizer_) is the
20// true value in question.
21
22#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
23#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
24
25#include <hash_map>
26#include <iosfwd>
27#include <math.h>
28#include <sstream>
29#include <string>
30
31#include "common_defs.h"
32
33namespace learning_stochastic_linear {
34
35template<class Key = std::string, class Hash = std::hash_map<Key, double> >
36class SparseWeightVector {
37 public:
38  typedef Hash Wmap;
39  typedef typename Wmap::iterator Witer;
40  typedef typename Wmap::const_iterator Witer_const;
41  SparseWeightVector() {
42    normalizer_ = 1.0;
43  }
44  ~SparseWeightVector() {}
45  explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
46    CopyFrom(other);
47  }
48  void operator=(const SparseWeightVector<Key, Hash> &other) {
49    CopyFrom(other);
50  }
51  void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
52    w_ = other.w_;
53    wmin_ = other.wmin_;
54    wmax_ = other.wmax_;
55    normalizer_ = other.normalizer_;
56  }
57
58  // This function implements checks to prevent unbounded vectors. It returns
59  // true if the checks succeed and false otherwise. A vector is deemed invalid
60  // if any of these conditions are met:
61  // 1. it has no values.
62  // 2. its normalizer is nan or inf or close to zero.
63  // 3. any of its values are nan or inf.
64  // 4. its L0 norm is close to zero.
65  bool IsValid() const;
66
67  // Normalizer getters and setters.
68  double GetNormalizer() const {
69    return normalizer_;
70  }
71  void SetNormalizer(const double norm) {
72    normalizer_ = norm;
73  }
74  void NormalizerMultUpdate(const double mul) {
75    normalizer_ = normalizer_ * mul;
76  }
77  void NormalizerAddUpdate(const double add) {
78    normalizer_ += add;
79  }
80
81  // Divides all the values by the normalizer, then it resets it to 1.0
82  void ResetNormalizer();
83
84  // Bound getters and setters.
85  // True if there is a bound with val containing the bound. false otherwise.
86  bool GetElementMinBound(const Key &fname, double *val) const {
87    return GetValue(wmin_, fname, val);
88  }
89  bool GetElementMaxBound(const Key &fname, double *val) const {
90    return GetValue(wmax_, fname, val);
91  }
92  void SetElementMinBound(const Key &fname, const double bound) {
93    wmin_[fname] = bound;
94  }
95  void SetElementMaxBound(const Key &fname, const double bound) {
96    wmax_[fname] = bound;
97  }
98  // Element getters and setters.
99  double GetElement(const Key &fname) const {
100    double val = 0;
101    GetValue(w_, fname, &val);
102    return val;
103  }
104  void SetElement(const Key &fname, const double val) {
105    //DCHECK(!isnan(val));
106    w_[fname] = val;
107  }
108  void AddUpdateElement(const Key &fname, const double val) {
109    w_[fname] += val;
110  }
111  void MultUpdateElement(const Key &fname, const double val) {
112    w_[fname] *= val;
113  }
114  // Load another weight vectors. Will overwrite the current vector.
115  void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
116    w_.clear();
117    w_.insert(vec.w_.begin(), vec.w_.end());
118    wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
119    wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
120    normalizer_ = vec.normalizer_;
121  }
122  void Clear() {
123    w_.clear();
124    wmax_.clear();
125    wmin_.clear();
126  }
127  const Wmap& GetMap() const {
128    return w_;
129  }
130  // Vector Operations.
131  void AdditiveWeightUpdate(const double multiplier,
132                            const SparseWeightVector<Key, Hash> &w1,
133                            const double additive_const);
134  void AdditiveSquaredWeightUpdate(const double multiplier,
135                                   const SparseWeightVector<Key, Hash> &w1,
136                                   const double additive_const);
137  void AdditiveInvSqrtWeightUpdate(const double multiplier,
138                                   const SparseWeightVector<Key, Hash> &w1,
139                                   const double additive_const);
140  void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
141  double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
142  // L-x norm. eg. L1, L2.
143  double LxNorm(const double x) const;
144  double L2Norm() const;
145  double L1Norm() const;
146  double L0Norm(const double epsilon) const;
147  // Bound preserving updates.
148  void AdditiveWeightUpdateBounded(const double multiplier,
149                                   const SparseWeightVector<Key, Hash> &w1,
150                                   const double additive_const);
151  void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
152  void ReprojectToBounds();
153  void ReprojectL0(const double l0_norm);
154  void ReprojectL1(const double l1_norm);
155  void ReprojectL2(const double l2_norm);
156  // Reproject using the given norm.
157  // Will also rescale regularizer_ if it gets too small/large.
158  int32 Reproject(const double norm, const RegularizationType r);
159  // Convert this vector to a string, simply for debugging.
160  std::string DebugString() const {
161    std::stringstream stream;
162    stream << *this;
163    return stream.str();
164  }
165 private:
166  // The weight map.
167  Wmap w_;
168  // Constraint bounds.
169  Wmap wmin_;
170  Wmap wmax_;
171  // Normalizing constant in magnitude measurement.
172  double normalizer_;
173  // This function in necessary since by default hash_map inserts an element
174  // if it does not find the key through [] operator. It implements a lookup
175  // without the space overhead of an add.
176  bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
177    Witer_const iter = w1.find(fname);
178    if (iter != w1.end()) {
179      (*val) = iter->second;
180      return true;
181    } else {
182      (*val) = 0;
183      return false;
184    }
185  }
186};
187
188// Outputs a SparseWeightVector, for debugging.
189template <class Key, class Hash>
190std::ostream& operator<<(std::ostream &stream,
191                    const SparseWeightVector<Key, Hash> &vector) {
192  typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
193  stream << "[[ ";
194  for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
195       iter != w_map.end();
196       ++iter) {
197    stream << "<" << iter->first << ", " << iter->second << "> ";
198  }
199  return stream << " ]]";
200};
201
202}  // namespace learning_stochastic_linear
203#endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
204