MoeSortingKernel< Problem_ > Struct Template Reference

MoeSortingKernel&lt; Problem_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::MoeSortingKernel< Problem_ > Struct Template Reference
ck_tile::MoeSortingKernel< Problem_ > Struct Template Reference

#include <moe_sorting_kernel.hpp>

Classes

struct  Kargs
struct  simple_smem_indexer

Public Types

using Problem = remove_cvref_t<Problem_>
using IndexType = typename Problem::IndexType
using WeightType = typename Problem::WeightType
typedef MoeSortingHostArgs MoeSortingKargs
using Hargs = MoeSortingHostArgs

Public Member Functions

template<typename data_t, int wave_size>
__device__ void wave_cumsum (data_t &thread_data) const
CK_TILE_DEVICE index_t calc_index (index_t total_col, index_t row, index_t col) const
CK_TILE_DEVICE void moe_buf_set_zero_kernel (uint8x16_t *buf, long_index_t buf_bytes) const
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d (void *buf, index_t row, index_t col, index_t elem_bytes) const
CK_TILE_DEVICE void moe_align_block_size_kernel (const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens_per_thread, const index_t numel, const mdiv unit_size_mdiv, const mdiv topk_mdiv, void *smem) const
CK_TILE_DEVICE void moe_align_block_size_kernel_ex (const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, const IndexType *__restrict__ local_expert_mask, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens, const mdiv unit_size_mdiv, const mdiv topk_mdiv, const mdiv expert_mdiv, const index_t smem_rows, void *smem) const
CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST constexpr auto get_num_cu ()
static CK_TILE_HOST constexpr auto GridSize (const Hargs &h)
static CK_TILE_HOST constexpr auto BlockSize (const Hargs &h)
static CK_TILE_HOST constexpr auto GetSmemSize (const Hargs &h)
static CK_TILE_HOST constexpr auto MakeKargs (const Hargs &h)
template<typename T, typename F, index_t wave_size_ = get_warp_size()>
static __device__ constexpr T wave_reduce (T local, F reduce_f, number< wave_size_ >={})

Static Public Attributes

static constexpr index_t kBlockSize = 256
static constexpr index_t OCCUPANCY = 2

Member Typedef Documentation

◆ Hargs

template<typename Problem_>
using ck_tile::MoeSortingKernel< Problem_ >::Hargs = MoeSortingHostArgs

◆ IndexType

template<typename Problem_>
using ck_tile::MoeSortingKernel< Problem_ >::IndexType = typename Problem::IndexType

◆ MoeSortingKargs

template<typename Problem_>
typedef MoeSortingHostArgs ck_tile::MoeSortingKernel< Problem_ >::MoeSortingKargs

◆ Problem

template<typename Problem_>
using ck_tile::MoeSortingKernel< Problem_ >::Problem = remove_cvref_t<Problem_>

◆ WeightType

template<typename Problem_>
using ck_tile::MoeSortingKernel< Problem_ >::WeightType = typename Problem::WeightType

Member Function Documentation

◆ BlockSize()

template<typename Problem_>
CK_TILE_HOST constexpr auto ck_tile::MoeSortingKernel< Problem_ >::BlockSize ( const Hargs & h)
inlinestaticconstexpr

◆ calc_index()

template<typename Problem_>
CK_TILE_DEVICE index_t ck_tile::MoeSortingKernel< Problem_ >::calc_index ( index_t total_col,
index_t row,
index_t col ) const
inline

◆ get_num_cu()

template<typename Problem_>
CK_TILE_HOST constexpr auto ck_tile::MoeSortingKernel< Problem_ >::get_num_cu ( )
inlinestaticconstexpr

◆ GetSmemSize()

template<typename Problem_>
CK_TILE_HOST constexpr auto ck_tile::MoeSortingKernel< Problem_ >::GetSmemSize ( const Hargs & h)
inlinestaticconstexpr

◆ GridSize()

template<typename Problem_>
CK_TILE_HOST constexpr auto ck_tile::MoeSortingKernel< Problem_ >::GridSize ( const Hargs & h)
inlinestaticconstexpr

◆ MakeKargs()

template<typename Problem_>
CK_TILE_HOST constexpr auto ck_tile::MoeSortingKernel< Problem_ >::MakeKargs ( const Hargs & h)
inlinestaticconstexpr

◆ moe_align_block_size_kernel()

template<typename Problem_>
CK_TILE_DEVICE void ck_tile::MoeSortingKernel< Problem_ >::moe_align_block_size_kernel ( const IndexType *__restrict__ topk_id,
const WeightType *__restrict__ weights,
index_t * p_sorted_token_ids,
WeightType * p_sorted_weights,
index_t * p_sorted_expert_ids,
index_t * p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens_per_thread,
const index_t numel,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
void * smem ) const
inline

◆ moe_align_block_size_kernel_ex()

template<typename Problem_>
CK_TILE_DEVICE void ck_tile::MoeSortingKernel< Problem_ >::moe_align_block_size_kernel_ex ( const IndexType *__restrict__ topk_id,
const WeightType *__restrict__ weights,
const IndexType *__restrict__ local_expert_mask,
index_t * p_sorted_token_ids,
WeightType * p_sorted_weights,
index_t * p_sorted_expert_ids,
index_t * p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
const index_t smem_rows,
void * smem ) const
inline

◆ moe_buf_set_zero_kernel()

template<typename Problem_>
CK_TILE_DEVICE void ck_tile::MoeSortingKernel< Problem_ >::moe_buf_set_zero_kernel ( uint8x16_t * buf,
long_index_t buf_bytes ) const
inline

◆ moe_buf_set_zero_kernel_2d()

template<typename Problem_>
CK_TILE_DEVICE void ck_tile::MoeSortingKernel< Problem_ >::moe_buf_set_zero_kernel_2d ( void * buf,
index_t row,
index_t col,
index_t elem_bytes ) const
inline

◆ operator()()

template<typename Problem_>
CK_TILE_DEVICE void ck_tile::MoeSortingKernel< Problem_ >::operator() ( Kargs kargs) const
inline

◆ wave_cumsum()

template<typename Problem_>
template<typename data_t, int wave_size>
__device__ void ck_tile::MoeSortingKernel< Problem_ >::wave_cumsum ( data_t & thread_data) const
inline

◆ wave_reduce()

template<typename Problem_>
template<typename T, typename F, index_t wave_size_ = get_warp_size()>
__device__ constexpr T ck_tile::MoeSortingKernel< Problem_ >::wave_reduce ( T local,
F reduce_f,
number< wave_size_ > = {} )
inlinestaticconstexpr

Member Data Documentation

◆ kBlockSize

template<typename Problem_>
index_t ck_tile::MoeSortingKernel< Problem_ >::kBlockSize = 256
staticconstexpr

◆ OCCUPANCY

template<typename Problem_>
index_t ck_tile::MoeSortingKernel< Problem_ >::OCCUPANCY = 2
staticconstexpr

The documentation for this struct was generated from the following file: