KokkosBatched::Trsm¶
Defined in header: KokkosBatched_Trsm_Decl.hpp
template <typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag, typename ArgAlgo>
struct SerialTrsm {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B);
};
template <typename MemberType, typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag,
typename ArgAlgo>
struct TeamTrsm {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B);
};
template <typename MemberType, typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag,
typename ArgAlgo>
struct TeamVectorTrsm {
template <typename ScalarType, typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
const BViewType &B);
};
Solves a system of the linear equations \(op(A) \cdot X = \alpha B\) or \(X \cdot op(A) = \alpha B\) where \(\alpha\) is a scalar, \(X\) and \(B\) are m-by-n matrices, \(A\) is a unit or non-unit, upper or lower triangular matrix and \(op(A)\) is one of \(A\), \(A^T\), or \(A^H\). The matrix \(X\) is overwritten on \(B\).
For a real matrix \(A\), \(op(A)\) is one of \(A\) or \(A^T\). This operation is equivalent to the BLAS routine
STRSMorDTRSMfor single or double precision.For a complex matrix \(A\), \(op(A)\) is one of \(A\), \(A^T\), or \(A^H\) This operation is equivalent to the BLAS routine
CTRSMorZTRSMfor single or double precision.
Parameters¶
- member:
Kokkos team member handle (only for
TeamTrsmandTeamVectorTrsm).- alpha:
Scalar multiplier for \(B\).
- A:
Input view containing the upper or lower triangular matrix.
- B:
Input/output view containing the right-hand side on input and the solution on output.
Type Requirements¶
MemberTypemust be a Kokkos team member handle (only forTeamTrsmandTeamVectorTrsm).ArgSidemust be one of the following:KokkosBatched::Side::Leftto solve a system \(op(A) \cdot X = \alpha B\)KokkosBatched::Side::Rightto solve a system \(X \cdot op(A) = \alpha B\)
ArgUplomust be one of the following:KokkosBatched::Uplo::Upperfor upper triangular solveKokkosBatched::Uplo::Lowerfor lower triangular solve
ArgTransmust be one of the following:KokkosBatched::Trans::NoTransposefor \(op(A) = A\)KokkosBatched::Trans::Transposefor \(op(A) = A^T\)KokkosBatched::Trans::ConjTransposefor \(op(A) = A^H\)
ArgDiagmust be one of the following:KokkosBatched::Diag::Unitfor the unit triangular matrix \(A\)KokkosBatched::Diag::NonUnitfor the non-unit triangular matrix \(A\)
ArgAlgomust be one of the following:KokkosBatched::Algo::trsm::Blockedfor the blocked algorithmKokkosBatched::Algo::trsm::Unblockedfor the unblocked algorithm
ScalarTypemust be a built-in arithmetic type likefloat,double,Kokkos::complex<float>, orKokkos::complex<double>.AViewTypemust be a Kokkos View of rank 2 containing the band matrix ABViewTypemust be a Kokkos View of rank 2 containing the right-hand side that satisfies -std::is_same_v<typename BViewType::value_type, typename BViewType::non_const_value_type> == true
Note
Some combinations of template parameters may not be supported yet.
Example¶
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project
#include <Kokkos_Core.hpp>
#include <Kokkos_Random.hpp>
#include <KokkosBatched_Trsm_Decl.hpp>
using ExecutionSpace = Kokkos::DefaultExecutionSpace;
/// \brief Example of batched trsm
/// Solving A * X = alpha * B, where
/// A: [[1, 1],
/// [0, 2]]
/// B: [[1, 1],
/// [1, 1]]
/// alpha: 1.5
///
/// X: [[3/4, 3/4],
/// [3/4, 3/4]]
///
/// This corresponds to the following system of equations:
/// 1 x00 + 1 x10 = 1.5
/// 1 x01 + 1 x11 = 1.5
/// 2 x10 = 1.5
/// 2 x11 = 1.5
///
int main(int /*argc*/, char** /*argv*/) {
Kokkos::initialize();
{
using View3DType = Kokkos::View<double***, ExecutionSpace>;
const int Nb = 10, n = 2;
// Matrix A and B
View3DType A("A", Nb, n, n), B("B", Nb, n, n);
// Lower triangular matrix
// Initialize A and B
Kokkos::deep_copy(B, 1.0);
auto h_A = Kokkos::create_mirror_view(A);
// Upper triangular matrix
for (int ib = 0; ib < Nb; ib++) {
h_A(ib, 0, 0) = 1.0;
h_A(ib, 0, 1) = 1.0;
h_A(ib, 1, 0) = 0.0;
h_A(ib, 1, 1) = 2.0;
}
Kokkos::deep_copy(A, h_A);
// solve A * X = alpha * B with trsm
const double alpha = 1.5;
ExecutionSpace exec;
using policy_type = Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<int>>;
policy_type policy{exec, 0, Nb};
Kokkos::parallel_for(
"trsm", policy, KOKKOS_LAMBDA(int ib) {
auto sub_A = Kokkos::subview(A, ib, Kokkos::ALL, Kokkos::ALL);
auto sub_B = Kokkos::subview(B, ib, Kokkos::ALL, Kokkos::ALL);
// Solve A * X = alpha * B with trsm
KokkosBatched::SerialTrsm<KokkosBatched::Side::Left, KokkosBatched::Uplo::Upper,
KokkosBatched::Trans::NoTranspose, KokkosBatched::Diag::NonUnit,
KokkosBatched::Algo::Trsm::Unblocked>::invoke(alpha, sub_A, sub_B);
});
// Confirm that the results are correct
auto h_B = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, B);
bool correct = true;
double eps = 1.0e-12;
for (int ib = 0; ib < Nb; ib++) {
if (Kokkos::abs(h_B(ib, 0, 0) - 3.0 / 4.0) > eps) correct = false;
if (Kokkos::abs(h_B(ib, 0, 1) - 3.0 / 4.0) > eps) correct = false;
if (Kokkos::abs(h_B(ib, 1, 0) - 3.0 / 4.0) > eps) correct = false;
if (Kokkos::abs(h_B(ib, 1, 1) - 3.0 / 4.0) > eps) correct = false;
}
if (correct) {
std::cout << "trsm works correctly!" << std::endl;
}
}
Kokkos::finalize();
}
output:
trsm works correctly!