Strassens Matrix Multiplication Program In C++ Programming
Matrix multiplication is a fundamental operation in various computational fields, from computer graphics to machine learning. While the standard algorithm is straightforward, its cubic time complexity becomes a bottleneck for large matrices. In this article, you will learn about Strassen's algorithm, an efficient divide-and-conquer approach that significantly reduces the computational cost of matrix multiplication.
Problem Statement
Multiplying two $N \times N$ matrices, say A and B, to produce a resulting matrix C, typically involves calculating each element $C_{ij}$ as the sum of products of elements from row $i$ of A and column $j$ of B. This process requires $N$ multiplications and $N-1$ additions for each of the $N^2$ elements in the result matrix. Consequently, the standard algorithm has a time complexity of $O(N^3)$, which can be computationally intensive and slow for very large matrices encountered in scientific simulations or big data analysis.
Example
Consider multiplying two $2 \times 2$ matrices using the standard algorithm:
Given matrices: A = $\begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}$ B = $\begin{pmatrix} 5 & 6 \\ 7 & 8 \end{pmatrix}$
The result C is calculated as: $C_{11} = (1 \times 5) + (2 \times 7) = 5 + 14 = 19$ $C_{12} = (1 \times 6) + (2 \times 8) = 6 + 16 = 22$ $C_{21} = (3 \times 5) + (4 \times 7) = 15 + 28 = 43$ $C_{22} = (3 \times 6) + (4 \times 8) = 18 + 32 = 50$
The resulting matrix C is: C = $\begin{pmatrix} 19 & 22 \\ 43 & 50 \end{pmatrix}$
This standard approach performs 8 multiplications and 4 additions for a $2 \times 2$ matrix.
Background & Knowledge Prerequisites
To understand Strassen's algorithm, readers should be familiar with:
- Basic Matrix Operations: Understanding of matrix addition, subtraction, and the standard matrix multiplication definition.
- C++ Fundamentals: Variables, loops, arrays (or
std::vectorfor dynamic arrays), and functions. - Recursion: The concept of a function calling itself, including base cases and recursive steps.
- Divide and Conquer Paradigm: Breaking down a problem into smaller subproblems, solving them independently, and combining their results.
Use Cases or Case Studies
Matrix multiplication is a cornerstone in numerous domains:
- Computer Graphics: Essential for transformations like translation, rotation, and scaling of objects in 2D and 3D space.
- Scientific Computing and Simulations: Used in solving systems of linear equations, finite element analysis, and simulating complex physical phenomena.
- Machine Learning and Deep Learning: Forms the core of neural network computations, especially in layers like fully connected layers and convolutional layers, for tasks such as image recognition and natural language processing.
- Image Processing: Applied in filters, image transformations, and various algorithms for image manipulation and analysis.
- Data Analysis and Statistics: Used in statistical models, covariance matrices, and principal component analysis (PCA).
Solution Approaches
Approach 1: Naive Matrix Multiplication
This is the standard, most intuitive method, involving three nested loops.
- One-line summary: Calculates each element of the result matrix by summing the products of corresponding row and column elements.
- Complexity: $O(N^3)$ for $N \times N$ matrices.
While simple to implement, its cubic complexity makes it inefficient for large matrices.
Approach 2: Strassen's Matrix Multiplication
Strassen's algorithm is a divide-and-conquer algorithm that improves upon the naive method by reducing the number of recursive multiplications from 8 to 7.
- One-line summary: A recursive algorithm that divides matrices into sub-matrices, performs 7 sub-matrix multiplications, and then combines results using additions/subtractions.
- Complexity: $O(N^{\log_2 7})$, which is approximately $O(N^{2.807})$. This is asymptotically faster than the $O(N^3)$ naive algorithm.
Algorithm Overview:
- Padding (if necessary): If the dimensions of the matrices are not powers of 2, pad them with zeros to make them the next power of 2.
- Divide: Divide the two $N \times N$ matrices (A and B) into four $N/2 \times N/2$ sub-matrices:
- Conquer (Calculate 7 products): Compute seven $N/2 \times N/2$ matrices (P1 to P7) using recursive calls to Strassen's algorithm and matrix additions/subtractions:
- $P_1 = (A_{11} + A_{22}) \times (B_{11} + B_{22})$
- $P_2 = (A_{21} + A_{22}) \times B_{11}$
- $P_3 = A_{11} \times (B_{12} - B_{22})$
- $P_4 = A_{22} \times (B_{21} - B_{11})$
- $P_5 = (A_{11} + A_{12}) \times B_{22}$
- $P_6 = (A_{21} - A_{11}) \times (B_{11} + B_{12})$
- $P_7 = (A_{12} - A_{22}) \times (B_{21} + B_{22})$
- Combine (Calculate result sub-matrices): The four sub-matrices of the result C are:
- $C_{11} = P_1 + P_4 - P_5 + P_7$
- $C_{12} = P_3 + P_5$
- $C_{21} = P_2 + P_4$
- $C_{22} = P_1 - P_2 + P_3 + P_6$
- Base Case: If the matrix size is $1 \times 1$ (or a small constant size), perform standard scalar multiplication.
Code Example:
This C++ implementation includes helper functions for matrix operations (addition, subtraction, printing) and handles padding for matrices that are not powers of 2.
// Strassen's Matrix Multiplication
#include <iostream>
#include <vector>
#include <cmath> // For std::pow, std::ceil, std::log2
using namespace std;
// Function to print a matrix
void printMatrix(const vector<vector<int>>& matrix) {
for (const auto& row : matrix) {
for (int val : row) {
cout << val << "\\t";
}
cout << endl;
}
}
// Function to add two matrices
vector<vector<int>> addMatrices(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
vector<vector<int>> C(n, vector<int>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// Function to subtract two matrices
vector<vector<int>> subtractMatrices(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
vector<vector<int>> C(n, vector<int>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// Strassen's matrix multiplication function
vector<vector<int>> strassenMultiply(const vector<vector<int>>& A, const vector<vector<int>>& B) {
int n = A.size();
// Base case: If matrix is 1x1, perform scalar multiplication
if (n == 1) {
return {{A[0][0] * B[0][0]}};
}
// Divide matrices into 4 sub-matrices
int k = n / 2;
vector<vector<int>> A11(k, vector<int>(k));
vector<vector<int>> A12(k, vector<int>(k));
vector<vector<int>> A21(k, vector<int>(k));
vector<vector<int>> A22(k, vector<int>(k));
vector<vector<int>> B11(k, vector<int>(k));
vector<vector<int>> B12(k, vector<int>(k));
vector<vector<int>> B21(k, vector<int>(k));
vector<vector<int>> B22(k, vector<int>(k));
for (int i = 0; i < k; ++i) {
for (int j = 0; j < k; ++j) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + k];
A21[i][j] = A[i + k][j];
A22[i][j] = A[i + k][j + k];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + k];
B21[i][j] = B[i + k][j];
B22[i][j] = B[i + k][j + k];
}
}
// Calculate 7 products (P1 to P7) recursively
vector<vector<int>> P1 = strassenMultiply(addMatrices(A11, A22), addMatrices(B11, B22));
vector<vector<int>> P2 = strassenMultiply(addMatrices(A21, A22), B11);
vector<vector<int>> P3 = strassenMultiply(A11, subtractMatrices(B12, B22));
vector<vector<int>> P4 = strassenMultiply(A22, subtractMatrices(B21, B11));
vector<vector<int>> P5 = strassenMultiply(addMatrices(A11, A12), B22);
vector<vector<int>> P6 = strassenMultiply(subtractMatrices(A21, A11), addMatrices(B11, B12));
vector<vector<int>> P7 = strassenMultiply(subtractMatrices(A12, A22), addMatrices(B21, B22));
// Combine products to form the result sub-matrices
vector<vector<int>> C11 = addMatrices(subtractMatrices(addMatrices(P1, P4), P5), P7);
vector<vector<int>> C12 = addMatrices(P3, P5);
vector<vector<int>> C21 = addMatrices(P2, P4);
vector<vector<int>> C22 = addMatrices(subtractMatrices(addMatrices(P1, P3), P2), P6);
// Assemble the final result matrix C
vector<vector<int>> C(n, vector<int>(n));
for (int i = 0; i < k; ++i) {
for (int j = 0; j < k; ++j) {
C[i][j] = C11[i][j];
C[i][j + k] = C12[i][j];
C[i + k][j] = C21[i][j];
C[i + k][j + k] = C22[i][j];
}
}
return C;
}
// Function to get the next power of 2
int getNextPowerOf2(int n) {
if (n == 0) return 1;
return pow(2, ceil(log2(n)));
}
int main() {
// Step 1: Define input matrices
vector<vector<int>> A_orig = {{1, 2}, {3, 4}};
vector<vector<int>> B_orig = {{5, 6}, {7, 8}};
// vector<vector<int>> A_orig = {{1, 1, 1, 1},
// {2, 2, 2, 2},
// {3, 3, 3, 3},
// {4, 4, 4, 4}};
// vector<vector<int>> B_orig = {{1, 1, 1, 1},
// {2, 2, 2, 2},
// {3, 3, 3, 3},
// {4, 4, 4, 4}};
cout << "Matrix A:" << endl;
printMatrix(A_orig);
cout << "\\nMatrix B:" << endl;
printMatrix(B_orig);
// Step 2: Determine actual size and padded size
int n_orig = A_orig.size();
int padded_size = getNextPowerOf2(n_orig);
// Step 3: Pad matrices with zeros if necessary
vector<vector<int>> A_padded(padded_size, vector<int>(padded_size, 0));
vector<vector<int>> B_padded(padded_size, vector<int>(padded_size, 0));
for (int i = 0; i < n_orig; ++i) {
for (int j = 0; j < n_orig; ++j) {
A_padded[i][j] = A_orig[i][j];
B_padded[i][j] = B_orig[i][j];
}
}
// Step 4: Perform Strassen's multiplication on padded matrices
vector<vector<int>> C_padded = strassenMultiply(A_padded, B_padded);
// Step 5: Extract the original sized result from the padded result
vector<vector<int>> C_result(n_orig, vector<int>(n_orig));
for (int i = 0; i < n_orig; ++i) {
for (int j = 0; j < n_orig; ++j) {
C_result[i][j] = C_padded[i][j];
}
}
cout << "\\nResult C (using Strassen's algorithm):" << endl;
printMatrix(C_result);
return 0;
}
Sample Output:
For the input matrices A = {{1, 2}, {3, 4}} and B = {{5, 6}, {7, 8}}:
Matrix A:
1 2
3 4
Matrix B:
5 6
7 8
Result C (using Strassen's algorithm):
19 22
43 50
Stepwise Explanation:
- Helper Functions:
printMatrix,addMatrices, andsubtractMatricesare defined to facilitate matrix operations and visualization. getNextPowerOf2: This utility function ensures that the matrix dimensions are powers of 2, which is crucial for recursive division in Strassen's algorithm. If not, matrices are padded with zeros.- Base Case: The
strassenMultiplyfunction checks if the input matrix sizenis 1. If so, it performs a direct scalar multiplication and returns the $1 \times 1$ result. This prevents infinite recursion. - Matrix Division: For
n > 1, the function divides each $N \times N$ matrix (A and B) into four $N/2 \times N/2$ sub-matrices (e.g.,A11,A12,A21,A22). - Recursive Products (P1-P7): Seven products (
P1throughP7) are calculated. Each product involves one matrix addition/subtraction, followed by a recursive call tostrassenMultiply, and another matrix addition/subtraction. This is where the core optimization lies, reducing 8 multiplications to 7. - Combining Results (C11-C22): The final four sub-matrices of the result
C(C11,C12,C21,C22) are formed by combining thePmatrices using a series of matrix additions and subtractions. - Assembling C: The calculated sub-matrices
C11,C12,C21, andC22are then combined to form the completeN \times Nresult matrixC. - Main Function: The
mainfunction demonstrates how to usestrassenMultiply. It takes original matrices, pads them to the nearest power of 2, calls Strassen's algorithm, and then extracts the relevant portion of the result to match the original dimensions.
Conclusion
Strassen's matrix multiplication algorithm offers a significant asymptotic improvement over the naive $O(N^3)$ method by reducing the number of recursive multiplications. While its practical benefits typically appear for very large matrices due to the overhead of recursive calls and matrix manipulations, it fundamentally altered the understanding of computational limits for matrix multiplication. It stands as a classic example of the power of the divide-and-conquer paradigm in algorithm design.
Summary
- Problem: Standard matrix multiplication has $O(N^3)$ complexity, slow for large matrices.
- Strassen's Algorithm: A divide-and-conquer approach.
- Key Idea: Reduces 8 sub-matrix multiplications to 7, using more additions/subtractions.
- Time Complexity: $O(N^{\log_2 7}) \approx O(N^{2.807})$, asymptotically faster than $O(N^3)$.
- Implementation: Involves recursive calls, matrix partitioning, and helper functions for arithmetic.
- Practicality: Most beneficial for very large matrices; overhead can make it slower for small matrices.