1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_STREAM_EXECUTOR_RNG_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_STREAM_EXECUTOR_RNG_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <limits.h>
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <complex>
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/logging.h"
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/port.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace perftools {
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace gputools {
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Stream;
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename ElemT>
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass DeviceMemory;
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace rng {
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Random-number-generation support interface -- this can be derived from a GPU
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// executor when the underlying platform has an RNG library implementation
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// available. See StreamExecutor::AsRng().
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// When a seed is not specified, the backing RNG will be initialized with the
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// default seed for that implementation.
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-hostile: see StreamExecutor class comment for details on
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// thread-hostility.
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass RngSupport {
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  static const int kMinSeedBytes = 16;
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  static const int kMaxSeedBytes = INT_MAX;
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Releases any random-number-generation resources associated with this
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // support object in the underlying platform implementation.
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual ~RngSupport() {}
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Populates a GPU memory allocation with random values appropriate for the
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // DeviceMemory element type; i.e. populates DeviceMemory<float> with random
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // float values.
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandUniform(Stream *stream,
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                     DeviceMemory<float> *v) = 0;
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandUniform(Stream *stream,
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                     DeviceMemory<double> *v) = 0;
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandUniform(Stream *stream,
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                     DeviceMemory<std::complex<float>> *v) = 0;
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandUniform(Stream *stream,
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                     DeviceMemory<std::complex<double>> *v) = 0;
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Populates a GPU memory allocation with random values sampled from a
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Gaussian distribution with the given mean and standard deviation.
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev,
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                      DeviceMemory<float> *v) {
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(ERROR)
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        << "platform's random number generator does not support gaussian";
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return false;
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool DoPopulateRandGaussian(Stream *stream, double mean,
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                                      double stddev, DeviceMemory<double> *v) {
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(ERROR)
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur        << "platform's random number generator does not support gaussian";
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    return false;
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Specifies the seed used to initialize the RNG.
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // This call does not transfer ownership of the buffer seed; its data should
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // not be altered for the lifetime of this call. At least 16 bytes of seed
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // data must be provided, but not all seed data will necessarily be used.
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // seed: Pointer to seed data. Must not be null.
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // seed_bytes: Size of seed buffer in bytes. Must be >= 16.
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool SetSeed(Stream *stream, const uint8 *seed,
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                       uint64 seed_bytes) = 0;
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected:
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  static bool CheckSeed(const uint8 *seed, uint64 seed_bytes);
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace rng
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace gputools
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace perftools
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_STREAM_EXECUTOR_RNG_H_
96