1/* 2 * Copyright (C) 2012 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17// Purpose: A container for sparse weight vectors 18// Maintains the sparse vector as a list of (name, value) pairs alongwith 19// a normalizer_. All operations assume that (name, value/normalizer_) is the 20// true value in question. 21 22#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 23#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 24 25#include <math.h> 26 27#include <iosfwd> 28#include <sstream> 29#include <string> 30#include <unordered_map> 31 32#include "common_defs.h" 33 34namespace learning_stochastic_linear { 35 36template<class Key = std::string, class Hash = std::unordered_map<Key, double> > 37class SparseWeightVector { 38 public: 39 typedef Hash Wmap; 40 typedef typename Wmap::iterator Witer; 41 typedef typename Wmap::const_iterator Witer_const; 42 SparseWeightVector() { 43 normalizer_ = 1.0; 44 } 45 ~SparseWeightVector() {} 46 explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) { 47 CopyFrom(other); 48 } 49 void operator=(const SparseWeightVector<Key, Hash> &other) { 50 CopyFrom(other); 51 } 52 void CopyFrom(const SparseWeightVector<Key, Hash> &other) { 53 w_ = other.w_; 54 wmin_ = other.wmin_; 55 wmax_ = other.wmax_; 56 normalizer_ = other.normalizer_; 57 } 58 59 // This function implements checks to prevent unbounded vectors. It returns 60 // true if the checks succeed and false otherwise. A vector is deemed invalid 61 // if any of these conditions are met: 62 // 1. it has no values. 63 // 2. its normalizer is nan or inf or close to zero. 64 // 3. any of its values are nan or inf. 65 // 4. its L0 norm is close to zero. 66 bool IsValid() const; 67 68 // Normalizer getters and setters. 69 double GetNormalizer() const { 70 return normalizer_; 71 } 72 void SetNormalizer(const double norm) { 73 normalizer_ = norm; 74 } 75 void NormalizerMultUpdate(const double mul) { 76 normalizer_ = normalizer_ * mul; 77 } 78 void NormalizerAddUpdate(const double add) { 79 normalizer_ += add; 80 } 81 82 // Divides all the values by the normalizer, then it resets it to 1.0 83 void ResetNormalizer(); 84 85 // Bound getters and setters. 86 // True if there is a bound with val containing the bound. false otherwise. 87 bool GetElementMinBound(const Key &fname, double *val) const { 88 return GetValue(wmin_, fname, val); 89 } 90 bool GetElementMaxBound(const Key &fname, double *val) const { 91 return GetValue(wmax_, fname, val); 92 } 93 void SetElementMinBound(const Key &fname, const double bound) { 94 wmin_[fname] = bound; 95 } 96 void SetElementMaxBound(const Key &fname, const double bound) { 97 wmax_[fname] = bound; 98 } 99 // Element getters and setters. 100 double GetElement(const Key &fname) const { 101 double val = 0; 102 GetValue(w_, fname, &val); 103 return val; 104 } 105 void SetElement(const Key &fname, const double val) { 106 //DCHECK(!isnan(val)); 107 w_[fname] = val; 108 } 109 void AddUpdateElement(const Key &fname, const double val) { 110 w_[fname] += val; 111 } 112 void MultUpdateElement(const Key &fname, const double val) { 113 w_[fname] *= val; 114 } 115 // Load another weight vectors. Will overwrite the current vector. 116 void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) { 117 w_.clear(); 118 w_.insert(vec.w_.begin(), vec.w_.end()); 119 wmax_.insert(vec.wmax_.begin(), vec.wmax_.end()); 120 wmin_.insert(vec.wmin_.begin(), vec.wmin_.end()); 121 normalizer_ = vec.normalizer_; 122 } 123 void Clear() { 124 w_.clear(); 125 wmax_.clear(); 126 wmin_.clear(); 127 } 128 const Wmap& GetMap() const { 129 return w_; 130 } 131 // Vector Operations. 132 void AdditiveWeightUpdate(const double multiplier, 133 const SparseWeightVector<Key, Hash> &w1, 134 const double additive_const); 135 void AdditiveSquaredWeightUpdate(const double multiplier, 136 const SparseWeightVector<Key, Hash> &w1, 137 const double additive_const); 138 void AdditiveInvSqrtWeightUpdate(const double multiplier, 139 const SparseWeightVector<Key, Hash> &w1, 140 const double additive_const); 141 void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1); 142 double DotProduct(const SparseWeightVector<Key, Hash> &s) const; 143 // L-x norm. eg. L1, L2. 144 double LxNorm(const double x) const; 145 double L2Norm() const; 146 double L1Norm() const; 147 double L0Norm(const double epsilon) const; 148 // Bound preserving updates. 149 void AdditiveWeightUpdateBounded(const double multiplier, 150 const SparseWeightVector<Key, Hash> &w1, 151 const double additive_const); 152 void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1); 153 void ReprojectToBounds(); 154 void ReprojectL0(const double l0_norm); 155 void ReprojectL1(const double l1_norm); 156 void ReprojectL2(const double l2_norm); 157 // Reproject using the given norm. 158 // Will also rescale regularizer_ if it gets too small/large. 159 int32 Reproject(const double norm, const RegularizationType r); 160 // Convert this vector to a string, simply for debugging. 161 std::string DebugString() const { 162 std::stringstream stream; 163 stream << *this; 164 return stream.str(); 165 } 166 private: 167 // The weight map. 168 Wmap w_; 169 // Constraint bounds. 170 Wmap wmin_; 171 Wmap wmax_; 172 // Normalizing constant in magnitude measurement. 173 double normalizer_; 174 // This function is necessary since by default unordered_map inserts an 175 // element if it does not find the key through [] operator. It implements a 176 // lookup without the space overhead of an add. 177 bool GetValue(const Wmap &w1, const Key &fname, double *val) const { 178 Witer_const iter = w1.find(fname); 179 if (iter != w1.end()) { 180 (*val) = iter->second; 181 return true; 182 } else { 183 (*val) = 0; 184 return false; 185 } 186 } 187}; 188 189// Outputs a SparseWeightVector, for debugging. 190template <class Key, class Hash> 191std::ostream& operator<<(std::ostream &stream, 192 const SparseWeightVector<Key, Hash> &vector) { 193 typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap(); 194 stream << "[[ "; 195 for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin(); 196 iter != w_map.end(); 197 ++iter) { 198 stream << "<" << iter->first << ", " << iter->second << "> "; 199 } 200 return stream << " ]]"; 201}; 202 203} // namespace learning_stochastic_linear 204#endif // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 205