Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions crates/cuda-backend/cuda/supra/ntt_bitrev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
#include "launcher.cuh"
#include "ntt/ntt.cuh"

constexpr uint32_t LOW_BITREV_BITS = 5;
constexpr uint32_t HIGH_BITREV_BITS = 5;

/// Split the binary representation of each number in the following way:
/// [A = HIGH_BITREV_BITS bits][B = the remaining middle bits][C = LOW_BITREV_BITS bits]
/// This function, given `idx` and `log_n`, returns the `idx`-th number
/// among all numbers from `0` to `2^{log_n} - 1`, sorted lexicographically by (B, A, C).
/// This should go well with the SM's L1 cache.
__device__ __forceinline__ uint32_t choose_index_to_bitrev(uint32_t idx, uint32_t log_n) {
return
(idx & ((1u << LOW_BITREV_BITS) - 1))
| ((idx >> (log_n - HIGH_BITREV_BITS)) << LOW_BITREV_BITS)
| ((idx & ((1u << (log_n - HIGH_BITREV_BITS)) - (1u << LOW_BITREV_BITS))) << (log_n - HIGH_BITREV_BITS - LOW_BITREV_BITS));
}

// Permutes the data in an array such that data[i] = data[bit_reverse(i)]
// and data[bit_reverse(i)] = data[i]
__launch_bounds__(1024) __global__
Expand All @@ -29,6 +44,9 @@ void bit_rev_permutation(fr_t* d_out, const fr_t *d_in, uint32_t lg_domain_size,
d_out[rev] = t;
} else {
index_t idx = threadIdx.x + blockDim.x * (index_t)blockIdx.x;
idx = (lg_domain_size > LOW_BITREV_BITS + HIGH_BITREV_BITS)
? choose_index_to_bitrev(idx, lg_domain_size)
: idx;
index_t rev = bit_rev(idx, lg_domain_size);
bool copy = d_out != d_in && idx == rev;

Expand Down