14template <
typename InDataType,
15 typename ComputeDataType,
17 typename IndexDataType,
21 bool OutputIndex =
false>
49 auto f = [&](
auto n,
auto ho,
auto wo,
auto c) {
50 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
52 IndexDataType current_index = 0;
64 if(hi >= 0 && hi < H && wi >= 0 && wi < W)
68 if constexpr(OutputIndex)
72 v_acc = reduce_op(v_acc, v_in, changed);
75 current_index = flat_index;
80 v_acc = reduce_op(v_acc, v_in);
89 if constexpr(OutputIndex)
91 output_index(n, ho, wo, c) = current_index;
99template <
typename InDataType,
100 typename ComputeDataType,
101 typename OutDataType,
102 typename IndexDataType,
104 typename TensorShape,
105 typename WindowShape,
106 bool OutputIndex =
false>
140 auto f = [&](
auto n,
auto do_,
auto ho,
auto wo,
auto c) {
141 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
143 IndexDataType current_index = 0;
160 if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
162 const ComputeDataType v_in =
165 if constexpr(OutputIndex)
167 IndexDataType flat_index =
169 bool changed =
false;
170 v_acc = reduce_op(v_acc, v_in, changed);
173 current_index = flat_index;
178 v_acc = reduce_op(v_acc, v_in);
188 if constexpr(OutputIndex)
191 output_index(n, do_, ho, wo, c) = current_index;
#define CK_TILE_HOST
Definition config.hpp:40
Definition reduce_operator.hpp:11
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition reference_pool.hpp:22
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition reference_pool.hpp:107
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
Definition tile/host/host_tensor.hpp:336
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition tile/host/host_tensor.hpp:531
Kernel arguments for pooling operations.
Definition pool_kernel.hpp:63
TensorShape output_shape
Definition pool_kernel.hpp:68
WindowShape window_lengths
Definition pool_kernel.hpp:71
WindowShape window_dilations
Definition pool_kernel.hpp:73
WindowShape input_left_pads
Definition pool_kernel.hpp:74
TensorShape input_shape
Definition pool_kernel.hpp:67
WindowShape window_strides
Definition pool_kernel.hpp:72