BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

BatchedContractionKernel&lt; Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference
BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

GPU kernel for batched tensor contraction operations. More...

#include <batched_contraction_kernel.hpp>

Public Types

using Problem = ck_tile::remove_cvref_t<Problem_>
 Tensor contraction problem specification.
using ADataType
 Data type for input tensor A.
using BDataType
 Data type for input tensor B.
using DsDataType
using EDataType
 Data type for output tensor E.
using TilePartitioner
using GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_>
 GEMM computation pipeline.
using EpiloguePipeline
 Epilogue pipeline for post-GEMM operations.
using UniversalGemmKernel
using KernelArgs

Public Member Functions

CK_TILE_DEVICE void operator() (const KernelArgs &kargs) const

Static Public Member Functions

static CK_TILE_HOST constexpr auto GetKernelName ()
 Returns the kernel name for debugging and profiling purposes.
static CK_TILE_HOST constexpr bool IsSupportedArguments (const KernelArgs &kargs)
 Validates whether the given kernel arguments are supported.
static CK_TILE_HOST constexpr ck_tile::index_t GetSmemSize ()
 Returns the shared memory size required by the kernel.
static CK_TILE_HOST constexpr auto GetBlockSize ()
 Returns the GPU block size for kernel launch.
static CK_TILE_HOST constexpr auto GridSize (const KernelArgs &kargs)
static CK_TILE_HOST constexpr KernelArgs MakeKernelArgs (const BatchedContractionHostArgs< NumDTensor > &host_args)

Static Public Attributes

static constexpr ck_tile::index_t NumDimG = Problem::NumDimG
 Number of batch dimensions.
static constexpr ck_tile::index_t NumDimM
 Number of M (output row) dimensions.
static constexpr ck_tile::index_t NumDimN
 Number of N (output column) dimensions.
static constexpr ck_tile::index_t NumDimK
 Number of K (contraction) dimensions.
static constexpr ck_tile::index_t NumDTensor
 Number of auxiliary input D tensors.
static constexpr ck_tile::index_t kBlockSize
 GPU block size inherited from GEMM kernel.

Detailed Description

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >

GPU kernel for batched tensor contraction operations.

Overview
This kernel performs batched tensor contraction operations using the underlying UniversalGemmKernel. It supports arbitrary tensor dimensionalities (G, M, N, K) and processes multiple batch instances in parallel. Each batch performs: E = epilogue_op(contraction(A, B), D0, D1, ...).
Template Parameters
Problem_Tensor contraction problem specification defining data types and dimensions
TilePartitioner_Tile partitioning strategy for workload distribution
GemmPipeline_GEMM computation pipeline for core matrix operations
EpiloguePipeline_Epilogue pipeline for post-GEMM operations and tensor fusion

Member Typedef Documentation

◆ ADataType

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType
Initial value:
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21

Data type for input tensor A.

◆ BDataType

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType
Initial value:

Data type for input tensor B.

◆ DsDataType

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::DsDataType
Initial value:

Data types for auxiliary input tensors D

◆ EDataType

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EDataType
Initial value:

Data type for output tensor E.

◆ EpiloguePipeline

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline
Initial value:

Epilogue pipeline for post-GEMM operations.

◆ GemmPipeline

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = ck_tile::remove_cvref_t<GemmPipeline_>

GEMM computation pipeline.

◆ KernelArgs

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::KernelArgs
Initial value:
Kernel arguments for batched tensor contraction operations.
Definition batched_contraction_kernel.hpp:189

Kernel argument structure

◆ Problem

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Problem = ck_tile::remove_cvref_t<Problem_>

Tensor contraction problem specification.

◆ TilePartitioner

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner
Initial value:

Tile partitioning strategy for workload distribution

◆ UniversalGemmKernel

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
using BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::UniversalGemmKernel
Initial value:
The Universal GEMM kernel template.
Definition universal_gemm_kernel.hpp:154

Member Function Documentation

◆ GetBlockSize()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetBlockSize ( )
inlinestaticconstexpr

Returns the GPU block size for kernel launch.

Returns
3D block dimensions for GPU kernel execution

◆ GetKernelName()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetKernelName ( )
inlinestaticconstexpr

Returns the kernel name for debugging and profiling purposes.

Returns
Constant string identifier for this kernel

◆ GetSmemSize()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

Returns the shared memory size required by the kernel.

Returns
Shared memory size in bytes

Delegates to underlying GEMM kernel's shared memory requirements

◆ GridSize()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GridSize ( const KernelArgs & kargs)
inlinestaticconstexpr

◆ IsSupportedArguments()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr bool BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::IsSupportedArguments ( const KernelArgs & kargs)
inlinestaticconstexpr

Validates whether the given kernel arguments are supported.

Parameters
kargsKernel arguments to validate
Returns
True if arguments are supported, false otherwise

Checks underlying GEMM kernel support and ensures valid batch dimensions

◆ MakeKernelArgs()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr KernelArgs BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MakeKernelArgs ( const BatchedContractionHostArgs< NumDTensor > & host_args)
inlinestaticconstexpr

◆ operator()()

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( const KernelArgs & kargs) const
inline

Member Data Documentation

◆ kBlockSize

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::kBlockSize
staticconstexpr
Initial value:
=
static constexpr index_t kBlockSize
Definition universal_gemm_kernel.hpp:202

GPU block size inherited from GEMM kernel.

◆ NumDimG

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::NumDimG = Problem::NumDimG
staticconstexpr

Number of batch dimensions.

◆ NumDimK

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::NumDimK
staticconstexpr
Initial value:
=
Problem::NumDimK

Number of K (contraction) dimensions.

◆ NumDimM

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::NumDimM
staticconstexpr
Initial value:
=
Problem::NumDimM

Number of M (output row) dimensions.

◆ NumDimN

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::NumDimN
staticconstexpr
Initial value:
=
Problem::NumDimN

Number of N (output column) dimensions.

◆ NumDTensor

template<typename Problem_, typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
ck_tile::index_t BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::NumDTensor
staticconstexpr
Initial value:
=
Problem::NumDTensor

Number of auxiliary input D tensors.


The documentation for this struct was generated from the following file: