block_topk_stream_2d.hpp Source File

block_topk_stream_2d.hpp Source File#

Composable Kernel: block_topk_stream_2d.hpp Source File
block_topk_stream_2d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
7
8namespace ck_tile {
9
10/*
11simple 2d topk implementation, along row (dim=1)
12requirement:
13 1). each row is within a warp
14*/
15template <typename Problem_, typename Policy_ = void>
17{
20
21 using DataType = typename Problem::DataType;
22 using IndexType = typename Problem::IndexType;
23
24 // TODO: if DataType is subdword, need pack into single dword to use argmax
30
31 template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
32 CK_TILE_DEVICE void operator()(const DistributedTensor& x,
33 const OutWindow& out_window,
34 const IdxWindow& idx_window,
35 index_t k,
36 number<dim> = {})
37 {
38 OutWindow out_window_tmp = out_window;
39 IdxWindow idx_window_tmp = idx_window;
40 static_assert(
41 std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
42 std::is_same_v<typename DistributedTensor::DataType, DataType>);
43 static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
44
45 DistributedTensor x_tmp = x;
46 constexpr auto dst_dist = typename IdxWindow::TileDstr{};
47
48 // argmax for topk
49 const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
50 return e0.arg > e1.arg ? e0 : e1;
51 };
52
53 for(index_t i_k = 0; i_k < k; i_k++)
54 {
55 constexpr auto span_2d = DistributedTensor::get_distributed_spans();
56 auto packet = [&]() {
57 auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
58
59 sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
60 sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
61 const auto tile_idx = get_x_indices_from_distributed_indices(
62 tmp.get_tile_distribution(), make_tuple(idx0, idx1));
63 constexpr auto i_j_idx = make_tuple(idx0, idx1);
65 t.arg = x_tmp(i_j_idx); // !!! we reference x here
66 t.value = tile_idx.at(number<1>{});
67 tmp(i_j_idx) = t;
68 });
69 });
70 return tmp;
71 }();
72
73 auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
74 auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
75 block_tile_reduce_xor_sync(r, f_argmax);
76
79 sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
80 sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
81 constexpr auto i_j_idx = make_tuple(idx0, idx1);
82 ArgmaxPacket tmp = r(i_j_idx);
83 o(i_j_idx) = tmp.arg;
84 i(i_j_idx) = tmp.value;
85 });
86 });
87
88 // update value
89 sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
90 sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
91 const auto tile_idx = get_x_indices_from_distributed_indices(
92 x.get_tile_distribution(), make_tuple(idx0, idx1));
93 auto col_id = tile_idx.at(number<1>{});
94
95 constexpr auto i_j_idx = make_tuple(idx0, idx1);
96
97 x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
98 : x_tmp(i_j_idx);
99 });
100 });
101
102 if(threadIdx.x % Problem::ColLanes == 0)
103 {
104 store_tile(out_window_tmp, o);
105 store_tile(idx_window_tmp, i);
106 }
107 move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
108 move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
109 }
110 }
111};
112
113} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:132
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition static_distributed_tensor.hpp:159
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_topk_stream_2d.hpp:26
DataType arg
Definition block_topk_stream_2d.hpp:27
index_t value
Definition block_topk_stream_2d.hpp:28
Definition block_topk_stream_2d.hpp:17
remove_cvref_t< Policy_ > Policy
Definition block_topk_stream_2d.hpp:19
CK_TILE_DEVICE void operator()(const DistributedTensor &x, const OutWindow &out_window, const IdxWindow &idx_window, index_t k, number< dim >={})
Definition block_topk_stream_2d.hpp:32
remove_cvref_t< Problem_ > Problem
Definition block_topk_stream_2d.hpp:18
typename Problem::IndexType IndexType
Definition block_topk_stream_2d.hpp:22
typename Problem::DataType DataType
Definition block_topk_stream_2d.hpp:21
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38