KokkosBatched::Trsm

Defined in header: KokkosBatched_Trsm_Decl.hpp

template <typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag, typename ArgAlgo>
struct SerialTrsm {
  template <typename ScalarType, typename AViewType, typename BViewType>
  KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha, const AViewType &A, const BViewType &B);
};

template <typename MemberType, typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag,
          typename ArgAlgo>
struct TeamTrsm {
  template <typename ScalarType, typename AViewType, typename BViewType>
  KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
                                           const BViewType &B);
};

template <typename MemberType, typename ArgSide, typename ArgUplo, typename ArgTrans, typename ArgDiag,
          typename ArgAlgo>
struct TeamVectorTrsm {
  template <typename ScalarType, typename AViewType, typename BViewType>
  KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member, const ScalarType alpha, const AViewType &A,
                                           const BViewType &B);
};

Solves a system of the linear equations \(op(A) \cdot X = \alpha B\) or \(X \cdot op(A) = \alpha B\) where \(\alpha\) is a scalar, \(X\) and \(B\) are m-by-n matrices, \(A\) is a unit or non-unit, upper or lower triangular matrix and \(op(A)\) is one of \(A\), \(A^T\), or \(A^H\). The matrix \(X\) is overwritten on \(B\).

  1. For a real matrix \(A\), \(op(A)\) is one of \(A\) or \(A^T\). This operation is equivalent to the BLAS routine STRSM or DTRSM for single or double precision.

  2. For a complex matrix \(A\), \(op(A)\) is one of \(A\), \(A^T\), or \(A^H\) This operation is equivalent to the BLAS routine CTRSM or ZTRSM for single or double precision.

Parameters

member:

Kokkos team member handle (only for TeamTrsm and TeamVectorTrsm).

alpha:

Scalar multiplier for \(B\).

A:

Input view containing the upper or lower triangular matrix.

B:

Input/output view containing the right-hand side on input and the solution on output.

Type Requirements

  • MemberType must be a Kokkos team member handle (only for TeamTrsm and TeamVectorTrsm).

  • ArgSide must be one of the following:
    • KokkosBatched::Side::Left to solve a system \(op(A) \cdot X = \alpha B\)

    • KokkosBatched::Side::Right to solve a system \(X \cdot op(A) = \alpha B\)

  • ArgUplo must be one of the following:
    • KokkosBatched::Uplo::Upper for upper triangular solve

    • KokkosBatched::Uplo::Lower for lower triangular solve

  • ArgTrans must be one of the following:
    • KokkosBatched::Trans::NoTranspose for \(op(A) = A\)

    • KokkosBatched::Trans::Transpose for \(op(A) = A^T\)

    • KokkosBatched::Trans::ConjTranspose for \(op(A) = A^H\)

  • ArgDiag must be one of the following:
    • KokkosBatched::Diag::Unit for the unit triangular matrix \(A\)

    • KokkosBatched::Diag::NonUnit for the non-unit triangular matrix \(A\)

  • ArgAlgo must be one of the following:
    • KokkosBatched::Algo::trsm::Blocked for the blocked algorithm

    • KokkosBatched::Algo::trsm::Unblocked for the unblocked algorithm

  • ScalarType must be a built-in arithmetic type like float, double, Kokkos::complex<float>, or Kokkos::complex<double>.

  • AViewType must be a Kokkos View of rank 2 containing the band matrix A

  • BViewType must be a Kokkos View of rank 2 containing the right-hand side that satisfies - std::is_same_v<typename BViewType::value_type, typename BViewType::non_const_value_type> == true

Note

Some combinations of template parameters may not be supported yet.

Example

// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright Contributors to the Kokkos project

#include <Kokkos_Core.hpp>
#include <Kokkos_Random.hpp>
#include <KokkosBatched_Trsm_Decl.hpp>

using ExecutionSpace = Kokkos::DefaultExecutionSpace;

/// \brief Example of batched trsm
/// Solving A * X = alpha * B, where
///   A: [[1,  1],
///       [0,  2]]
///   B: [[1, 1],
///       [1, 1]]
///   alpha: 1.5
///
///   X: [[3/4, 3/4],
///       [3/4, 3/4]]
///
/// This corresponds to the following system of equations:
///    1 x00 + 1 x10 = 1.5
///    1 x01 + 1 x11 = 1.5
///            2 x10 = 1.5
///            2 x11 = 1.5
///
int main(int /*argc*/, char** /*argv*/) {
  Kokkos::initialize();
  {
    using View3DType = Kokkos::View<double***, ExecutionSpace>;
    const int Nb = 10, n = 2;

    // Matrix A and B
    View3DType A("A", Nb, n, n), B("B", Nb, n, n);

    // Lower triangular matrix
    // Initialize A and B
    Kokkos::deep_copy(B, 1.0);
    auto h_A = Kokkos::create_mirror_view(A);

    // Upper triangular matrix
    for (int ib = 0; ib < Nb; ib++) {
      h_A(ib, 0, 0) = 1.0;
      h_A(ib, 0, 1) = 1.0;
      h_A(ib, 1, 0) = 0.0;
      h_A(ib, 1, 1) = 2.0;
    }
    Kokkos::deep_copy(A, h_A);

    // solve A * X = alpha * B with trsm
    const double alpha = 1.5;
    ExecutionSpace exec;
    using policy_type = Kokkos::RangePolicy<ExecutionSpace, Kokkos::IndexType<int>>;
    policy_type policy{exec, 0, Nb};
    Kokkos::parallel_for(
        "trsm", policy, KOKKOS_LAMBDA(int ib) {
          auto sub_A = Kokkos::subview(A, ib, Kokkos::ALL, Kokkos::ALL);
          auto sub_B = Kokkos::subview(B, ib, Kokkos::ALL, Kokkos::ALL);

          // Solve A * X = alpha * B with trsm
          KokkosBatched::SerialTrsm<KokkosBatched::Side::Left, KokkosBatched::Uplo::Upper,
                                    KokkosBatched::Trans::NoTranspose, KokkosBatched::Diag::NonUnit,
                                    KokkosBatched::Algo::Trsm::Unblocked>::invoke(alpha, sub_A, sub_B);
        });

    // Confirm that the results are correct
    auto h_B     = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace{}, B);
    bool correct = true;
    double eps   = 1.0e-12;
    for (int ib = 0; ib < Nb; ib++) {
      if (Kokkos::abs(h_B(ib, 0, 0) - 3.0 / 4.0) > eps) correct = false;
      if (Kokkos::abs(h_B(ib, 0, 1) - 3.0 / 4.0) > eps) correct = false;
      if (Kokkos::abs(h_B(ib, 1, 0) - 3.0 / 4.0) > eps) correct = false;
      if (Kokkos::abs(h_B(ib, 1, 1) - 3.0 / 4.0) > eps) correct = false;
    }

    if (correct) {
      std::cout << "trsm works correctly!" << std::endl;
    }
  }
  Kokkos::finalize();
}

output:

trsm works correctly!