sparse_weight_vector.cpp revision 6b4eebc73439cbc3ddfb547444a341d1f9be7996
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/* 26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project 36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License"); 56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License. 66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at 76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * http://www.apache.org/licenses/LICENSE-2.0 96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * 106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software 116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS, 126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and 146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License. 156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */ 166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "sparse_weight_vector.h" 186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <algorithm> 206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <list> 216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <vector> 226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <math.h> 236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector; 256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::list; 266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::max; 276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear { 296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Max/Min permitted values of normalizer_ for preventing under/overflows. 316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastatic double kNormalizerMin = 1e-20; 326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastatic double kNormalizerMax = 1e20; 336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huabool SparseWeightVector<Key, Hash>::IsValid() const { 366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (isnan(normalizer_) || __isinff(normalizer_)) 376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return false; 386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (isnanf(iter->second) || __isinff(iter->second)) 426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return false; 436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return true; 456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveWeightUpdate( 496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double multiplier, 506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const) { 526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w1.w_.begin(); 536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w1.w_.end(); 546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] += ((multiplier * iter->second) / w1.normalizer_ 566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua + additive_const) * normalizer_; 576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveSquaredWeightUpdate( 636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double multiplier, 646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const) { 666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w1.w_.begin(); 676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w1.w_.end(); 686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] += ((multiplier * iter->second * iter->second) / 706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (w1.normalizer_ * w1.normalizer_) 716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua + additive_const) * normalizer_; 726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveInvSqrtWeightUpdate( 786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double multiplier, 796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const) { 816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w1.w_.begin(); 826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w1.w_.end(); 836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if(iter->second > 0.0) { 856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] += ((multiplier * sqrt(w1.normalizer_)) / 866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (sqrt(iter->second)) 876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua + additive_const) * normalizer_; 886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveWeightUpdateBounded( 956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double multiplier, 966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1, 976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double additive_const) { 986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double min_bound = 0; 996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double max_bound = 0; 1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w1.w_.begin(); 1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w1.w_.end(); 1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] += ((multiplier * iter->second) / w1.normalizer_ 1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua + additive_const) * normalizer_; 1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound); 1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_min_bounded) { 1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((w_[iter->first] / normalizer_) < min_bound) { 1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] = min_bound*normalizer_; 1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua continue; 1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound); 1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_max_bounded) { 1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((w_[iter->first] / normalizer_) > max_bound) 1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua w_[iter->first] = max_bound*normalizer_; 1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::MultWeightUpdate( 1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1) { 1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second *= w1.GetElement(iter->first); 1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ *= w1.normalizer_; 1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::MultWeightUpdateBounded( 1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1) { 1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double min_bound = 0; 1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double max_bound = 0; 1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ *= w1.normalizer_; 1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second *= w1.GetElement(iter->first); 1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound); 1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_min_bounded) { 1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((iter->second / normalizer_) < min_bound) { 1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = min_bound*normalizer_; 1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua continue; 1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound); 1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_max_bounded) { 1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((iter->second / normalizer_) > max_bound) 1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = max_bound*normalizer_; 1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ResetNormalizer() { 1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second /= normalizer_; 1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ = 1.0; 1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectToBounds() { 1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double min_bound = 0; 1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double max_bound = 0; 1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound); 1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_min_bounded) { 1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((iter->second/normalizer_) < min_bound) { 1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = min_bound*normalizer_; 1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua continue; 1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound); 1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (is_max_bounded) { 1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((iter->second/normalizer_) > max_bound) 1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = max_bound*normalizer_; 1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::DotProduct( 1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const SparseWeightVector<Key, Hash> &w1) const { 1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double result = 0; 1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (w_.size() > w1.w_.size()) { 1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w1.w_.begin(); 1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w1.w_.end(); 2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += iter->second * GetElement(iter->first); 2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result /= (this->normalizer_ * w1.normalizer_); 2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += iter->second * w1.GetElement(iter->first); 2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result /= (this->normalizer_ * w1.normalizer_); 2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::LxNorm(const double x) const { 2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double result = 0; 2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CHECK_GT(x, 0); 2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += pow(iter->second, x); 2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return (pow(result, 1.0 / x) / normalizer_); 2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L2Norm() const { 2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double result = 0; 2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += iter->second * iter->second; 2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return sqrt(result)/normalizer_; 2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L1Norm() const { 2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double result = 0; 2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua result += fabs(iter->second); 2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result / normalizer_; 2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 2506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L0Norm( 2516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double epsilon) const { 2526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double result = 0; 2536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer_const iter = w_.begin(); 2546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (fabs(iter->second / normalizer_) > epsilon) 2576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++result; 2586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return result; 2606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 2626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Algorithm for L0 projection which takes O(n log(n)), where n is 2636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// the number of non-zero elements in the vector. 2646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 2656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL0(const double l0_norm) { 2666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// First calculates the order-statistics of the sparse vector 2676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// and then reprojects to the L0 orthant with the requested norm. 2686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CHECK_GT(l0_norm, 0); 2696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 req_l0_norm = static_cast<uint64>(l0_norm); 2706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute order statistics and the current L0 norm. 2716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua vector<double> abs_val_vec; 2726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua uint64 curr_l0_norm = 0; 2736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double epsilone = 1E-05; 2746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 2756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (fabs(iter->second/normalizer_) > epsilone) { 2786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua abs_val_vec.push_back(fabs(iter->second/normalizer_)); 2796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++curr_l0_norm; 2806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // check if a projection is necessary 2836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (curr_l0_norm < req_l0_norm) { 2846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 2856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua std::nth_element(&abs_val_vec[0], 2876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua &abs_val_vec[curr_l0_norm - req_l0_norm], 2886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua &abs_val_vec[curr_l0_norm]); 2896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const double theta = abs_val_vec[curr_l0_norm - req_l0_norm]; 2906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // compute the final projection. 2916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 2926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 2936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 2946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if ((fabs(iter->second/normalizer_) - theta) < 0) { 2956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = 0; 2966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 2986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 2996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 3006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Slow algorithm for accurate L1 projection which takes O(n log(n)), where n is 3016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// the number of non-zero elements in the vector. 3026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 3036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL1(const double l1_norm) { 3046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// First calculates the order-statistics of the sparse vector 3056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// applies a probability simplex projection to the abs(vector) 3066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// and reprojects back to the original with the appropriate sign. 3076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// For ref. see "Efficient Projections into the l1-ball for Learning 3086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// in High Dimensions" 3096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CHECK_GT(l1_norm, 0); 3106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute order statistics and the current L1 norm. 3116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua list<double> abs_val_list; 3126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double curr_l1_norm = 0; 3136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 3146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 3156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 3166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua abs_val_list.push_back(fabs(iter->second/normalizer_)); 3176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua curr_l1_norm += fabs(iter->second/normalizer_); 3186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // check if a projection is necessary 3206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (curr_l1_norm < l1_norm) { 3216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return; 3226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua abs_val_list.sort(); 3246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua abs_val_list.reverse(); 3256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Compute projection on the probability simplex. 3266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double curr_index = 1; 3276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double theta = 0; 3286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double cum_sum = 0; 3296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (list<double>::iterator val_iter = abs_val_list.begin(); 3306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua val_iter != abs_val_list.end(); 3316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++val_iter) { 3326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua cum_sum += *val_iter; 3336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua theta = (cum_sum - l1_norm)/curr_index; 3346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (((*val_iter) - theta) <= 0) { 3356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua break; 3366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++curr_index; 3386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // compute the final projection. 3406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua for (Witer iter = w_.begin(); 3416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter != w_.end(); 3426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ++iter) { 3436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua int sign_mul = iter->second > 0; 3446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua iter->second = max(sign_mul * normalizer_ * 3456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua (fabs(iter->second/normalizer_) - theta), 3466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 0.0); 3476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 3496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 3506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 3516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL2(const double l2_norm) { 3526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CHECK_GT(l2_norm, 0); 3536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua double curr_l2_norm = L2Norm(); 3546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // Check if a projection is necessary. 3556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (curr_l2_norm > l2_norm) { 3566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua normalizer_ *= curr_l2_norm / l2_norm; 3576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 3596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 3606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash> 3616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint32 SparseWeightVector<Key, Hash>::Reproject(const double norm, 3626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua const RegularizationType r) { 3636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua CHECK_GT(norm, 0); 3646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (r == L0) { 3656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ReprojectL0(norm); 3666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (r == L1) { 3676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ReprojectL1(norm); 3686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else if (r == L2) { 3696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ReprojectL2(norm); 3706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } else { 3716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // This else is just to ensure that if other RegularizationTypes are 3726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // supported in the enum later which require manipulations not related 3736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // to SparseWeightVector then we catch the accidental argument here. 3746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("Unsupported regularization type requested"); 3756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return -1; 3766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // If the normalizer gets dangerously large or small, normalize the 3786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // entire vector. This stops projections from sending the vector 3796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // weights and the normalizer simultaneously all very small or 3806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // large, causing under/over flows. But if you hit this too often 3816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua // it's a sign you've chosen a bad lambda. 3826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (normalizer_ < kNormalizerMin) { 3836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("Resetting normalizer to 1.0 to prevent underflow. " 3846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua "Is lambda too large?"); 3856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ResetNormalizer(); 3866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua if (normalizer_ > kNormalizerMax) { 3886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ALOGE("Resetting normalizer to 1.0 to prevent overflow. " 3896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua "Is lambda too small?"); 3906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua ResetNormalizer(); 3916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua } 3926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua return 0; 3936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} 3946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua 3956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate class SparseWeightVector<std::string, std::hash_map<std::string, double> >; 3966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate class SparseWeightVector<int, std::hash_map<int, double> >; 3976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate class SparseWeightVector<uint64, std::hash_map<uint64, double> >; 3986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua} // namespace learning_stochastic_linear 399