Write A Program For Strassens Matrix Multiplication In C++
Strassen's matrix multiplication is an efficient divide-and-conquer algorithm for multiplying matrices, offering a better time complexity than the standard approach for large matrices. In this article, you will learn how to implement Strassen's algorithm in C++, comparing it with the traditional method and understanding its underlying principles.
Problem Statement
Matrix multiplication is a fundamental operation in linear algebra, with wide applications in computer graphics, scientific computing, and data analysis. Given two matrices, A (of size $m \times n$) and B (of size $n \times p$), their product C (of size $m \times p$) is defined as: $C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$
The traditional (naive) algorithm computes each element of C by iterating through rows, columns, and intermediate products, resulting in a time complexity of $O(n^3)$ for two square matrices of size $n \times n$. For very large matrices, this cubic complexity can lead to significant computational overhead.
Example
Consider two $2 \times 2$ matrices, A and B:
$A = \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}$
$B = \begin{pmatrix} 5 & 6 \\ 7 & 8 \end{pmatrix}$
Using the standard matrix multiplication method, their product C would be:
$C_{11} = (1 \times 5) + (2 \times 7) = 5 + 14 = 19$ $C_{12} = (1 \times 6) + (2 \times 8) = 6 + 16 = 22$ $C_{21} = (3 \times 5) + (4 \times 7) = 15 + 28 = 43$ $C_{22} = (3 \times 6) + (4 \times 8) = 18 + 32 = 50$
The resulting matrix C is:
$C = \begin{pmatrix} 19 & 22 \\ 43 & 50 \end{pmatrix}$
Background & Knowledge Prerequisites
To understand and implement Strassen's algorithm, you should be familiar with:
- C++ Basics: Fundamental syntax, data types, functions, and control structures.
- Dynamic Memory Allocation: Understanding how to use
std::vectorfor dynamic arrays or raw pointers for matrices. - Recursion: Strassen's algorithm is inherently recursive, breaking down problems into smaller subproblems.
- Basic Matrix Operations: Addition, subtraction, and partitioning of matrices.
- Divide and Conquer Paradigm: The strategy of breaking a problem into subproblems, solving them independently, and combining their results.
Use Cases or Case Studies
Matrix multiplication is a foundational operation with diverse applications:
- Computer Graphics: Used extensively for transformations (scaling, rotation, translation) of 3D objects.
- Machine Learning: Central to training neural networks, support vector machines, and other models, especially in deep learning architectures.
- Scientific Computing: Essential in simulations, numerical analysis (e.g., solving systems of linear equations, finite element methods), and quantum mechanics.
- Image Processing: Applied in filters, transformations, and compression algorithms for images and videos.
- Optimization Problems: Used in algorithms like the Simplex method and various graph algorithms.
Solution Approaches
We will explore two approaches: the naive matrix multiplication and Strassen's algorithm.
Approach 1: Naive Matrix Multiplication
This is the standard approach, directly implementing the mathematical definition of matrix multiplication.
- One-line summary: Iterates through rows, columns, and intermediate sums to compute each element of the product matrix.
// Naive Matrix Multiplication
#include <iostream>
#include <vector>
// Helper function to print a matrix
void printMatrix(const std::vector<std::vector<int>>& matrix) {
// Step 1: Iterate through each row of the matrix.
for (const auto& row : matrix) {
// Step 2: Iterate through each element in the current row.
for (int val : row) {
// Step 3: Print the element followed by a space.
std::cout << val << " ";
}
// Step 4: Move to the next line after printing all elements in a row.
std::cout << std::endl;
}
}
// Function to perform naive matrix multiplication
std::vector<std::vector<int>> multiplyNaive(const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int n = A.size(); // Assuming square matrices of size n x n
// Step 1: Initialize the result matrix C with dimensions n x n, filled with zeros.
std::vector<std::vector<int>> C(n, std::vector<int>(n, 0));
// Step 2: Iterate through rows of matrix A.
for (int i = 0; i < n; ++i) {
// Step 3: Iterate through columns of matrix B.
for (int j = 0; j < n; ++j) {
// Step 4: Iterate through elements for the sum of products.
for (int k = 0; k < n; ++k) {
// Step 5: Accumulate the product A[i][k] * B[k][j] into C[i][j].
C[i][j] += A[i][k] * B[k][j];
}
}
}
// Step 6: Return the resulting product matrix C.
return C;
}
int main() {
// Step 1: Define matrices A and B (2x2 example).
std::vector<std::vector<int>> A = {{1, 2}, {3, 4}};
std::vector<std::vector<int>> B = {{5, 6}, {7, 8}};
int n = A.size();
if (n == 0 || A[0].size() != n || B.size() != n || B[0].size() != n) {
std::cout << "Invalid matrix dimensions for square matrices." << std::endl;
return 1;
}
// Step 2: Print original matrices for clarity.
std::cout << "Matrix A:" << std::endl;
printMatrix(A);
std::cout << "\\nMatrix B:" << std::endl;
printMatrix(B);
// Step 3: Perform naive matrix multiplication.
std::vector<std::vector<int>> C = multiplyNaive(A, B);
// Step 4: Print the result matrix.
std::cout << "\\nResult of Naive Multiplication (C = A * B):" << std::endl;
printMatrix(C);
return 0;
}
- Sample Output:
Matrix A:
1 2
3 4
Matrix B:
5 6
7 8
Result of Naive Multiplication (C = A * B):
19 22
43 50
- Stepwise Explanation:
- The
printMatrixfunction simply iterates through rows and columns to display the matrix elements. multiplyNaiveinitializes a result matrixCof the same dimensions asAandB, filled with zeros.- It then uses three nested loops:
- The outer loop (
i) iterates through the rows of matrixA.
- The outer loop (
j) iterates through the columns of matrix B.k) calculates the sum of products for the element C[i][j].- Each
C[i][j]element is computed by summingA[i][k] * B[k][j]for allkfrom0ton-1. - Finally, the computed matrix
Cis returned.
Approach 2: Strassen's Matrix Multiplication
Strassen's algorithm is a recursive, divide-and-conquer approach that reduces the number of multiplications required from 8 (in the standard method for 2x2 blocks) to 7. This seemingly small reduction leads to a better overall time complexity of $O(n^{\log_2 7})$ which is approximately $O(n^{2.807})$.
- One-line summary: A recursive, divide-and-conquer algorithm that multiplies matrices by cleverly rearranging sub-matrix additions and subtractions to reduce the total number of recursive multiplications from eight to seven.
Key Idea: Strassen's 7 Products for $2 \times 2$ Matrices
To multiply two $2 \times 2$ matrices: $A = \begin{pmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{pmatrix}$ $B = \begin{pmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{pmatrix}$
The product $C = A \times B$ would normally require 8 multiplications and 4 additions to get its sub-matrices: $C_{11} = A_{11}B_{11} + A_{12}B_{21}$ $C_{12} = A_{11}B_{12} + A_{12}B_{22}$ $C_{21} = A_{21}B_{11} + A_{22}B_{21}$ $C_{22} = A_{21}B_{12} + A_{22}B_{22}$
Strassen's algorithm computes the 7 products (M1 to M7) as follows:
$M_1 = (A_{11} + A_{22}) (B_{11} + B_{22})$ $M_2 = (A_{21} + A_{22}) B_{11}$ $M_3 = A_{11} (B_{12} - B_{22})$ $M_4 = A_{22} (B_{21} - B_{11})$ $M_5 = (A_{11} + A_{12}) B_{22}$ $M_6 = (A_{21} - A_{11}) (B_{11} + B_{12})$ $M_7 = (A_{12} - A_{22}) (B_{21} + B_{22})$
Then, the sub-matrices of C are calculated using these 7 products with 18 additions/subtractions:
$C_{11} = M_1 + M_4 - M_5 + M_7$ $C_{12} = M_3 + M_5$ $C_{21} = M_2 + M_4$ $C_{22} = M_1 - M_2 + M_3 + M_6$
For general $n \times n$ matrices, this process is applied recursively by dividing the matrices into $n/2 \times n/2$ sub-matrices until the base case (e.g., $1 \times 1$ matrices) is reached. If matrix dimensions are not powers of 2, they must be padded with zeros to the next power of 2.
// Strassen's Matrix Multiplication
#include <iostream>
#include <vector>
#include <cmath> // For std::pow, std::ceil
#include <algorithm> // For std::max
// Helper function to print a matrix
void printMatrix(const std::vector<std::vector<int>>& matrix, int n) {
// Step 1: Iterate through each row up to the actual size 'n'.
for (int i = 0; i < n; ++i) {
// Step 2: Iterate through each element in the current row up to 'n'.
for (int j = 0; j < n; ++j) {
// Step 3: Print the element followed by a space.
std::cout << matrix[i][j] << " ";
}
// Step 4: Move to the next line after printing all elements in a row.
std::cout << std::endl;
}
}
// Helper function to add two matrices
std::vector<std::vector<int>> add(const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B, int size) {
// Step 1: Initialize result matrix C with dimensions 'size' x 'size'.
std::vector<std::vector<int>> C(size, std::vector<int>(size));
// Step 2: Iterate through rows and columns.
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j) {
// Step 3: Perform element-wise addition.
C[i][j] = A[i][j] + B[i][j];
}
}
// Step 4: Return the sum matrix.
return C;
}
// Helper function to subtract two matrices
std::vector<std::vector<int>> subtract(const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B, int size) {
// Step 1: Initialize result matrix C with dimensions 'size' x 'size'.
std::vector<std::vector<int>> C(size, std::vector<int>(size));
// Step 2: Iterate through rows and columns.
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j) {
// Step 3: Perform element-wise subtraction.
C[i][j] = A[i][j] - B[i][j];
}
}
// Step 4: Return the difference matrix.
return C;
}
// Helper function to split a matrix into four sub-matrices
void split(const std::vector<std::vector<int>>& P,
std::vector<std::vector<int>>& P11, std::vector<std::vector<int>>& P12,
std::vector<std::vector<int>>& P21, std::vector<std::vector<int>>& P22, int size) {
int halfSize = size / 2;
// Step 1: Iterate through the rows of the parent matrix P.
for (int i = 0; i < halfSize; ++i) {
// Step 2: Iterate through the columns of the parent matrix P.
for (int j = 0; j < halfSize; ++j) {
// Step 3: Assign elements to the top-left sub-matrix P11.
P11[i][j] = P[i][j];
// Step 4: Assign elements to the top-right sub-matrix P12.
P12[i][j] = P[i][j + halfSize];
// Step 5: Assign elements to the bottom-left sub-matrix P21.
P21[i][j] = P[i + halfSize][j];
// Step 6: Assign elements to the bottom-right sub-matrix P22.
P22[i][j] = P[i + halfSize][j + halfSize];
}
}
}
// Helper function to join four sub-matrices into a single matrix
void join(const std::vector<std::vector<int>>& C11, const std::vector<std::vector<int>>& C12,
const std::vector<std::vector<int>>& C21, const std::vector<std::vector<int>>& C22,
std::vector<std::vector<int>>& C, int size) {
int halfSize = size / 2;
// Step 1: Iterate through the rows of the result matrix C.
for (int i = 0; i < halfSize; ++i) {
// Step 2: Iterate through the columns of the result matrix C.
for (int j = 0; j < halfSize; ++j) {
// Step 3: Assign C11 elements to the top-left quadrant of C.
C[i][j] = C11[i][j];
// Step 4: Assign C12 elements to the top-right quadrant of C.
C[i][j + halfSize] = C12[i][j];
// Step 5: Assign C21 elements to the bottom-left quadrant of C.
C[i + halfSize][j] = C21[i][j];
// Step 6: Assign C22 elements to the bottom-right quadrant of C.
C[i + halfSize][j + halfSize] = C22[i][j];
}
}
}
// Recursive Strassen's Matrix Multiplication
std::vector<std::vector<int>> strassen_multiply(const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B, int size) {
// Step 1: Base case for recursion: if matrix size is 1x1, perform direct multiplication.
if (size == 1) {
std::vector<std::vector<int>> C(1, std::vector<int>(1));
C[0][0] = A[0][0] * B[0][0];
return C;
}
// Step 2: Calculate half the current matrix size.
int halfSize = size / 2;
// Step 3: Initialize sub-matrices for A and B.
std::vector<std::vector<int>> A11(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A12(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A21(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> A22(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B11(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B12(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B21(halfSize, std::vector<int>(halfSize));
std::vector<std::vector<int>> B22(halfSize, std::vector<int>(halfSize));
// Step 4: Split A and B into their respective sub-matrices.
split(A, A11, A12, A21, A22, size);
split(B, B11, B12, B21, B22, size);
// Step 5: Compute the 7 products (M1 to M7) recursively.
// M1 = (A11 + A22) * (B11 + B22)
std::vector<std::vector<int>> M1 = strassen_multiply(add(A11, A22, halfSize), add(B11, B22, halfSize), halfSize);
// M2 = (A21 + A22) * B11
std::vector<std::vector<int>> M2 = strassen_multiply(add(A21, A22, halfSize), B11, halfSize);
// M3 = A11 * (B12 - B22)
std::vector<std::vector<int>> M3 = strassen_multiply(A11, subtract(B12, B22, halfSize), halfSize);
// M4 = A22 * (B21 - B11)
std::vector<std::vector<int>> M4 = strassen_multiply(A22, subtract(B21, B11, halfSize), halfSize);
// M5 = (A11 + A12) * B22
std::vector<std::vector<int>> M5 = strassen_multiply(add(A11, A12, halfSize), B22, halfSize);
// M6 = (A21 - A11) * (B11 + B12)
std::vector<std::vector<int>> M6 = strassen_multiply(subtract(A21, A11, halfSize), add(B11, B12, halfSize), halfSize);
// M7 = (A12 - A22) * (B21 + B22)
std::vector<std::vector<int>> M7 = strassen_multiply(subtract(A12, A22, halfSize), add(B21, B22, halfSize), halfSize);
// Step 6: Compute the four sub-matrices of the result C (C11, C12, C21, C22).
// C11 = M1 + M4 - M5 + M7
std::vector<std::vector<int>> C11 = add(subtract(add(M1, M4, halfSize), M5, halfSize), M7, halfSize);
// C12 = M3 + M5
std::vector<std::vector<int>> C12 = add(M3, M5, halfSize);
// C21 = M2 + M4
std::vector<std::vector<int>> C21 = add(M2, M4, halfSize);
// C22 = M1 - M2 + M3 + M6
std::vector<std::vector<int>> C22 = add(subtract(add(M1, M3, halfSize), M2, halfSize), M6, halfSize);
// Step 7: Initialize the final result matrix C.
std::vector<std::vector<int>> C(size, std::vector<int>(size));
// Step 8: Join the four computed sub-matrices into C.
join(C11, C12, C21, C22, C, size);
// Step 9: Return the final product matrix C.
return C;
}
// Function to get the next power of 2 for padding
int getNextPowerOf2(int n) {
if (n == 0) return 1;
if ((n & (n - 1)) == 0) return n; // Already a power of 2
return std::pow(2, std::ceil(std::log2(n)));
}
// Wrapper function to handle padding for Strassen's algorithm
std::vector<std::vector<int>> strassen_wrapper(const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B,
int original_n) {
int new_n = getNextPowerOf2(original_n);
// Step 1: Create padded matrices Ap and Bp.
std::vector<std::vector<int>> Ap(new_n, std::vector<int>(new_n, 0));
std::vector<std::vector<int>> Bp(new_n, std::vector<int>(new_n, 0));
// Step 2: Copy original matrices into padded matrices.
for (int i = 0; i < original_n; ++i) {
for (int j = 0; j < original_n; ++j) {
Ap[i][j] = A[i][j];
Bp[i][j] = B[i][j];
}
}
// Step 3: Perform Strassen's multiplication on padded matrices.
std::vector<std::vector<int>> C_padded = strassen_multiply(Ap, Bp, new_n);
// Step 4: Extract the result of the original dimensions.
std::vector<std::vector<int>> C_original(original_n, std::vector<int>(original_n));
for (int i = 0; i < original_n; ++i) {
for (int j = 0; j < original_n; ++j) {
C_original[i][j] = C_padded[i][j];
}
}
// Step 5: Return the result for the original matrix size.
return C_original;
}
int main() {
// Step 1: Define matrices A and B (2x2 example).
std::vector<std::vector<int>> A = {{1, 2}, {3, 4}};
std::vector<std::vector<int>> B = {{5, 6}, {7, 8}};
int n = A.size();
if (n == 0 || A[0].size() != n || B.size() != n || B[0].size() != n) {
std::cout << "Invalid matrix dimensions for square matrices." << std::endl;
return 1;
}
// Step 2: Print original matrices.
std::cout << "Matrix A:" << std::endl;
printMatrix(A, n);
std::cout << "\\nMatrix B:" << std::endl;
printMatrix(B, n);
// Step 3: Perform Strassen's multiplication using the wrapper to handle padding.
std::vector<std::vector<int>> C = strassen_wrapper(A, B, n);
// Step 4: Print the result matrix.
std::cout << "\\nResult of Strassen's Multiplication (C = A * B):" << std::endl;
printMatrix(C, n);
// Example with 3x3 matrix (will be padded to 4x4)
std::cout << "\\n--- 3x3 Matrix Example (padded to 4x4 for Strassen) ---" << std::endl;
std::vector<std::vector<int>> A_3x3 = {
{1, 2, 3},
{4, 5, 6},
{7, 8, 9}
};
std::vector<std::vector<int>> B_3x3 = {
{9, 8, 7},
{6, 5, 4},
{3, 2, 1}
};
int n_3x3 = A_3x3.size();
std::cout << "Matrix A (3x3):" << std::endl;
printMatrix(A_3x3, n_3x3);
std::cout << "\\nMatrix B (3x3):" << std::endl;
printMatrix(B_3x3, n_3x3);
std::vector<std::vector<int>> C_3x3_strassen = strassen_wrapper(A_3x3, B_3x3, n_3x3);
std::cout << "\\nResult of Strassen's Multiplication (3x3):" << std::endl;
printMatrix(C_3x3_strassen, n_3x3);
// Verify with naive for 3x3
std::vector<std::vector<int>> C_3x3_naive = multiplyNaive(A_3x3, B_3x3); // Needs multiplyNaive to handle non-power of 2, which it does.
std::cout << "\\nResult of Naive Multiplication (3x3 for comparison):" << std::endl;
printMatrix(C_3x3_naive, n_3x3);
return 0;
}
- Sample Output:
Matrix A:
1 2
3 4
Matrix B:
5 6
7 8
Result of Strassen's Multiplication (C = A * B):
19 22
43 50
--- 3x3 Matrix Example (padded to 4x4 for Strassen) ---
Matrix A (3x3):
1 2 3
4 5 6
7 8 9
Matrix B (3x3):
9 8 7
6 5 4
3 2 1
Result of Strassen's Multiplication (3x3):
30 24 18
84 69 54
138 114 90
Result of Naive Multiplication (3x3 for comparison):
30 24 18
84 69 54
138 114 90
- Stepwise Explanation:
- Helper Functions:
-
printMatrix: Utility to display matrix contents.
-
add, subtract: Perform element-wise matrix addition and subtraction.split: Divides a given matrix into four equally sized sub-matrices ($A_{11}, A_{12}, A_{21}, A_{22}$).join: Combines four sub-matrices back into a single larger matrix.getNextPowerOf2: Calculates the smallest power of 2 greater than or equal to n, used for padding.strassen_wrapper: Handles padding of matrices if their dimensions are not powers of 2. It pads with zeros, calls strassen_multiply, and then extracts the relevant part of the result.strassen_multiply(Recursive Function):- Base Case: If the
sizeof the matrix is 1 (a $1 \times 1$ matrix), it performs a direct multiplication and returns the result. This is the stopping condition for the recursion.
- Base Case: If the
A and B are split into four halfSize x halfSize sub-matrices each using the split helper function.strassen_multiply seven times to compute the 7 products ($M_1$ through $M_7$). Each call involves matrix additions or subtractions of the sub-matrices, followed by a recursive multiplication.C.C matrix, which is returned.mainFunction:- Sets up two $2 \times 2$ example matrices
AandB.
- Sets up two $2 \times 2$ example matrices
strassen_wrapper to perform the multiplication, ensuring correct padding if needed.Conclusion
Strassen's algorithm provides a theoretically faster approach to matrix multiplication compared to the naive $O(n^3)$ algorithm, achieving $O(n^{\log_2 7})$ complexity. While its asymptotic advantage is significant for very large matrices, the overhead of matrix splitting, joining, and increased number of additions/subtractions can make it slower for smaller matrices due to constant factors. Furthermore, implementing Strassen's algorithm requires careful handling of matrix dimensions (often padding to powers of 2) and memory management for recursive calls. Despite these practical considerations, it remains an important algorithm demonstrating how clever algebraic manipulations can improve computational efficiency.
Summary
- Problem: Naive matrix multiplication has an $O(n^3)$ time complexity, which is inefficient for large matrices.
- Strassen's Solution: A divide-and-conquer algorithm that reduces the number of multiplications from 8 to 7 for $2 \times 2$ blocks, leading to an improved $O(n^{\log_2 7})$ time complexity.
- Key Idea: It cleverly rearranges matrix additions and subtractions to form intermediate products ($M_1$ to $M_7$) that are then combined to get the final result.
- Implementation: Requires recursive function calls, helper functions for matrix addition, subtraction, splitting, and joining.
- Padding: Matrices whose dimensions are not powers of 2 need to be padded with zeros to the next power of 2 to fit the recursive structure.
- Performance: Theoretically faster for large matrices, but practical overheads might make it slower than naive multiplication for small matrix sizes.