1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_STREAM_EXECUTOR_RNG_H_ 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_STREAM_EXECUTOR_RNG_H_ 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <limits.h> 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <complex> 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/logging.h" 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/stream_executor/platform/port.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace perftools { 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace gputools { 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Stream; 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurtemplate <typename ElemT> 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass DeviceMemory; 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace rng { 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Random-number-generation support interface -- this can be derived from a GPU 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// executor when the underlying platform has an RNG library implementation 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// available. See StreamExecutor::AsRng(). 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// When a seed is not specified, the backing RNG will be initialized with the 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// default seed for that implementation. 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-hostile: see StreamExecutor class comment for details on 41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// thread-hostility. 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass RngSupport { 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMinSeedBytes = 16; 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static const int kMaxSeedBytes = INT_MAX; 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Releases any random-number-generation resources associated with this 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // support object in the underlying platform implementation. 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual ~RngSupport() {} 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Populates a GPU memory allocation with random values appropriate for the 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // DeviceMemory element type; i.e. populates DeviceMemory<float> with random 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // float values. 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandUniform(Stream *stream, 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<float> *v) = 0; 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandUniform(Stream *stream, 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<double> *v) = 0; 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandUniform(Stream *stream, 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<std::complex<float>> *v) = 0; 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandUniform(Stream *stream, 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<std::complex<double>> *v) = 0; 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Populates a GPU memory allocation with random values sampled from a 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Gaussian distribution with the given mean and standard deviation. 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandGaussian(Stream *stream, float mean, float stddev, 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DeviceMemory<float> *v) { 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur LOG(ERROR) 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur << "platform's random number generator does not support gaussian"; 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return false; 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool DoPopulateRandGaussian(Stream *stream, double mean, 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur double stddev, DeviceMemory<double> *v) { 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur LOG(ERROR) 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur << "platform's random number generator does not support gaussian"; 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return false; 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Specifies the seed used to initialize the RNG. 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // This call does not transfer ownership of the buffer seed; its data should 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // not be altered for the lifetime of this call. At least 16 bytes of seed 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // data must be provided, but not all seed data will necessarily be used. 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // seed: Pointer to seed data. Must not be null. 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // seed_bytes: Size of seed buffer in bytes. Must be >= 16. 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool SetSeed(Stream *stream, const uint8 *seed, 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur uint64 seed_bytes) = 0; 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected: 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static bool CheckSeed(const uint8 *seed, uint64 seed_bytes); 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace rng 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace gputools 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace perftools 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_STREAM_EXECUTOR_RNG_H_ 96