Program Listing for File rocrand_discrete.h

Return to documentation for file (library/include/rocrand/rocrand_discrete.h)

// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

#ifndef ROCRAND_DISCRETE_H_
#define ROCRAND_DISCRETE_H_

#ifndef FQUALIFIERS
#define FQUALIFIERS __forceinline__ __device__
#endif // FQUALIFIERS

#include <math.h>

#include "rocrand/rocrand_lfsr113.h"
#include "rocrand/rocrand_mrg31k3p.h"
#include "rocrand/rocrand_mrg32k3a.h"
#include "rocrand/rocrand_mtgp32.h"
#include "rocrand/rocrand_philox4x32_10.h"
#include "rocrand/rocrand_scrambled_sobol32.h"
#include "rocrand/rocrand_scrambled_sobol64.h"
#include "rocrand/rocrand_sobol32.h"
#include "rocrand/rocrand_sobol64.h"
#include "rocrand/rocrand_threefry2x32_20.h"
#include "rocrand/rocrand_threefry2x64_20.h"
#include "rocrand/rocrand_threefry4x32_20.h"
#include "rocrand/rocrand_threefry4x64_20.h"
#include "rocrand/rocrand_xorwow.h"

#include "rocrand/rocrand_discrete_types.h"
#include "rocrand/rocrand_normal.h"
#include "rocrand/rocrand_uniform.h"

// Alias method
//
// Walker, A. J.
// An Efficient Method for Generating Discrete Random Variables with General Distributions, 1977
//
// Vose M. D.
// A Linear Algorithm For Generating Random Numbers With a Given Distribution, 1991

namespace rocrand_device {
namespace detail {

FQUALIFIERS unsigned int discrete_alias(const double       x,
                                        const unsigned int size,
                                        const unsigned int offset,
                                        const unsigned int* __restrict__ alias,
                                        const double* __restrict__ probability)
{
    // Calculate value using Alias table

    // x is [0, 1)
    const double       nx  = size * x;
    const double       fnx = floor(nx);
    const double       y   = nx - fnx;
    const unsigned int i   = static_cast<unsigned int>(fnx);
    return offset + (y < probability[i] ? i : alias[i]);
}

FQUALIFIERS unsigned int discrete_alias(const double x, const rocrand_discrete_distribution_st& dis)
{
    return discrete_alias(x, dis.size, dis.offset, dis.alias, dis.probability);
}

FQUALIFIERS
unsigned int discrete_alias(const unsigned int r, const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_32 = ROCRAND_2POW32_INV_DOUBLE;
    const double x = r * inv_double_32;
    return discrete_alias(x, dis);
}

// To prevent ambiguity compile error when compiler is facing the type "unsigned long"!!!
FQUALIFIERS unsigned int discrete_alias(const unsigned long                     r,
                                        const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_32 = ROCRAND_2POW32_INV_DOUBLE;
    const double x = r * inv_double_32;
    return discrete_alias(x, dis);
}

FQUALIFIERS unsigned int discrete_alias(const unsigned long long int            r,
                                        const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_64 = ROCRAND_2POW64_INV_DOUBLE;
    const double x = r * inv_double_64;
    return discrete_alias(x, dis);
}

FQUALIFIERS unsigned int discrete_cdf(const double       x,
                                      const unsigned int size,
                                      const unsigned int offset,
                                      const double* __restrict__ cdf)
{
    // Calculate value using binary search in CDF

    unsigned int min = 0;
    unsigned int max = size - 1;
    do
    {
        const unsigned int center = (min + max) / 2;
        const double       p      = cdf[center];
        if(x > p)
        {
            min = center + 1;
        }
        else
        {
            max = center;
        }
    }
    while(min != max);

    return offset + min;
}

FQUALIFIERS unsigned int discrete_cdf(const double x, const rocrand_discrete_distribution_st& dis)
{
    return discrete_cdf(x, dis.size, dis.offset, dis.cdf);
}

FQUALIFIERS
unsigned int discrete_cdf(const unsigned int r, const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_32 = ROCRAND_2POW32_INV_DOUBLE;
    const double x = r * inv_double_32;
    return discrete_cdf(x, dis);
}

// To prevent ambiguity compile error when compiler is facing the type "unsigned long"!!!
FQUALIFIERS unsigned int discrete_cdf(const unsigned long                     r,
                                      const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_32 = ROCRAND_2POW32_INV_DOUBLE;
    const double x = r * inv_double_32;
    return discrete_cdf(x, dis);
}

FQUALIFIERS unsigned int discrete_cdf(const unsigned long long int            r,
                                      const rocrand_discrete_distribution_st& dis)
{
    constexpr double inv_double_64 = ROCRAND_2POW64_INV_DOUBLE;
    const double x = r * inv_double_64;
    return discrete_cdf(x, dis);
}

} // end namespace detail
} // end namespace rocrand_device

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_philox4x32_10 * state, const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_alias(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
uint4 rocrand_discrete4(rocrand_state_philox4x32_10 * state, const rocrand_discrete_distribution discrete_distribution)
{
    const uint4 u4 = rocrand4(state);
    return uint4 {
        rocrand_device::detail::discrete_alias(u4.x, *discrete_distribution),
        rocrand_device::detail::discrete_alias(u4.y, *discrete_distribution),
        rocrand_device::detail::discrete_alias(u4.z, *discrete_distribution),
        rocrand_device::detail::discrete_alias(u4.w, *discrete_distribution)
    };
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_mrg31k3p*             state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_alias(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_mrg32k3a * state, const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_alias(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_xorwow * state, const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_alias(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_mtgp32 * state, const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_sobol32 * state, const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_scrambled_sobol32*    state,
                              const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_sobol64*              state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_scrambled_sobol64*    state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS
unsigned int rocrand_discrete(rocrand_state_lfsr113*              state,
                              const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_threefry2x32_20*      state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_threefry2x64_20*      state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_threefry4x32_20*      state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

FQUALIFIERS unsigned int rocrand_discrete(rocrand_state_threefry4x64_20*      state,
                                          const rocrand_discrete_distribution discrete_distribution)
{
    return rocrand_device::detail::discrete_cdf(rocrand(state), *discrete_distribution);
}

 // end of group rocranddevice
#endif // ROCRAND_DISCRETE_H_