1bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
3bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFloweryou may not use this file except in compliance with the License.
5bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerYou may obtain a copy of the License at
6bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
7bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
9bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerSee the License for the specific language governing permissions and
13bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerlimitations under the License.
14bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower==============================================================================*/
15bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
16bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include <stdlib.h>
17bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include <initializer_list>
18bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include <iterator>
19bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include <vector>
20bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
21bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/framework/bfloat16.h"
22bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
23bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/framework/tensor_testutil.h"
24bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/framework/types.pb.h"
25bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/graph/node_builder.h"
26bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/lib/strings/stringprintf.h"
27bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/platform/test.h"
28bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#include "tensorflow/core/platform/test_benchmark.h"
29bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
30bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowernamespace tensorflow {
31bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
32bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower// Generate "count" random positive integers (not including zero) with sum
33bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower// "sum". Technique based on one from https://math.stackexchange.com/a/1276225
34bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower// but simplified (especially for zero-based indexing).
35bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerstatic std::vector<int64> GenerateRandomIntsWithSum(int64 sum, int count) {
36bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  CHECK_GE(count, 1);
37bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  CHECK_GE(sum, count);
38bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  std::vector<int64> temp(count);
39bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  for (int i = 0; i + 1 < count; ++i) {
40bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    temp[i] = lrand48() % (sum - count);
41bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  }
42bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  temp[count - 1] = sum - count;
43bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  std::sort(temp.begin(), std::prev(temp.end()));
44bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  std::vector<int64> result(count);
45bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  std::adjacent_difference(temp.begin(), temp.end(), result.begin());
46bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  for (int i = 0; i < count; ++i) {
47bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    ++result[i];
48bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  }
49bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  CHECK(std::all_of(result.begin(), result.end(),
50bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                    [sum](int64 x) { return x >= 1 && x <= sum; }));
51bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  CHECK_EQ(std::accumulate(result.begin(), result.end(), static_cast<int64>(0)),
52bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower           sum);
53bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  CHECK_EQ(result.size(), count);
54bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  return result;
55bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower}
56bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
57bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerstatic Graph* MakeGraph(int split_dim, const std::vector<int64>& size_splits,
58bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                        std::initializer_list<int64> total_size) {
59bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  Graph* g = new Graph(OpRegistry::Global());
60bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  TensorShape in_shape(total_size);
61bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  Tensor in(DataTypeToEnum<float>::value, in_shape);
62bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  in.flat<float>().setRandom();
63bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  Tensor split_dim_tensor = test::AsScalar<int32>(split_dim);
64bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  Tensor size_splits_tensor = test::AsTensor<int64>(size_splits);
65bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  Node* splitv;
66bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  TF_CHECK_OK(NodeBuilder(g->NewName("splitv"), "SplitV")
67bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                  .Input(test::graph::Constant(g, in))
68bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                  .Input(test::graph::Constant(g, size_splits_tensor))
69bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                  .Input(test::graph::Constant(g, split_dim_tensor))
70bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                  .Attr("num_split", static_cast<int64>(size_splits.size()))
71bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                  .Finalize(g, &splitv));
72bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  return g;
73bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower}
74bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
75bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#define BM_SPLITV_1D(num_split, total_size)                                  \
76bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  static void BM_SplitV_1d_##num_split##_##total_size(int iters) {           \
77bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::StopTiming();                                                   \
78bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::ItemsProcessed(static_cast<int64>(iters) * total_size);         \
79bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    auto label =                                                             \
80bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower        strings::Printf("1-D %d chunks totaling %d", num_split, total_size); \
81bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::SetLabel(label);                                                \
82bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::UseRealTime();                                                  \
83bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    auto g = MakeGraph(/* split_dim = */ 0,                                  \
84bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                       GenerateRandomIntsWithSum(total_size, num_split),     \
85bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                       {total_size});                                        \
86bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::StartTiming();                                                  \
87bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    test::Benchmark("cpu", g).Run(iters);                                    \
88bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  }                                                                          \
89bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  BENCHMARK(BM_SplitV_1d_##num_split##_##total_size);
90bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
91bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower#define BM_SPLITV_2D(split_dim, num_split, total_size0, total_size1)          \
92bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  static void                                                                 \
93bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower      BM_SplitV_2d_##split_dim##_##num_split##_##total_size0##_##total_size1( \
94bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower          int iters) {                                                        \
95bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::StopTiming();                                                    \
96bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    std::vector<int64> total_size_vec{total_size0, total_size1};              \
97bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::ItemsProcessed(static_cast<int64>(iters) * total_size0 *         \
98bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                            total_size1);                                     \
99bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    auto label =                                                              \
100bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower        strings::Printf("2-D %d chunks in dim %d totaling (%d * %d)",         \
101bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower                        num_split, split_dim, total_size0, total_size1);      \
102bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::SetLabel(label);                                                 \
103bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::UseRealTime();                                                   \
104bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    auto g = MakeGraph(                                                       \
105bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower        split_dim,                                                            \
106bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower        GenerateRandomIntsWithSum(total_size_vec[split_dim], num_split),      \
107bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower        {total_size0, total_size1});                                          \
108bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    testing::StartTiming();                                                   \
109bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower    test::Benchmark("cpu", g).Run(iters);                                     \
110bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  }                                                                           \
111bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower  BENCHMARK(                                                                  \
112bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower      BM_SplitV_2d_##split_dim##_##num_split##_##total_size0##_##total_size1);
113bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
114bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(5, 20);
115bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(262144, 1000000);
116bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(1, 100000);
117bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(5, 100000);
118bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(5, 250000);
119bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(5, 500000);
120bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(5, 1000000);
121bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(10, 4194304);
122bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(2, 4194304);
123bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(100, 10240);
124bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_1D(32768, 1048576);
125bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
126bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 1024, 10247, 10);
127bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 1024, 100000, 10);
128bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 512, 1024, 256);
129bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 20, 100000, 5);
130bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 2, 7, 524288);
131bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(0, 100, 4096, 512);
132bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
133bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 1024, 15, 10240);
134bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 1024, 10, 100000);
135bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 512, 1024, 2563);
136bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 20, 100000, 52);
137bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 2, 3, 524288);
138bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlowerBM_SPLITV_2D(1, 100, 4096, 512);
139bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower
140bb80ae7f57260f35358eb777ba1511fa23f13609A. Unique TensorFlower}  // namespace tensorflow
141