gemm_validation.hpp Source File

gemm_validation.hpp Source File#

Composable Kernel: gemm_validation.hpp Source File
gemm_validation.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <string>
7#include <stdexcept>
9
10namespace ck_tile {
11
12inline void
13validate_stride(std::string Layout, int M, int N, int stride, const std::string& stride_name)
14{
15 if(Layout == "C" && stride < M)
16 {
17 throw std::runtime_error("For ColumnMajor layout, " + stride_name + "(" +
18 std::to_string(stride) + ") must be greater or equal to dim " +
19 std::to_string(M));
20 }
21 if(Layout == "R" && stride < N)
22 {
23 throw std::runtime_error("For RowMajor layout, " + stride_name + "(" +
24 std::to_string(stride) + ") must be greater or equal to dim " +
25 std::to_string(N));
26 }
27}
28
29inline void validate_gemm_stride(std::string a_layout,
30 std::string b_layout,
31 std::string c_layout,
32 int M,
33 int N,
34 int K,
35 int Stride_A,
36 int Stride_B,
37 int Stride_C)
38{
39 // set default stride
40 if(Stride_A <= 0)
41 Stride_A = (a_layout == "R") ? K : M;
42 if(Stride_B <= 0)
43 Stride_B = (b_layout == "R") ? N : K;
44 if(Stride_C <= 0)
45 Stride_C = (c_layout == "R") ? N : M;
46
47 validate_stride(a_layout, M, K, Stride_A, "Stride_A");
48 validate_stride(b_layout, K, N, Stride_B, "Stride_B");
49 validate_stride(c_layout, M, N, Stride_C, "Stride_C");
50}
51} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
void validate_stride(std::string Layout, int M, int N, int stride, const std::string &stride_name)
Definition gemm_validation.hpp:13
void validate_gemm_stride(std::string a_layout, std::string b_layout, std::string c_layout, int M, int N, int K, int Stride_A, int Stride_B, int Stride_C)
Definition gemm_validation.hpp:29
Layout wrapper that performs the tensor descriptor logic.
Definition layout.hpp:24