1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFloweryou may not use this file except in compliance with the License. 594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerYou may obtain a copy of the License at 694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 1094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 1194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 1294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerSee the License for the specific language governing permissions and 1394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowerlimitations under the License. 1494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower==============================================================================*/ 1594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 165e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// See docs in ../ops/sdca_ops.cc. 1794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 1894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#define EIGEN_USE_THREADS 19db828ab20399697bb97c218ca6c435ee59b2a029A. Unique TensorFlower 207abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <stdint.h> 2194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include <atomic> 22270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include <limits> 237abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <memory> 247abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <new> 2594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include <string> 267abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include <vector> 2794a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 2894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/framework/device_base.h" 3094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/kernel_def_builder.h" 3194a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op.h" 3294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op_def_builder.h" 3394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/op_kernel.h" 3494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 3594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_shape.h" 3694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_types.h" 37270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/framework/types.h" 38eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/hinge-loss.h" 39eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/logistic-loss.h" 407abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/kernels/loss.h" 417abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/kernels/sdca_internal.h" 42eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/smooth-hinge-loss.h" 43eb4421e17157a80ae97f258a2c935e0ae00f1ed9A. Unique TensorFlower#include "tensorflow/core/kernels/squared-loss.h" 44e7ff84394d7ebf57fd52f5e7b12075c3197cb6e3A. Unique TensorFlower#include "tensorflow/core/lib/core/coding.h" 4594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/lib/core/errors.h" 46270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/core/status.h" 47270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/core/stringpiece.h" 48270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/gtl/inlined_vector.h" 49270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/lib/strings/stringprintf.h" 50370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower#include "tensorflow/core/platform/fingerprint.h" 517abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower#include "tensorflow/core/platform/macros.h" 5294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/platform/mutex.h" 53270b010b6d51ada90cbd9f7790b5dab71ce58bf7A. Unique TensorFlower#include "tensorflow/core/platform/types.h" 5494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower#include "tensorflow/core/util/work_sharder.h" 5594a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 5694a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowernamespace tensorflow { 576e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower 5894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlowernamespace { 5994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 607abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::Example; 617abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::Examples; 627abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::ExampleStatistics; 637abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlowerusing sdca::ModelWeights; 64982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsenusing sdca::Regularizations; 65c2a7ea886e4ef49a870481007c209d32f317cd14A. Unique TensorFlower 661cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerstruct ComputeOptions { 676882effb863dcd0da00d3287959deac46734a0b2A. Unique TensorFlower explicit ComputeOptions(OpKernelConstruction* const context) { 6894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower string loss_type; 69c2a7ea886e4ef49a870481007c209d32f317cd14A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type)); 7094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower if (loss_type == "logistic_loss") { 711cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower loss_updater.reset(new LogisticLossUpdater); 722ababd9c11a60ba4293aa0a3bffb4d68fc9aa89bA. Unique TensorFlower } else if (loss_type == "squared_loss") { 731cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower loss_updater.reset(new SquaredLossUpdater); 74cbcaf4e85afb0f9eacbb5d2aadb12407786e0e2dA. Unique TensorFlower } else if (loss_type == "hinge_loss") { 751cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower loss_updater.reset(new HingeLossUpdater); 76bc85766000089aef714c372f255a28691fd0df45A. Unique TensorFlower } else if (loss_type == "smooth_hinge_loss") { 771cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower loss_updater.reset(new SmoothHingeLossUpdater); 782ababd9c11a60ba4293aa0a3bffb4d68fc9aa89bA. Unique TensorFlower } else { 79982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES( 80982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen context, false, 81982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen errors::InvalidArgument("Unsupported loss type: ", loss_type)); 8294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower } 831cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative)); 8494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower OP_REQUIRES_OK( 851cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context, context->GetAttr("num_sparse_features", &num_sparse_features)); 861cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values", 871cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower &num_sparse_features_with_values)); 881cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, 891cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context->GetAttr("num_dense_features", &num_dense_features)); 9094a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower OP_REQUIRES( 911cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context, num_sparse_features + num_dense_features > 0, 9294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower errors::InvalidArgument("Requires at least one feature to train.")); 93caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower 94982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen OP_REQUIRES(context, 95982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int64>(num_sparse_features) + 96982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen static_cast<int64>(num_dense_features) <= 97982549ea3423df4270ff154e5c764beb43d472daRasmus Munk Larsen std::numeric_limits<int>::max(), 98caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower errors::InvalidArgument( 99caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower strings::Printf("Too many feature groups: %lld > %d", 1001cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower static_cast<int64>(num_sparse_features) + 1011cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower static_cast<int64>(num_dense_features), 102caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower std::numeric_limits<int>::max()))); 1031cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK( 1041cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context, context->GetAttr("num_loss_partitions", &num_loss_partitions)); 1059ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations", 1061cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower &num_inner_iterations)); 1071cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, regularizations.Initialize(context)); 10894a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower } 10994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 1101cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower std::unique_ptr<DualLossUpdater> loss_updater; 1111cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower int num_sparse_features = 0; 1121cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower int num_sparse_features_with_values = 0; 1131cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower int num_dense_features = 0; 1141cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower int num_inner_iterations = 0; 1151cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower int num_loss_partitions = 0; 1161cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower bool adaptative = false; 1171cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower Regularizations regularizations; 1181cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower}; 119caeb1030c6421922798edfa4246d744cfcd2427cA. Unique TensorFlower 1205e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// TODO(shengx): The helper classes/methods are changed to support multiclass 1215e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// SDCA, which lead to changes within this function. Need to revisit the 1225e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower// convergence once the multiclass SDCA is in. 1231cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowervoid DoCompute(const ComputeOptions& options, OpKernelContext* const context) { 1241cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower ModelWeights model_weights; 1251cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, model_weights.Initialize(context)); 1261cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1271cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower Examples examples; 1281cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK( 1291cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context, 1301cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower examples.Initialize(context, model_weights, options.num_sparse_features, 1311cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.num_sparse_features_with_values, 1321cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.num_dense_features)); 1331cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1341cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const Tensor* example_state_data_t; 1351cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, 1361cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower context->input("example_state_data", &example_state_data_t)); 1371cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower TensorShape expected_example_state_shape({examples.num_examples(), 4}); 1381cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES(context, 1391cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data_t->shape() == expected_example_state_shape, 1401cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower errors::InvalidArgument( 1411cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower "Expected shape ", expected_example_state_shape.DebugString(), 1421cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower " for example_state_data, got ", 1431cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data_t->shape().DebugString())); 1441cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1451cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower Tensor mutable_example_state_data_t(*example_state_data_t); 1461cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower auto example_state_data = mutable_example_state_data_t.matrix<float>(); 147bc225bfaa534acc25047fe844f19edc333b7a76aPeter Hawkins OP_REQUIRES_OK(context, context->set_output("out_example_state_data", 148bc225bfaa534acc25047fe844f19edc333b7a76aPeter Hawkins mutable_example_state_data_t)); 1491cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1501cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower if (options.adaptative) { 1517abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower OP_REQUIRES_OK(context, 1527abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower examples.SampleAdaptativeProbabilities( 1537abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower options.num_loss_partitions, options.regularizations, 1547abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower model_weights, example_state_data, options.loss_updater, 1557abbdf2699d0c84e9c9fa2ffe9eb76d848d7be9fA. Unique TensorFlower /*num_weight_vectors =*/1)); 1561cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower } 157db828ab20399697bb97c218ca6c435ee59b2a029A. Unique TensorFlower 1581cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower mutex mu; 1591cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower Status train_step_status GUARDED_BY(mu); 1601cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower std::atomic<std::int64_t> atomic_index(-1); 1611cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower auto train_step = [&](const int64 begin, const int64 end) { 1621cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // The static_cast here is safe since begin and end can be at most 1631cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // num_examples which is an int. 1641cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower for (int id = static_cast<int>(begin); id < end; ++id) { 1651cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const int64 example_index = 1661cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower examples.sampled_index(++atomic_index, options.adaptative); 1671cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const Example& example = examples.example(example_index); 1681cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const float dual = example_state_data(example_index, 0); 1691cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const float example_weight = example.example_weight(); 1701cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower float example_label = example.example_label(); 1711cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const Status conversion_status = 1721cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.loss_updater->ConvertLabel(&example_label); 1731cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower if (!conversion_status.ok()) { 1741cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower mutex_lock l(mu); 1751cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower train_step_status = conversion_status; 1761cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // Return from this worker thread - the calling thread is 1771cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // responsible for checking context status and returning on error. 1781cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower return; 1791cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower } 1801cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1811cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // Compute wx, example norm weighted by regularization, dual loss, 1821cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // primal loss. 1835e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower // For binary SDCA, num_weight_vectors should be one. 1841cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const ExampleStatistics example_statistics = 1855e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower example.ComputeWxAndWeightedExampleNorm( 1865e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower options.num_loss_partitions, model_weights, 1875e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower options.regularizations, 1 /* num_weight_vectors */); 1881cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1891cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const double new_dual = options.loss_updater->ComputeUpdatedDual( 1901cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.num_loss_partitions, example_label, example_weight, dual, 1915e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower example_statistics.wx[0], example_statistics.normalized_squared_norm); 1921cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 1931cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // Compute new weights. 1941cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const double normalized_bounded_dual_delta = 1951cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower (new_dual - dual) * example_weight / 1961cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.regularizations.symmetric_l2(); 1975e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower model_weights.UpdateDeltaWeights( 1985e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower context->eigen_cpu_device(), example, 1995e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower std::vector<double>{normalized_bounded_dual_delta}); 2001cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 2011cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // Update example data. 2021cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data(example_index, 0) = new_dual; 2031cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data(example_index, 1) = 2041cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.loss_updater->ComputePrimalLoss( 2055e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower example_statistics.prev_wx[0], example_label, example_weight); 2061cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data(example_index, 2) = 2071cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower options.loss_updater->ComputeDualLoss(dual, example_label, 2081cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_weight); 2091cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower example_state_data(example_index, 3) = example_weight; 210f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower } 2111cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower }; 2121cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data, 2131cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower // number of cpus, and cost per example. 2141cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const int64 kCostPerUnit = examples.num_features(); 2151cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower const DeviceBase::CpuWorkerThreads& worker_threads = 2161cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower *context->device()->tensorflow_cpu_worker_threads(); 217f7ed48aac2ab7eeab1198c0e7027a9a9afeb7b5dA. Unique TensorFlower 2181cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower Shard(worker_threads.num_threads, worker_threads.workers, 2191cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower examples.num_examples(), kCostPerUnit, train_step); 2201cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower OP_REQUIRES_OK(context, train_step_status); 2211cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower} 22294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 2231cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower} // namespace 2241cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 2251cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerclass SdcaOptimizer : public OpKernel { 2261cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower public: 2271cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower explicit SdcaOptimizer(OpKernelConstruction* const context) 2281cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower : OpKernel(context), options_(context) {} 2291cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower 230db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner void Compute(OpKernelContext* context) override { 2311cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower DoCompute(options_, context); 23294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower } 23394a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 23494a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower private: 235733321f6159f55474f4d7822b5a1c174d30e696dYuan Yu // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and 2369ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower // template the entire class to avoid the virtual table lookup penalty in 2379ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower // the inner loop. 2381cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower ComputeOptions options_; 23994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower}; 2401cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU), 2411cf9120cec318826f5c079f1500cb3d2dc6cc0edA. Unique TensorFlower SdcaOptimizer); 24294a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower 2439ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerclass SdcaShrinkL1 : public OpKernel { 2449ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public: 2456a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower explicit SdcaShrinkL1(OpKernelConstruction* const context) 2466a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower : OpKernel(context) { 247ad2122584430ca4e4cc97fa642c1536357402eb6A. Unique TensorFlower OP_REQUIRES_OK(context, regularizations_.Initialize(context)); 2489ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 2499ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 250db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner void Compute(OpKernelContext* context) override { 251c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower OpMutableInputList weights_inputs; 252c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower OP_REQUIRES_OK(context, 253c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower context->mutable_input_list("weights", &weights_inputs)); 254c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower 255c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower auto do_work = [&](const int64 begin, const int64 end) { 256c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower for (int i = begin; i < end; ++i) { 257c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>(); 2586e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower prox_w.device(context->eigen_cpu_device()) = 2595e48873977dc8827848eeae6cf0db2ef549cce10A. Unique TensorFlower regularizations_.EigenShrinkVector(prox_w); 2606e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower } 2616e6d0dcde7c59f306c6f276ed153128fe8bdd8bfA. Unique TensorFlower }; 2626d6ba9e3d88cef87c4e0b2753ffe40d939901fa4A. Unique TensorFlower 263c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower if (weights_inputs.size() > 0) { 264c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower int64 num_weights = 0; 265c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower for (int i = 0; i < weights_inputs.size(); ++i) { 266c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements(); 267c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower } 268c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower // TODO(sibyl-Aix6ihai): Tune this value. 269c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size(); 270c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower const DeviceBase::CpuWorkerThreads& worker_threads = 271c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower *context->device()->tensorflow_cpu_worker_threads(); 272c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower Shard(worker_threads.num_threads, worker_threads.workers, 273c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower weights_inputs.size(), kCostPerUnit, do_work); 274c955418332cd3c0ff237c0cb56cda8a61deaf3ddA. Unique TensorFlower } 2759ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 2769ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 2779ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower private: 2789ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower Regularizations regularizations_; 2799ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower}; 2809ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); 2819ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 28201fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// Computes platform independent, compact and unique (with very high 28301fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// probability) representation of an example id. It shouldn't be put in 28401fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// persistent storage, as its implementation may change in the future. 28501fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// 28601fb3ef3652d1040778c45c40b5517c4c474771dA. Unique TensorFlower// The current probability of at least one collision for 1B example_ids is 28787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower// approximately 10^-21 (ie 2^60 / 2^129). 288370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlowerclass SdcaFprint : public OpKernel { 2899ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower public: 2906a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower explicit SdcaFprint(OpKernelConstruction* const context) 2916a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower : OpKernel(context) {} 2929ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 293db43c0961d324c506db65581f2bf0615c25bf68aBenoit Steiner void Compute(OpKernelContext* context) override { 2946a3e47ae7a63985e419a9cd6a620ddb13a0b8721A. Unique TensorFlower const Tensor& input = context->input(0); 29587a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), 29687a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower errors::InvalidArgument("Input must be a vector, got shape ", 29787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower input.shape().DebugString())); 298370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower Tensor* out; 29987a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower const int64 num_elements = input.NumElements(); 30087a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower OP_REQUIRES_OK(context, context->allocate_output( 30187a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower 0, TensorShape({num_elements, 2}), &out)); 3029ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower 303370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlower const auto in_values = input.flat<string>(); 30487a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower auto out_values = out->matrix<int64>(); 30587a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower 30687a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower for (int64 i = 0; i < num_elements; ++i) { 30787a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower const Fprint128 fprint = Fingerprint128(in_values(i)); 30887a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower // Never return 0 or 1 as the first value of the hash to allow these to 30987a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower // safely be used as sentinel values (e.g. dense hash table empty key). 31087a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2) 31187a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower ? fprint.low64 31287a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower : fprint.low64 + ~static_cast<uint64>(1); 31387a61b18f83d4f5ec8a796c2a4d665d3010eac91A. Unique TensorFlower out_values(i, 1) = fprint.high64; 31456f1d64998744ad655fe5c428658a13be35b865eEugene Brevdo } 3159ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower } 3169ba61973849a9ef79104c8295886049634a193b4A. Unique TensorFlower}; 317370a6d4e91ffcaa155dfc72a74ca082c987580f3A. Unique TensorFlowerREGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint); 31856f1d64998744ad655fe5c428658a13be35b865eEugene Brevdo 31994a992cfc3266244b81fe311b805fc1ae3f53f30A. Unique TensorFlower} // namespace tensorflow 320