sparse_weight_vector.cpp revision a08525ea290ff4edc766eda1ec80388be866a79e
16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "sparse_weight_vector.h"
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <algorithm>
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <list>
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <vector>
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <math.h>
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::vector;
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::list;
266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huausing std::max;
276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear {
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Max/Min permitted values of normalizer_ for preventing under/overflows.
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastatic double kNormalizerMin = 1e-20;
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastatic double kNormalizerMax = 1e20;
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huabool SparseWeightVector<Key, Hash>::IsValid() const {
366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (isnan(normalizer_) || __isinff(normalizer_))
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return false;
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w_.begin();
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       iter != w_.end();
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       ++iter) {
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (isnanf(iter->second) || __isinff(iter->second))
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return false;
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return true;
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveWeightUpdate(
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double multiplier,
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1,
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double additive_const) {
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w1.w_.begin();
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w1.w_.end();
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[iter->first] += ((multiplier * iter->second) / w1.normalizer_
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                        + additive_const) * normalizer_;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveSquaredWeightUpdate(
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double multiplier,
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1,
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double additive_const) {
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w1.w_.begin();
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w1.w_.end();
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[iter->first] += ((multiplier * iter->second * iter->second) /
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          (w1.normalizer_ * w1.normalizer_)
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                        + additive_const) * normalizer_;
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveInvSqrtWeightUpdate(
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double multiplier,
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1,
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double additive_const) {
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w1.w_.begin();
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w1.w_.end();
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if(iter->second > 0.0) {
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      w_[iter->first] += ((multiplier * sqrt(w1.normalizer_)) /
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          (sqrt(iter->second))
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                          + additive_const) * normalizer_;
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::AdditiveWeightUpdateBounded(
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double multiplier,
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1,
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double additive_const) {
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double min_bound = 0;
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double max_bound = 0;
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w1.w_.begin();
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w1.w_.end();
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[iter->first] += ((multiplier * iter->second) / w1.normalizer_
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                        + additive_const) * normalizer_;
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound);
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_min_bounded) {
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((w_[iter->first] / normalizer_) < min_bound) {
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        w_[iter->first] = min_bound*normalizer_;
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        continue;
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      }
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound);
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_max_bounded) {
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((w_[iter->first] / normalizer_) > max_bound)
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        w_[iter->first] = max_bound*normalizer_;
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::MultWeightUpdate(
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1) {
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iter->second *= w1.GetElement(iter->first);
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  normalizer_ *= w1.normalizer_;
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::MultWeightUpdateBounded(
1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1) {
1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double min_bound = 0;
1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double max_bound = 0;
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  normalizer_ *= w1.normalizer_;
1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iter->second *= w1.GetElement(iter->first);
1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound);
1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_min_bounded) {
1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((iter->second / normalizer_) < min_bound) {
1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter->second = min_bound*normalizer_;
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        continue;
1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      }
1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound);
1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_max_bounded) {
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((iter->second / normalizer_) > max_bound)
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter->second = max_bound*normalizer_;
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return;
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ResetNormalizer() {
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       iter != w_.end();
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       ++iter) {
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iter->second /= normalizer_;
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  normalizer_ = 1.0;
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectToBounds() {
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double min_bound = 0;
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double max_bound = 0;
1746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
1766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       iter != w_.end();
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       ++iter) {
1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_min_bounded = GetValue(wmin_, iter->first, &min_bound);
1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_min_bounded) {
1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((iter->second/normalizer_) < min_bound) {
1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter->second = min_bound*normalizer_;
1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        continue;
1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      }
1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    bool is_max_bounded = GetValue(wmax_, iter->first, &max_bound);
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (is_max_bounded) {
1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      if ((iter->second/normalizer_) > max_bound)
1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter->second = max_bound*normalizer_;
1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::DotProduct(
1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const SparseWeightVector<Key, Hash> &w1) const {
1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double result = 0;
1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (w_.size() > w1.w_.size()) {
1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (Witer_const iter = w1.w_.begin();
1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter != w1.w_.end();
2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        ++iter) {
2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      result += iter->second * GetElement(iter->first);
2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result /= (this->normalizer_ * w1.normalizer_);
2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } else {
2056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    for (Witer_const iter = w_.begin();
2066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        iter != w_.end();
2076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua        ++iter) {
2086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      result += iter->second * w1.GetElement(iter->first);
2096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result /= (this->normalizer_ * w1.normalizer_);
2116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
2136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
2166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::LxNorm(const double x) const {
2176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double result = 0;
2186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  CHECK_GT(x, 0);
2196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w_.begin();
2206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += pow(iter->second, x);
2236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return (pow(result, 1.0 / x) / normalizer_);
2256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
2286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L2Norm() const {
2296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double result = 0;
2306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w_.begin();
2316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += iter->second * iter->second;
2346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return sqrt(result)/normalizer_;
2366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
2396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L1Norm() const {
2406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double result = 0;
2416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w_.begin();
2426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    result += fabs(iter->second);
2456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result / normalizer_;
2476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
2506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huadouble SparseWeightVector<Key, Hash>::L0Norm(
2516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    const double epsilon) const {
2526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double result = 0;
2536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer_const iter = w_.begin();
2546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (fabs(iter->second / normalizer_) > epsilon)
2576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++result;
2586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return result;
2606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Algorithm for L0 projection which takes O(n log(n)), where n is
2636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// the number of non-zero elements in the vector.
2646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
2656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL0(const double l0_norm) {
2666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// First calculates the order-statistics of the sparse vector
2676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// and then reprojects to the L0 orthant with the requested norm.
2686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  CHECK_GT(l0_norm, 0);
2696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 req_l0_norm = static_cast<uint64>(l0_norm);
2706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Compute order statistics and the current L0 norm.
2716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  vector<double> abs_val_vec;
2726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  uint64 curr_l0_norm = 0;
2736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  const double epsilone = 1E-05;
2746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
2756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (fabs(iter->second/normalizer_) > epsilone) {
2786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      abs_val_vec.push_back(fabs(iter->second/normalizer_));
2796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++curr_l0_norm;
2806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // check if a projection is necessary
2836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (curr_l0_norm < req_l0_norm) {
2846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return;
2856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  std::nth_element(&abs_val_vec[0],
2876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua              &abs_val_vec[curr_l0_norm - req_l0_norm],
2886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua              &abs_val_vec[curr_l0_norm]);
2896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  const double theta = abs_val_vec[curr_l0_norm - req_l0_norm];
2906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // compute the final projection.
2916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
2926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
2936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
2946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if ((fabs(iter->second/normalizer_) - theta) < 0) {
2956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter->second = 0;
2966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
2976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
2996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
3006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Slow algorithm for accurate L1 projection which takes O(n log(n)), where n is
3016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// the number of non-zero elements in the vector.
3026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
3036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL1(const double l1_norm) {
3046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// First calculates the order-statistics of the sparse vector
3056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// applies a probability simplex projection to the abs(vector)
3066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// and reprojects back to the original with the appropriate sign.
3076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// For ref. see "Efficient Projections into the l1-ball for Learning
3086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// in High Dimensions"
3096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  CHECK_GT(l1_norm, 0);
3106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Compute order statistics and the current L1 norm.
3116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  list<double> abs_val_list;
3126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double curr_l1_norm = 0;
3136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
3146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
3156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
3166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    abs_val_list.push_back(fabs(iter->second/normalizer_));
3176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    curr_l1_norm += fabs(iter->second/normalizer_);
3186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // check if a projection is necessary
3206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (curr_l1_norm < l1_norm) {
3216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return;
3226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  abs_val_list.sort();
3246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  abs_val_list.reverse();
3256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Compute projection on the probability simplex.
3266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double curr_index = 1;
3276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double theta = 0;
3286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double cum_sum = 0;
3296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (list<double>::iterator val_iter = abs_val_list.begin();
3306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       val_iter != abs_val_list.end();
3316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       ++val_iter) {
3326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    cum_sum += *val_iter;
3336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    theta = (cum_sum - l1_norm)/curr_index;
3346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (((*val_iter) - theta) <= 0) {
3356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      break;
3366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
3376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ++curr_index;
3386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // compute the final projection.
3406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (Witer iter = w_.begin();
3416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      iter != w_.end();
3426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      ++iter) {
3436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    int sign_mul = iter->second > 0;
3446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    iter->second = max(sign_mul * normalizer_ *
3456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                           (fabs(iter->second/normalizer_) - theta),
3466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                       0.0);
3476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
3496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
3506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
3516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huavoid SparseWeightVector<Key, Hash>::ReprojectL2(const double l2_norm) {
3526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  CHECK_GT(l2_norm, 0);
3536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double curr_l2_norm = L2Norm();
3546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Check if a projection is necessary.
3556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (curr_l2_norm > l2_norm) {
3566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ *= curr_l2_norm / l2_norm;
3576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
3596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
3606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate<class Key, class Hash>
3616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaint32 SparseWeightVector<Key, Hash>::Reproject(const double norm,
3626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                               const RegularizationType r) {
3636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  CHECK_GT(norm, 0);
3646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (r == L0) {
3656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ReprojectL0(norm);
3666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } else if (r == L1) {
3676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ReprojectL1(norm);
3686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } else if (r == L2) {
3696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ReprojectL2(norm);
3706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  } else {
3716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // This else is just to ensure that if other RegularizationTypes are
3726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // supported in the enum later which require manipulations not related
3736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    // to SparseWeightVector then we catch the accidental argument here.
3746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ALOGE("Unsupported regularization type requested");
3756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return -1;
3766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // If the normalizer gets dangerously large or small, normalize the
3786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // entire vector. This stops projections from sending the vector
3796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // weights and the normalizer simultaneously all very small or
3806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // large, causing under/over flows. But if you hit this too often
3816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // it's a sign you've chosen a bad lambda.
3826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (normalizer_ < kNormalizerMin) {
3836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ALOGE("Resetting normalizer to 1.0 to prevent underflow. "
3846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          "Is lambda too large?");
3856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ResetNormalizer();
3866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  if (normalizer_ > kNormalizerMax) {
3886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ALOGE("Resetting normalizer to 1.0 to prevent overflow. "
3896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua          "Is lambda too small?");
3906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    ResetNormalizer();
3916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
3926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return 0;
3936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}
3946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
395a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class SparseWeightVector<std::string, std::unordered_map<std::string, double> >;
396a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class SparseWeightVector<int, std::unordered_map<int, double> >;
397a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate class SparseWeightVector<uint64, std::unordered_map<uint64, double> >;
3986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learning_stochastic_linear
399