16b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua/*
26b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Copyright (C) 2012 The Android Open Source Project
36b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
46b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Licensed under the Apache License, Version 2.0 (the "License");
56b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * you may not use this file except in compliance with the License.
66b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * You may obtain a copy of the License at
76b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
86b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *      http://www.apache.org/licenses/LICENSE-2.0
96b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua *
106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * Unless required by applicable law or agreed to in writing, software
116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * distributed under the License is distributed on an "AS IS" BASIS,
126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * See the License for the specific language governing permissions and
146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua * limitations under the License.
156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua */
166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Purpose: A container for sparse weight vectors
186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Maintains the sparse vector as a list of (name, value) pairs alongwith
196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// a normalizer_. All operations assume that (name, value/normalizer_) is the
206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// true value in question.
216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#ifndef LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#define LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <math.h>
26a08525ea290ff4edc766eda1ec80388be866a79eDan Albert
27a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <iosfwd>
286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <sstream>
296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include <string>
30a08525ea290ff4edc766eda1ec80388be866a79eDan Albert#include <unordered_map>
316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#include "common_defs.h"
336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huanamespace learning_stochastic_linear {
356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
36a08525ea290ff4edc766eda1ec80388be866a79eDan Alberttemplate<class Key = std::string, class Hash = std::unordered_map<Key, double> >
376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huaclass SparseWeightVector {
386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua public:
396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  typedef Hash Wmap;
406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  typedef typename Wmap::iterator Witer;
416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  typedef typename Wmap::const_iterator Witer_const;
426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  SparseWeightVector() {
436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ = 1.0;
446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  ~SparseWeightVector() {}
466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  explicit SparseWeightVector(const SparseWeightVector<Key, Hash> &other) {
476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    CopyFrom(other);
486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void operator=(const SparseWeightVector<Key, Hash> &other) {
506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    CopyFrom(other);
516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void CopyFrom(const SparseWeightVector<Key, Hash> &other) {
536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_ = other.w_;
546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmin_ = other.wmin_;
556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmax_ = other.wmax_;
566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ = other.normalizer_;
576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // This function implements checks to prevent unbounded vectors. It returns
606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // true if the checks succeed and false otherwise. A vector is deemed invalid
616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // if any of these conditions are met:
626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 1. it has no values.
636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 2. its normalizer is nan or inf or close to zero.
646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 3. any of its values are nan or inf.
656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // 4. its L0 norm is close to zero.
666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  bool IsValid() const;
676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Normalizer getters and setters.
696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetNormalizer() const {
706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return normalizer_;
716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetNormalizer(const double norm) {
736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ = norm;
746b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
756b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void NormalizerMultUpdate(const double mul) {
766b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ = normalizer_ * mul;
776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void NormalizerAddUpdate(const double add) {
796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ += add;
806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Divides all the values by the normalizer, then it resets it to 1.0
836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void ResetNormalizer();
846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Bound getters and setters.
866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // True if there is a bound with val containing the bound. false otherwise.
876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  bool GetElementMinBound(const Key &fname, double *val) const {
886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return GetValue(wmin_, fname, val);
896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  bool GetElementMaxBound(const Key &fname, double *val) const {
916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return GetValue(wmax_, fname, val);
926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetElementMinBound(const Key &fname, const double bound) {
946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmin_[fname] = bound;
956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetElementMaxBound(const Key &fname, const double bound) {
976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmax_[fname] = bound;
986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Element getters and setters.
1006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double GetElement(const Key &fname) const {
1016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    double val = 0;
1026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    GetValue(w_, fname, &val);
1036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return val;
1046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1056b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void SetElement(const Key &fname, const double val) {
1066b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    //DCHECK(!isnan(val));
1076b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[fname] = val;
1086b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1096b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void AddUpdateElement(const Key &fname, const double val) {
1106b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[fname] += val;
1116b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1126b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void MultUpdateElement(const Key &fname, const double val) {
1136b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_[fname] *= val;
1146b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1156b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Load another weight vectors. Will overwrite the current vector.
1166b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void LoadWeightVector(const SparseWeightVector<Key, Hash> &vec) {
1176b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_.clear();
1186b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_.insert(vec.w_.begin(), vec.w_.end());
1196b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmax_.insert(vec.wmax_.begin(), vec.wmax_.end());
1206b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmin_.insert(vec.wmin_.begin(), vec.wmin_.end());
1216b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    normalizer_ = vec.normalizer_;
1226b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1236b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void Clear() {
1246b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    w_.clear();
1256b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmax_.clear();
1266b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    wmin_.clear();
1276b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1286b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  const Wmap& GetMap() const {
1296b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return w_;
1306b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1316b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Vector Operations.
1326b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void AdditiveWeightUpdate(const double multiplier,
1336b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                            const SparseWeightVector<Key, Hash> &w1,
1346b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                            const double additive_const);
1356b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void AdditiveSquaredWeightUpdate(const double multiplier,
1366b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const SparseWeightVector<Key, Hash> &w1,
1376b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const double additive_const);
1386b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void AdditiveInvSqrtWeightUpdate(const double multiplier,
1396b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const SparseWeightVector<Key, Hash> &w1,
1406b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const double additive_const);
1416b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void MultWeightUpdate(const SparseWeightVector<Key, Hash> &w1);
1426b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double DotProduct(const SparseWeightVector<Key, Hash> &s) const;
1436b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // L-x norm. eg. L1, L2.
1446b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double LxNorm(const double x) const;
1456b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double L2Norm() const;
1466b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double L1Norm() const;
1476b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double L0Norm(const double epsilon) const;
1486b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Bound preserving updates.
1496b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void AdditiveWeightUpdateBounded(const double multiplier,
1506b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const SparseWeightVector<Key, Hash> &w1,
1516b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                                   const double additive_const);
1526b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void MultWeightUpdateBounded(const SparseWeightVector<Key, Hash> &w1);
1536b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void ReprojectToBounds();
1546b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void ReprojectL0(const double l0_norm);
1556b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void ReprojectL1(const double l1_norm);
1566b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  void ReprojectL2(const double l2_norm);
1576b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Reproject using the given norm.
1586b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Will also rescale regularizer_ if it gets too small/large.
1596b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  int32 Reproject(const double norm, const RegularizationType r);
1606b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Convert this vector to a string, simply for debugging.
1616b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  std::string DebugString() const {
1626b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    std::stringstream stream;
1636b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    stream << *this;
1646b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    return stream.str();
1656b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1666b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua private:
1676b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // The weight map.
1686b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  Wmap w_;
1696b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Constraint bounds.
1706b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  Wmap wmin_;
1716b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  Wmap wmax_;
1726b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  // Normalizing constant in magnitude measurement.
1736b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  double normalizer_;
174a08525ea290ff4edc766eda1ec80388be866a79eDan Albert  // This function is necessary since by default unordered_map inserts an
175a08525ea290ff4edc766eda1ec80388be866a79eDan Albert  // element if it does not find the key through [] operator. It implements a
176a08525ea290ff4edc766eda1ec80388be866a79eDan Albert  // lookup without the space overhead of an add.
1776b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  bool GetValue(const Wmap &w1, const Key &fname, double *val) const {
1786b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    Witer_const iter = w1.find(fname);
1796b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    if (iter != w1.end()) {
1806b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      (*val) = iter->second;
1816b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return true;
1826b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    } else {
1836b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      (*val) = 0;
1846b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua      return false;
1856b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    }
1866b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
1876b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua};
1886b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
1896b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua// Outputs a SparseWeightVector, for debugging.
1906b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huatemplate <class Key, class Hash>
1916b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Huastd::ostream& operator<<(std::ostream &stream,
1926b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua                    const SparseWeightVector<Key, Hash> &vector) {
1936b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  typename SparseWeightVector<Key, Hash>::Wmap w_map = vector.GetMap();
1946b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  stream << "[[ ";
1956b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  for (typename SparseWeightVector<Key, Hash>::Witer_const iter = w_map.begin();
1966b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       iter != w_map.end();
1976b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua       ++iter) {
1986b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua    stream << "<" << iter->first << ", " << iter->second << "> ";
1996b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  }
2006b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua  return stream << " ]]";
2016b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua};
2026b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua
2036b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua}  // namespace learning_stochastic_linear
2046b4eebc73439cbc3ddfb547444a341d1f9be7996Wei Hua#endif  // LEARNING_STOCHASTIC_LINEAR_SPARSE_WEIGHT_VECTOR_H_
205