~*~ GPU Programming ~*~

_____ _____ _ _ / ____|| __ \| | | | | | __ | |__) | | | | | | |_ || ___/| | | | | |__| || | | |__| | \_____||_| \____/
-=-=-=-=-=-=-=-=-=-=-=-=-=-

My Latest CUDA Adventures

tokens_can_float.cu
__global__ void fp32_to_byte_pair_token_kernel(const float* input, uint8_t* output, uint32_t size)
{
    constexpr uint16_t bf16_mantissa_size = 7;
    constexpr uint16_t bf16_exponent_size = 8;
    constexpr uint16_t bf16_sign_position = bf16_mantissa_size + bf16_exponent_size;
    uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(idx < size) 
    {
        __nv_bfloat16 inter = __float2bfloat16(input[idx]);
        const uint16_t unified_token = reinterpret_cast(inter);
        uint8_t mantissa_token = 0;
        uint8_t exponent_token = 0;
        
        // Add mantissa bits to mantissa token.
        for(uint16_t i = 0; i < bf16_mantissa_size; ++i)
        {
            mantissa_token += unified_token & (1u << i);
        }
        // Add exponent bits to exponent token.
        for(uint16_t i = bf16_mantissa_size; i < bf16_sign_position; ++i)
        {
            exponent_token += (unified_token & (1u << i)) >> bf16_mantissa_size;
        }
        // Add sign bit to mantissa token.
        mantissa_token += (unified_token & (1u << bf16_sign_position)) >> bf16_exponent_size;
        
        const uint32_t out_idx = idx * 2;
        output[out_idx] = mantissa_token;
        output[out_idx + 1] = exponent_token;
    }
}

__global__ void byte_pair_token_to_fp32_kernel(const uint8_t* input, float* output, uint32_t size)
{
    constexpr uint16_t bf16_mantissa_size = 7;
    constexpr uint16_t bf16_exponent_size = 8;
    constexpr uint16_t bf16_sign_position = bf16_mantissa_size + bf16_exponent_size;
    const uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
    if(idx < size)
    {
        const uint32_t in_idx = idx * 2;
        const uint8_t mantissa_token = input[in_idx];
        const uint8_t exponent_token = input[in_idx + 1];
        uint16_t unified_token = 0;

        // Merge mantissa bits.
        for(uint16_t i = 0; i < bf16_mantissa_size; ++i)
        {
            unified_token += static_cast(mantissa_token & (1u << i));
        }
        // Merge exponent bits.
        for(uint16_t i = 0; i < bf16_exponent_size; ++i)
        {
            unified_token += static_cast(exponent_token & (1u << i)) << bf16_mantissa_size;
        }
        // Merge sign bit.
        unified_token += static_cast(mantissa_token & (1u << bf16_mantissa_size)) << bf16_exponent_size;

        __nv_bfloat16 inter = reinterpret_cast<__nv_bfloat16&>(unified_token);
        output[idx] = __bfloat162float(inter);;
    }
}
dizzy.cu
__global__ void fourth_degree_sh_encoding(
    uint32_t n,
    // 3-dimensional vector.
    float* __restrict__ input,
    // 16-dimensional vector.
    float* __restrict__ output
){
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if(i >= n) { return; }

    const int input_offset = i * 3;
    const float x = input[input_offset];
    const float y = input[input_offset + 1];
    const float z = input[input_offset + 2];

    const float x_squared = x * x;
    const float y_squared = y * y;
    const float z_squared = z * z;

    const int output_offset = i * 16;
    output[output_offset] = 0.28209479177387814f;
    output[output_offset + 1] = 0.4886025119029199f * y;
    output[output_offset + 2] = 0.4886025119029199f * z;
    output[output_offset + 3] = 0.4886025119029199f * x;
    output[output_offset + 4] = 1.0925484305920792f * x * y;
    output[output_offset + 5] = 1.0925484305920792f * y * z;
    output[output_offset + 6] = 0.9461746957575601f * z_squared - 0.31539156525251999f;
    output[output_offset + 7] = 1.0925484305920792f * x * z;
    output[output_offset + 8] = 0.5462742152960396f * (x_squared - y_squared);
    output[output_offset + 9] = 0.5900435899266435f * y * (3 * x_squared - y_squared);
    output[output_offset + 10] = 2.890611442640554f * x * y * z;
    output[output_offset + 11] = 0.4570457994644658f * y * (5 * z_squared - 1);
    output[output_offset + 12] = 0.3731763325901154f * z * (5 * z_squared - 3);
    output[output_offset + 13] = 0.4570457994644658f * x * (5 * z_squared - 1); 
    output[output_offset + 14] = 1.445305721320277f * z * (x_squared - y_squared); 
    output[output_offset + 15] = 0.5900435899266435f * x * (x_squared - 3 * y_squared);
}
-=-=-=-=-=-=-=-=-=-=-=-=-=-