16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Purpose: A container for sparse weight vectors 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Maintains the sparse vector as a list of (name, value) pairs alongwith 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// a normalizer_. All operations assume that (name, value/normalizer_) is the 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// true value in question. 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <math.h> 26a08525ea290ff4edc766eda1ec80388be866a79eDan Albert 27a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <iosfwd> 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <sstream> 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <string> 30a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <unordered_map> 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h" 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 36a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate<class Key = std::string, class Hash = std::unordered_map<Key, double> > 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass SparseWeightVector { 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public: 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua typedef Hash Wmap; 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua typedef typename Wmap::iterator Witer; 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua typedef typename Wmap::const_iterator Witer_const; 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua SparseWeightVector() { 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = 1.0; 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ~SparseWeightVector() {} 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) { 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CopyFrom(other); 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void operator=(const SparseWeightVector<Key, Hash> &other) { 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CopyFrom(other); 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void CopyFrom(const SparseWeightVector<Key, Hash> &other) { 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_ = other.w_; 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmin_ = other.wmin_; 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmax_ = other.wmax_; 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = other.normalizer_; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // This function implements checks to prevent unbounded vectors. It returns 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // true if the checks succeed and false otherwise. A vector is deemed invalid 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // if any of these conditions are met: 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 1. it has no values. 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 2. its normalizer is nan or inf or close to zero. 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 3. any of its values are nan or inf. 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // 4. its L0 norm is close to zero. 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool IsValid() const; 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Normalizer getters and setters. 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetNormalizer() const { 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return normalizer_; 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetNormalizer(const double norm) { 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = norm; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void NormalizerMultUpdate(const double mul) { 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = normalizer_ * mul; 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void NormalizerAddUpdate(const double add) { 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ += add; 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Divides all the values by the normalizer, then it resets it to 1.0 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void ResetNormalizer(); 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Bound getters and setters. 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // True if there is a bound with val containing the bound. false otherwise. 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool GetElementMinBound(const Key &fname, double *val) const { 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return GetValue(wmin_, fname, val); 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool GetElementMaxBound(const Key &fname, double *val) const { 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return GetValue(wmax_, fname, val); 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetElementMinBound(const Key &fname, const double bound) { 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmin_[fname] = bound; 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetElementMaxBound(const Key &fname, const double bound) { 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmax_[fname] = bound; 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Element getters and setters. 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double GetElement(const Key &fname) const { 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double val = 0; 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua GetValue(w_, fname, &val); 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return val; 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void SetElement(const Key &fname, const double val) { 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua //DCHECK(!isnan(val)); 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[fname] = val; 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void AddUpdateElement(const Key &fname, const double val) { 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[fname] += val; 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void MultUpdateElement(const Key &fname, const double val) { 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[fname] *= val; 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Load another weight vectors. Will overwrite the current vector. 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) { 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_.clear(); 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_.insert(vec.w_.begin(), vec.w_.end()); 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmax_.insert(vec.wmax_.begin(), vec.wmax_.end()); 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmin_.insert(vec.wmin_.begin(), vec.wmin_.end()); 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = vec.normalizer_; 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void Clear() { 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_.clear(); 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmax_.clear(); 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua wmin_.clear(); 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const Wmap& GetMap() const { 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return w_; 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Vector Operations. 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void AdditiveWeightUpdate(const double multiplier, 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const); 1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void AdditiveSquaredWeightUpdate(const double multiplier, 1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const); 1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void AdditiveInvSqrtWeightUpdate(const double multiplier, 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const); 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1); 1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double DotProduct(const SparseWeightVector<Key, Hash> &s) const; 1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // L-x norm. eg. L1, L2. 1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double LxNorm(const double x) const; 1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double L2Norm() const; 1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double L1Norm() const; 1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double L0Norm(const double epsilon) const; 1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Bound preserving updates. 1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void AdditiveWeightUpdateBounded(const double multiplier, 1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const); 1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1); 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void ReprojectToBounds(); 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void ReprojectL0(const double l0_norm); 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void ReprojectL1(const double l1_norm); 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua void ReprojectL2(const double l2_norm); 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Reproject using the given norm. 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Will also rescale regularizer_ if it gets too small/large. 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int32 Reproject(const double norm, const RegularizationType r); 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Convert this vector to a string, simply for debugging. 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua std::string DebugString() const { 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua std::stringstream stream; 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua stream << *this; 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return stream.str(); 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private: 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // The weight map. 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua Wmap w_; 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Constraint bounds. 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua Wmap wmin_; 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua Wmap wmax_; 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Normalizing constant in magnitude measurement. 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double normalizer_; 174a08525ea290ff4edc766eda1ec80388be866a79eDan Albert // This function is necessary since by default unordered_map inserts an 175a08525ea290ff4edc766eda1ec80388be866a79eDan Albert // element if it does not find the key through [] operator. It implements a 176a08525ea290ff4edc766eda1ec80388be866a79eDan Albert // lookup without the space overhead of an add. 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool GetValue(const Wmap &w1, const Key &fname, double *val) const { 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua Witer_const iter = w1.find(fname); 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (iter != w1.end()) { 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (*val) = iter->second; 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return true; 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (*val) = 0; 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return false; 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Outputs a SparseWeightVector, for debugging. 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate <class Key, class Hash> 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastd::ostream& operator<<(std::ostream &stream, 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &vector) { 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap(); 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua stream << "[[ "; 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin(); 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_map.end(); 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua stream << "<" << iter->first << ", " << iter->second << "> "; 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return stream << " ]]"; 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}; 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_ 205