10 throw std::runtime_error(
"Host tensor is not rank 2 tensor.");
14 if(aqk_ % block_aq_k != 0)
16 throw std::runtime_error(
"shuffle_aq needs a aqk of multiple times of block_aq_k.");
19 std::copy(t->
begin(), t->
end(), t_view.begin());
23template <
typename GemmConfig,
typename T>
29 constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
31 GemmConfig::N_Warp_Tile,
32 k_ / GemmConfig::K_Warp_Tile,
34 GemmConfig::K_Warp_Tile / divisor});
35 std::copy(t.
begin(), t.
end(), t_view.begin());
39template <
typename GemmConfig,
typename T>
46 constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
49 {n_ / GemmConfig::N_Tile, GemmConfig::N_Warp, GemmConfig::N_Warp_Tile, NRepeat, bqk_});
50 std::copy(t.
begin(), t.
end(), t_view.begin());
54template <
typename GemmConfig,
typename T>
61 constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
62 constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
66 GemmConfig::N_Warp_Tile,
68 k_ / GemmConfig::K_Warp_Tile,
70 GemmConfig::K_Warp_Tile / divisor});
72 std::copy(t.
begin(), t.
end(), t_view.begin());
Definition tile/core/algorithm/cluster_descriptor.hpp:13
auto shuffle_b(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:24
auto shuffle_bq_permuteN(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:40
auto shuffle_b_permuteN(const ck_tile::HostTensor< T > &t)
Definition tensor_shuffle_utils.hpp:55
auto shuffle_aq(const ck_tile::HostTensor< T > *t, int block_aq_k)
Definition tensor_shuffle_utils.hpp:6
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition reference_permute.hpp:19
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
Data::iterator end()
Definition tile/host/host_tensor.hpp:589
Data::iterator begin()
Definition tile/host/host_tensor.hpp:587