batched_contraction_kernel.hpp File Reference

batched_contraction_kernel.hpp File Reference#

Composable Kernel: batched_contraction_kernel.hpp File Reference
batched_contraction_kernel.hpp File Reference

Batched Tensor Contraction Operations. More...

Go to the source code of this file.

Classes

struct  BatchedContractionHostArgs< NumDTensor >
struct  BatchedContractionKernelArgs< NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor >
 Kernel arguments for batched tensor contraction operations. More...
struct  BatchedContractionKernel< Problem_, TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
 GPU kernel for batched tensor contraction operations. More...

Detailed Description

Batched Tensor Contraction Operations.

What is Batched Tensor Contraction with Multiple D?

Tensor contraction is a fundamental operation that generalizes matrix multiplication to multi-dimensional tensors. It performs element-wise multiplication and summation over shared dimensions

Beyond pure contraction, this kernel supports multiple auxiliary input tensors (D tensors) that are fused with the contraction result through configurable epilogue operations, enabling efficient computation of complex tensor expressions in a single kernel launch.

Mathematical Formulation

For tensors A and B with arbitrary dimensionalities, the complete operation computes:

E[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = epilogue_op(C, D₀, D₁, D₂, ...)

Where: C[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = Σ_{K₀,K₁,...} A[G₀,G₁,...,M₀,M₁,...,K₀,K₁,...] × B[G₀,G₁,...,N₀,N₁,...,K₀,K₁,...]

Where:

  • G dimensions: Batch dimensions (shared across A, B, and output E)
  • M dimensions: Row dimensions of the output matrix (from tensor A)
  • N dimensions: Column dimensions of the output matrix (from tensor B)
  • K dimensions: Contraction dimensions (summed over, present in both A and B)

Why Tensor Contraction Can Be Implemented Using GEMM

Mathematical Equivalence: Tensor contraction is fundamentally equivalent to matrix multiplication when dimensions are appropriately flattened. The key insight is that the summation operation over shared dimensions (K dimensions) in tensor contraction is mathematically identical to the dot product computation in matrix multiplication.

Dimension Flattening Strategy:

  • M dimensions (from tensor A) → Flattened into matrix rows (M_total)
  • N dimensions (from tensor B) → Flattened into matrix columns (N_total)
  • K dimensions (contraction dims) → Flattened into inner dimension (K_total)
  • G dimensions (batch dims) → Handled through batch processing

Mathematical Transformation: ``` Original: E[g,m₀,m₁,n₀,n₁] = Σ_{k₀,k₁} A[g,m₀,m₁,k₀,k₁] × B[g,n₀,n₁,k₀,k₁] Flattened: E[g,M,N] = Σ_K A[g,M,K] × B[g,N,K] (where M=m₀×m₁, N=n₀×n₁, K=k₀×k₁) GEMM Form: E = A × Bᵀ

Why This Approach Is Optimal: Rather than implementing tensor contraction from scratch, this kernel leverages the highly optimized UniversalGemmKernel as its computational backend.

Current Kernel Limitations

Layout Restrictions:

  • Row-Major Only: All tensors must use row-major memory layout
  • Packed Tensors: Only contiguous/packed tensor layouts supported
  • Hardcoded Strides: stride_A = K_total, stride_B = K_total, stride_E = N_total
  • D Tensor Layout: All D tensors must match E tensor layout (stride_Ds = N_total)

Implementation Constraints:

  • Fixed Stride Calculation: Strides are automatically calculated and cannot be customized
  • No Column-Major: Column-major or custom stride patterns not supported
  • No Strided Access: Non-contiguous tensor slicing not supported

Future Enhancements:

  • Support for arbitrary stride patterns
  • Column-major and mixed layout support
  • Non-contiguous tensor operation support */

namespace ck_tile {

/** Host arguments for batched tensor contraction operations.

Overview
This structure encapsulates all host-side arguments required for batched tensor contraction. It supports arbitrary number of batch dimensions (G), M dimensions, N dimensions, and K dimensions.
Tensor Layout Assumptions
  • A tensor: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
  • B tensor: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
  • D tensors: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (auxiliary input tensors)
  • E tensor: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (output tensor)
Template Parameters
NumDTensorNumber of D (auxiliary input) tensors. Default is 0.