1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFloweryou may not use this file except in compliance with the License.
594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerYou may obtain a copy of the License at
694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
1094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
1194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerSee the License for the specific language governing permissions and
1394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerlimitations under the License.
1494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower==============================================================================*/
1594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
165e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// See docs in ../ops/sdca_ops.cc.
1794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
1894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#define EIGEN_USE_THREADS
19db828ab20399697bb97c218ca6c435ee59b2a029A. Unique TensorFlower
207abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <stdint.h>
2194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include <atomic>
22270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include <limits>
237abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <memory>
247abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <new>
2594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include <string>
267abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <vector>
2794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
2894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
29270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/framework/device_base.h"
3094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h"
3194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op.h"
3294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op_def_builder.h"
3394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h"
3494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
3594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h"
3694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_types.h"
37270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/framework/types.h"
38eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/hinge-loss.h"
39eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/logistic-loss.h"
407abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/kernels/loss.h"
417abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/kernels/sdca_internal.h"
42eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/smooth-hinge-loss.h"
43eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/squared-loss.h"
44e7ff84394d7ebf57fd52f5e7b12075c3197cb6e3A. Unique TensorFlower#include "tensorflow/core/lib/core/coding.h"
4594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h"
46270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/core/status.h"
47270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/core/stringpiece.h"
48270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/gtl/inlined_vector.h"
49270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/strings/stringprintf.h"
50370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower#include "tensorflow/core/platform/fingerprint.h"
517abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/platform/macros.h"
5294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/platform/mutex.h"
53270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/platform/types.h"
5494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/util/work_sharder.h"
5594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
5694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowernamespace tensorflow {
576e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower
5894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowernamespace {
5994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
607abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::Example;
617abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::Examples;
627abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::ExampleStatistics;
637abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::ModelWeights;
64982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing sdca::Regularizations;
65c2a7ea886e4ef49a870481007c209d32f317cd14A. Unique TensorFlower
661cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerstruct ComputeOptions {
676882effb863dcd0da00d3287959deac46734a0b2A. Unique TensorFlower  explicit ComputeOptions(OpKernelConstruction* const context) {
6894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    string loss_type;
69c2a7ea886e4ef49a870481007c209d32f317cd14A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
7094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    if (loss_type == "logistic_loss") {
711cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      loss_updater.reset(new LogisticLossUpdater);
722ababd9c11a60ba4293aa0a3bffb4d68fc9aa89bA. Unique TensorFlower    } else if (loss_type == "squared_loss") {
731cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      loss_updater.reset(new SquaredLossUpdater);
74cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower    } else if (loss_type == "hinge_loss") {
751cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      loss_updater.reset(new HingeLossUpdater);
76bc85766000089aef714c372f255a28691fd0df45A. Unique TensorFlower    } else if (loss_type == "smooth_hinge_loss") {
771cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      loss_updater.reset(new SmoothHingeLossUpdater);
782ababd9c11a60ba4293aa0a3bffb4d68fc9aa89bA. Unique TensorFlower    } else {
79982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen      OP_REQUIRES(
80982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          context, false,
81982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen          errors::InvalidArgument("Unsupported loss type: ", loss_type));
8294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    }
831cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative));
8494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    OP_REQUIRES_OK(
851cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        context, context->GetAttr("num_sparse_features", &num_sparse_features));
861cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
871cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                                             &num_sparse_features_with_values));
881cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    OP_REQUIRES_OK(context,
891cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                   context->GetAttr("num_dense_features", &num_dense_features));
9094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower    OP_REQUIRES(
911cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        context, num_sparse_features + num_dense_features > 0,
9294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower        errors::InvalidArgument("Requires at least one feature to train."));
93caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower
94982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen    OP_REQUIRES(context,
95982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                static_cast<int64>(num_sparse_features) +
96982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                        static_cast<int64>(num_dense_features) <=
97982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen                    std::numeric_limits<int>::max(),
98caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower                errors::InvalidArgument(
99caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower                    strings::Printf("Too many feature groups: %lld > %d",
1001cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                                    static_cast<int64>(num_sparse_features) +
1011cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                                        static_cast<int64>(num_dense_features),
102caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower                                    std::numeric_limits<int>::max())));
1031cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    OP_REQUIRES_OK(
1041cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
1059ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower    OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
1061cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                                             &num_inner_iterations));
1071cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    OP_REQUIRES_OK(context, regularizations.Initialize(context));
10894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower  }
10994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
1101cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  std::unique_ptr<DualLossUpdater> loss_updater;
1111cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  int num_sparse_features = 0;
1121cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  int num_sparse_features_with_values = 0;
1131cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  int num_dense_features = 0;
1141cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  int num_inner_iterations = 0;
1151cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  int num_loss_partitions = 0;
1161cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  bool adaptative = false;
1171cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  Regularizations regularizations;
1181cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower};
119caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower
1205e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// TODO(shengx): The helper classes/methods are changed to support multiclass
1215e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// SDCA, which lead to changes within this function. Need to revisit the
1225e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// convergence once the multiclass SDCA is in.
1231cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowervoid DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
1241cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  ModelWeights model_weights;
1251cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  OP_REQUIRES_OK(context, model_weights.Initialize(context));
1261cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1271cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  Examples examples;
1281cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  OP_REQUIRES_OK(
1291cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      context,
1301cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      examples.Initialize(context, model_weights, options.num_sparse_features,
1311cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                          options.num_sparse_features_with_values,
1321cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                          options.num_dense_features));
1331cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1341cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  const Tensor* example_state_data_t;
1351cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  OP_REQUIRES_OK(context,
1361cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                 context->input("example_state_data", &example_state_data_t));
1371cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  TensorShape expected_example_state_shape({examples.num_examples(), 4});
1381cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  OP_REQUIRES(context,
1391cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower              example_state_data_t->shape() == expected_example_state_shape,
1401cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower              errors::InvalidArgument(
1411cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                  "Expected shape ", expected_example_state_shape.DebugString(),
1421cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                  " for example_state_data, got ",
1431cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                  example_state_data_t->shape().DebugString()));
1441cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1451cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  Tensor mutable_example_state_data_t(*example_state_data_t);
1461cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  auto example_state_data = mutable_example_state_data_t.matrix<float>();
147bc225bfaa534acc25047fe844f19edc333b7a76aPeter Hawkins  OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
148bc225bfaa534acc25047fe844f19edc333b7a76aPeter Hawkins                                              mutable_example_state_data_t));
1491cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1501cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  if (options.adaptative) {
1517abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower    OP_REQUIRES_OK(context,
1527abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower                   examples.SampleAdaptativeProbabilities(
1537abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower                       options.num_loss_partitions, options.regularizations,
1547abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower                       model_weights, example_state_data, options.loss_updater,
1557abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower                       /*num_weight_vectors =*/1));
1561cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  }
157db828ab20399697bb97c218ca6c435ee59b2a029A. Unique TensorFlower
1581cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  mutex mu;
1591cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  Status train_step_status GUARDED_BY(mu);
1601cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  std::atomic<std::int64_t> atomic_index(-1);
1611cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  auto train_step = [&](const int64 begin, const int64 end) {
1621cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    // The static_cast here is safe since begin and end can be at most
1631cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    // num_examples which is an int.
1641cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    for (int id = static_cast<int>(begin); id < end; ++id) {
1651cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const int64 example_index =
1661cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          examples.sampled_index(++atomic_index, options.adaptative);
1671cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const Example& example = examples.example(example_index);
1681cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const float dual = example_state_data(example_index, 0);
1691cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const float example_weight = example.example_weight();
1701cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      float example_label = example.example_label();
1711cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const Status conversion_status =
1721cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          options.loss_updater->ConvertLabel(&example_label);
1731cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      if (!conversion_status.ok()) {
1741cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        mutex_lock l(mu);
1751cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        train_step_status = conversion_status;
1761cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        // Return from this worker thread - the calling thread is
1771cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        // responsible for checking context status and returning on error.
1781cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        return;
1791cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      }
1801cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1811cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      // Compute wx, example norm weighted by regularization, dual loss,
1821cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      // primal loss.
1835e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower      // For binary SDCA, num_weight_vectors should be one.
1841cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const ExampleStatistics example_statistics =
1855e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower          example.ComputeWxAndWeightedExampleNorm(
1865e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower              options.num_loss_partitions, model_weights,
1875e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower              options.regularizations, 1 /* num_weight_vectors */);
1881cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1891cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const double new_dual = options.loss_updater->ComputeUpdatedDual(
1901cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          options.num_loss_partitions, example_label, example_weight, dual,
1915e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower          example_statistics.wx[0], example_statistics.normalized_squared_norm);
1921cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
1931cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      // Compute new weights.
1941cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      const double normalized_bounded_dual_delta =
1951cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          (new_dual - dual) * example_weight /
1961cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          options.regularizations.symmetric_l2();
1975e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower      model_weights.UpdateDeltaWeights(
1985e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower          context->eigen_cpu_device(), example,
1995e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower          std::vector<double>{normalized_bounded_dual_delta});
2001cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
2011cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      // Update example data.
2021cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      example_state_data(example_index, 0) = new_dual;
2031cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      example_state_data(example_index, 1) =
2041cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          options.loss_updater->ComputePrimalLoss(
2055e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower              example_statistics.prev_wx[0], example_label, example_weight);
2061cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      example_state_data(example_index, 2) =
2071cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower          options.loss_updater->ComputeDualLoss(dual, example_label,
2081cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                                                example_weight);
2091cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      example_state_data(example_index, 3) = example_weight;
210f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower    }
2111cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  };
2121cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
2131cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  // number of cpus, and cost per example.
2141cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  const int64 kCostPerUnit = examples.num_features();
2151cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  const DeviceBase::CpuWorkerThreads& worker_threads =
2161cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      *context->device()->tensorflow_cpu_worker_threads();
217f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower
2181cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  Shard(worker_threads.num_threads, worker_threads.workers,
2191cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower        examples.num_examples(), kCostPerUnit, train_step);
2201cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  OP_REQUIRES_OK(context, train_step_status);
2211cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower}
22294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
2231cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower}  // namespace
2241cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
2251cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerclass SdcaOptimizer : public OpKernel {
2261cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower public:
2271cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  explicit SdcaOptimizer(OpKernelConstruction* const context)
2281cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower      : OpKernel(context), options_(context) {}
2291cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower
230db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner  void Compute(OpKernelContext* context) override {
2311cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower    DoCompute(options_, context);
23294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower  }
23394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
23494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower private:
235733321f6159f55474f4d7822b5a1c174d30e696dYuan Yu  // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
2369ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  // template the entire class to avoid the virtual table lookup penalty in
2379ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  // the inner loop.
2381cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower  ComputeOptions options_;
23994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower};
2401cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
2411cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower                        SdcaOptimizer);
24294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower
2439ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerclass SdcaShrinkL1 : public OpKernel {
2449ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public:
2456a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower  explicit SdcaShrinkL1(OpKernelConstruction* const context)
2466a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower      : OpKernel(context) {
247ad2122584430ca4e4cc97fa642c1536357402eb6A. Unique TensorFlower    OP_REQUIRES_OK(context, regularizations_.Initialize(context));
2489ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  }
2499ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
250db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner  void Compute(OpKernelContext* context) override {
251c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower    OpMutableInputList weights_inputs;
252c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower    OP_REQUIRES_OK(context,
253c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower                   context->mutable_input_list("weights", &weights_inputs));
254c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower
255c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower    auto do_work = [&](const int64 begin, const int64 end) {
256c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      for (int i = begin; i < end; ++i) {
257c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower        auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
2586e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower        prox_w.device(context->eigen_cpu_device()) =
2595e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower            regularizations_.EigenShrinkVector(prox_w);
2606e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower      }
2616e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower    };
2626d6ba9e3d88cef87c4e0b2753ffe40d939901fa4A. Unique TensorFlower
263c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower    if (weights_inputs.size() > 0) {
264c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      int64 num_weights = 0;
265c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      for (int i = 0; i < weights_inputs.size(); ++i) {
266c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower        num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
267c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      }
268c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      // TODO(sibyl-Aix6ihai): Tune this value.
269c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
270c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      const DeviceBase::CpuWorkerThreads& worker_threads =
271c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower          *context->device()->tensorflow_cpu_worker_threads();
272c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower      Shard(worker_threads.num_threads, worker_threads.workers,
273c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower            weights_inputs.size(), kCostPerUnit, do_work);
274c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower    }
2759ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  }
2769ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
2779ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower private:
2789ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  Regularizations regularizations_;
2799ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower};
2809ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
2819ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
28201fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// Computes platform independent, compact and unique (with very high
28301fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// probability) representation of an example id. It shouldn't be put in
28401fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// persistent storage, as its implementation may change in the future.
28501fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower//
28601fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// The current probability of at least one collision for 1B example_ids is
28787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower// approximately 10^-21 (ie 2^60 / 2^129).
288370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlowerclass SdcaFprint : public OpKernel {
2899ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public:
2906a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower  explicit SdcaFprint(OpKernelConstruction* const context)
2916a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower      : OpKernel(context) {}
2929ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
293db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner  void Compute(OpKernelContext* context) override {
2946a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower    const Tensor& input = context->input(0);
29587a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower    OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
29687a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower                errors::InvalidArgument("Input must be a vector, got shape ",
29787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower                                        input.shape().DebugString()));
298370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower    Tensor* out;
29987a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower    const int64 num_elements = input.NumElements();
30087a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower    OP_REQUIRES_OK(context, context->allocate_output(
30187a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower                                0, TensorShape({num_elements, 2}), &out));
3029ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower
303370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower    const auto in_values = input.flat<string>();
30487a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower    auto out_values = out->matrix<int64>();
30587a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower
30687a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower    for (int64 i = 0; i < num_elements; ++i) {
30787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower      const Fprint128 fprint = Fingerprint128(in_values(i));
30887a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower      // Never return 0 or 1 as the first value of the hash to allow these to
30987a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower      // safely be used as sentinel values (e.g. dense hash table empty key).
31087a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower      out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
31187a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower                             ? fprint.low64
31287a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower                             : fprint.low64 + ~static_cast<uint64>(1);
31387a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower      out_values(i, 1) = fprint.high64;
31456f1d64998744ad655fe5c428658a13be35b865eEugene Brevdo    }
3159ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower  }
3169ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower};
317370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
31856f1d64998744ad655fe5c428658a13be35b865eEugene Brevdo
31994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower}  // namespace tensorflow
320