device_normalization_fwd.hpp Source File

device_normalization_fwd.hpp Source File#

Composable Kernel: device_normalization_fwd.hpp Source File
device_normalization_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14template <typename XDataType,
15 typename GammaDataType,
16 typename BetaDataType,
17 typename YDataType,
18 typename SaveMeanInvStdDataType,
19 typename YElementwiseOperation,
20 index_t Rank,
21 index_t NumReduceDim>
23{
24 virtual std::unique_ptr<BaseArgument>
25 MakeArgumentPointer(const std::vector<index_t> lengths,
26 const std::vector<index_t> xStrides,
27 const std::vector<index_t> gammaStrides,
28 const std::vector<index_t> betaStrides,
29 const std::vector<index_t> yStrides,
30 const std::vector<index_t> saveMeanStrides,
31 const std::vector<index_t> saveInvStdStrides,
32 const std::vector<index_t> reduceDims,
33 double epsilon,
34 const void* p_x,
35 const void* p_gamma,
36 const void* p_beta,
37 void* p_y,
38 void* p_savedMean,
39 void* p_savedInvVar,
40 YElementwiseOperation y_elementwise_op) = 0;
41
42 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43};
44
45template <typename XDataType,
46 typename GammaDataType,
47 typename BetaDataType,
48 typename YDataType,
49 typename SaveMeanInvStdDataType,
50 typename YElementwiseOperation,
51 index_t Rank,
52 index_t NumReduceDim>
53using DeviceNormalizationFwdPtr = std::unique_ptr<DeviceNormalizationFwd<XDataType,
54 GammaDataType,
55 BetaDataType,
56 YDataType,
57 SaveMeanInvStdDataType,
58 YElementwiseOperation,
59 Rank,
60 NumReduceDim>>;
61
62} // namespace device
63} // namespace tensor_operation
64} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > > DeviceNormalizationFwdPtr
Definition device_normalization_fwd.hpp:53
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_normalization_fwd.hpp:23
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_savedMean, void *p_savedInvVar, YElementwiseOperation y_elementwise_op)=0