Write A Program For Strassens Matrix Multiplication In C++
In this article, you will learn how to implement Strassen's matrix multiplication algorithm in C++, a technique that offers improved time complexity compared to the standard approach for large matrices.
Problem Statement
Matrix multiplication is a fundamental operation in linear algebra, with widespread applications across various scientific and engineering fields. The standard algorithm for multiplying two $N \times N$ matrices has a time complexity of O($N^3$). For very large matrices, this cubic complexity can lead to significant computational overhead, making the process prohibitively slow. The challenge is to find a more efficient algorithm to reduce the number of arithmetic operations required.
Example
Consider two 2x2 matrices, A and B:
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
The product C = A * B using standard multiplication would be:
C = [[(1*5 + 2*7), (1*6 + 2*8)],
[(3*5 + 4*7), (3*6 + 4*8)]]
C = [[(5 + 14), (6 + 16)],
[(15 + 28), (18 + 32)]]
C = [[19, 22],
[43, 50]]
This output serves as a target for our algorithms.
Background & Knowledge Prerequisites
To understand and implement Strassen's algorithm, familiarity with the following concepts is essential:
- C++ Basics: Understanding of data structures like
std::vector, functions, loops, and basic input/output operations. - Matrices: Knowledge of how matrices are represented and basic operations like addition and subtraction.
- Standard Matrix Multiplication: Understanding the O($N^3$) algorithm for multiplying two matrices.
- Divide and Conquer: This algorithmic paradigm is central to Strassen's method, involving breaking down a problem into smaller sub-problems, solving them, and combining their results.
- Recursion: Strassen's algorithm is naturally implemented recursively.
Use Cases or Case Studies
Strassen's matrix multiplication, despite its complexity, finds practical application in scenarios where multiplying very large matrices efficiently is critical:
- Image Processing: Operations like image filtering, transformations, and feature extraction often involve matrix manipulations.
- Computer Graphics: Rendering 3D graphics, transformations (scaling, rotation, translation), and projections heavily rely on matrix multiplication.
- Scientific Computing: Simulations in physics, chemistry, and engineering frequently involve solving systems of linear equations or performing matrix operations on large datasets.
- Machine Learning: Deep learning models, especially those involving complex neural networks, perform numerous matrix multiplications during training and inference.
- Graph Algorithms: Certain graph algorithms, like computing all-pairs shortest paths using matrix multiplication, can benefit from faster matrix multiplication.
Solution Approaches
Approach 1: Standard Matrix Multiplication
The standard algorithm is straightforward, involving three nested loops.
- One-line summary: Computes the product of two matrices by iterating through rows of the first, columns of the second, and summing intermediate products.
- Code example:
// Standard Matrix Multiplication
#include <iostream>
#include <vector>
// Function to print a matrix
void printMatrix(const std::vector<std::vector<int>>& matrix) {
for (const auto& row : matrix) {
for (int val : row) {
std::cout << val << " ";
}
std::cout << std::endl;
}
}
// Function for standard matrix multiplication
std::vector<std::vector<int>> standardMultiply(
const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int rowsA = A.size();
int colsA = A[0].size();
int rowsB = B.size();
int colsB = B[0].size();
if (colsA != rowsB) {
std::cerr << "Error: Matrices dimensions are not compatible for multiplication." << std::endl;
return {}; // Return empty matrix
}
std::vector<std::vector<int>> C(rowsA, std::vector<int>(colsB, 0));
// Step 1: Iterate through rows of matrix A
for (int i = 0; i < rowsA; ++i) {
// Step 2: Iterate through columns of matrix B
for (int j = 0; j < colsB; ++j) {
// Step 3: Iterate through columns of A (or rows of B)
for (int k = 0; k < colsA; ++k) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
int main() {
std::vector<std::vector<int>> A = {{1, 2}, {3, 4}};
std::vector<std::vector<int>> B = {{5, 6}, {7, 8}};
std::cout << "Matrix A:" << std::endl;
printMatrix(A);
std::cout << "\\nMatrix B:" << std::endl;
printMatrix(B);
std::vector<std::vector<int>> C = standardMultiply(A, B);
std::cout << "\\nResult of Standard Multiplication (C = A * B):" << std::endl;
printMatrix(C);
return 0;
}
- Sample output:
Matrix A:
1 2
3 4
Matrix B:
5 6
7 8
Result of Standard Multiplication (C = A * B):
19 22
43 50
- Stepwise explanation:
- Initialize a result matrix
Cwith dimensionsrowsA x colsBfilled with zeros. - For each element
C[i][j]in the result matrix:
- It is calculated as the sum of products of corresponding elements from the
i-th row ofAand thej-th column ofB. - This involves an inner loop that iterates
colsA(orrowsB) times.
- The three nested loops contribute to its O($N^3$) time complexity for $N \times N$ matrices.
Approach 2: Strassen's Matrix Multiplication
Strassen's algorithm reduces the number of multiplications required for $2 \times 2$ matrices from 8 to 7, leading to an overall complexity of O($N^{\log_2 7}$) which is approximately O($N^{2.807}$). It achieves this using a divide-and-conquer strategy.
- One-line summary: A recursive divide-and-conquer algorithm that multiplies matrices by cleverly rearranging sub-matrix additions and subtractions to reduce the number of recursive multiplications from 8 to 7, achieving sub-cubic time complexity.
- Detailed Explanation (7 Products):
A = [[a, b], B = [[e, f],
[c, d]] [g, h]]
The product C = A * B is:
C = [[ae + bg, af + bh],
[ce + dg, cf + dh]]
Strassen's algorithm calculates 7 intermediate products:P1 = a(f - h)P2 = (a + b)hP3 = (c + d)eP4 = d(g - e)P5 = (a + d)(e + h)P6 = (b - d)(g + h)P7 = (a - c)(e + f)
Then, the elements of the result matrix C are:
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
When applied recursively, this concept extends to larger matrices by dividing them into four sub-matrices.
- Code example:
// Strassen's Matrix Multiplication
#include <iostream>
#include <vector>
#include <cmath> // For std::ceil and std::log2
// Function to print a matrix
void printMatrix(const std::vector<std::vector<int>>& matrix) {
for (const auto& row : matrix) {
for (int val : row) {
std::cout << val << " ";
}
std::cout << std::endl;
}
}
// Function to add two matrices
std::vector<std::vector<int>> addMatrices(
const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int n = A.size();
std::vector<std::vector<int>> C(n, std::vector<int>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// Function to subtract two matrices
std::vector<std::vector<int>> subtractMatrices(
const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int n = A.size();
std::vector<std::vector<int>> C(n, std::vector<int>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// Function for standard matrix multiplication (base case for Strassen)
std::vector<std::vector<int>> baseCaseMultiply(
const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int n = A.size();
std::vector<std::vector<int>> C(n, std::vector<int>(n));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
for (int k = 0; k < n; ++k) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
return C;
}
// Strassen's Matrix Multiplication
std::vector<std::vector<int>> strassenMultiply(
const std::vector<std::vector<int>>& A,
const std::vector<std::vector<int>>& B) {
int n = A.size();
// Base case: If matrix size is small, use standard multiplication
if (n <= 16) { // Threshold can be tuned for performance
return baseCaseMultiply(A, B);
}
// Step 1: Divide matrices into 4 sub-matrices
int newSize = n / 2;
std::vector<std::vector<int>> A11(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> A12(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> A21(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> A22(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> B11(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> B12(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> B21(newSize, std::vector<int>(newSize));
std::vector<std::vector<int>> B22(newSize, std::vector<int>(newSize));
for (int i = 0; i < newSize; ++i) {
for (int j = 0; j < newSize; ++j) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + newSize];
A21[i][j] = A[i + newSize][j];
A22[i][j] = A[i + newSize][j + newSize];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + newSize];
B21[i][j] = B[i + newSize][j];
B22[i][j] = B[i + newSize][j + newSize];
}
}
// Step 2: Compute the 7 products recursively
std::vector<std::vector<int>> P1 = strassenMultiply(A11, subtractMatrices(B12, B22));
std::vector<std::vector<int>> P2 = strassenMultiply(addMatrices(A11, A12), B22);
std::vector<std::vector<int>> P3 = strassenMultiply(addMatrices(A21, A22), B11);
std::vector<std::vector<int>> P4 = strassenMultiply(A22, subtractMatrices(B21, B11));
std::vector<std::vector<int>> P5 = strassenMultiply(addMatrices(A11, A22), addMatrices(B11, B22));
std::vector<std::vector<int>> P6 = strassenMultiply(subtractMatrices(A12, A22), addMatrices(B21, B22));
std::vector<std::vector<int>> P7 = strassenMultiply(subtractMatrices(A11, A21), addMatrices(B11, B12));
// Step 3: Compute the 4 resulting sub-matrices
std::vector<std::vector<int>> C11 = addMatrices(subtractMatrices(addMatrices(P5, P4), P2), P6);
std::vector<std::vector<int>> C12 = addMatrices(P1, P2);
std::vector<std::vector<int>> C21 = addMatrices(P3, P4);
std::vector<std::vector<int>> C22 = subtractMatrices(subtractMatrices(addMatrices(P5, P1), P3), P7);
// Step 4: Combine sub-matrices into the final result
std::vector<std::vector<int>> C(n, std::vector<int>(n));
for (int i = 0; i < newSize; ++i) {
for (int j = 0; j < newSize; ++j) {
C[i][j] = C11[i][j];
C[i][j + newSize] = C12[i][j];
C[i + newSize][j] = C21[i][j];
C[i + newSize][j + newSize] = C22[i][j];
}
}
return C;
}
int main() {
// Example matrices must have dimensions that are powers of 2 for this simplified implementation
// For general matrices, padding with zeros to the next power of 2 is required.
std::vector<std::vector<int>> A = {{1, 2, 1, 2},
{3, 4, 3, 4},
{1, 2, 1, 2},
{3, 4, 3, 4}};
std::vector<std::vector<int>> B = {{5, 6, 5, 6},
{7, 8, 7, 8},
{5, 6, 5, 6},
{7, 8, 7, 8}};
std::cout << "Matrix A (4x4):" << std::endl;
printMatrix(A);
std::cout << "\\nMatrix B (4x4):" << std::endl;
printMatrix(B);
std::vector<std::vector<int>> C = strassenMultiply(A, B);
std::cout << "\\nResult of Strassen's Multiplication (C = A * B):" << std::endl;
printMatrix(C);
// Another example (2x2)
std::vector<std::vector<int>> A_small = {{1, 2}, {3, 4}};
std::vector<std::vector<int>> B_small = {{5, 6}, {7, 8}};
std::cout << "\\nMatrix A (2x2):" << std::endl;
printMatrix(A_small);
std::cout << "\\nMatrix B (2x2):" << std::endl;
printMatrix(B_small);
std::vector<std::vector<int>> C_small = strassenMultiply(A_small, B_small);
std::cout << "\\nResult of Strassen's Multiplication (C = A * B) for 2x2:" << std::endl;
printMatrix(C_small);
return 0;
}
- Sample output (for 4x4):
Matrix A (4x4):
1 2 1 2
3 4 3 4
1 2 1 2
3 4 3 4
Matrix B (4x4):
5 6 5 6
7 8 7 8
5 6 5 6
7 8 7 8
Result of Strassen's Multiplication (C = A * B):
38 44 38 44
86 100 86 100
38 44 38 44
86 100 86 100
Matrix A (2x2):
1 2
3 4
Matrix B (2x2):
5 6
7 8
Result of Strassen's Multiplication (C = A * B) for 2x2:
19 22
43 50
- Stepwise explanation:
- Base Case: If the matrix size
nis below a certain threshold (e.g., 16), use the simpler standard matrix multiplication. This is a common optimization because Strassen's algorithm has higher constant factors for small matrices due to overheads like memory allocation and copying. - Divide: The input matrices
AandBare divided into fourn/2 x n/2sub-matrices:A11, A12, A21, A22andB11, B12, B21, B22. This implementation assumesnis a power of 2. For generaln, matrices must be padded with zeros to the next power of 2. - Conquer (7 Recursive Products): Seven products (P1 to P7) are computed recursively using
strassenMultiplyon combinations of these sub-matrices, as described in the detailed explanation above. Each product involves one matrix multiplication and one or two matrix additions/subtractions. - Combine: The four resulting sub-matrices (
C11, C12, C21, C22) are calculated using additions and subtractions of the seven products. - Assemble: The final result matrix
Cis constructed by combining these four sub-matrices.
Conclusion
Strassen's matrix multiplication algorithm provides a theoretical improvement over the standard O($N^3$) approach by achieving O($N^{\log_2 7}$) complexity. While it introduces higher constant factors and memory overhead due to recursive calls and sub-matrix operations, it becomes advantageous for sufficiently large matrices. The point at which Strassen's outperforms standard multiplication depends on hardware specifics and implementation details, often requiring matrix dimensions in the hundreds or thousands.
Summary
- Standard matrix multiplication has O($N^3$) time complexity.
- Strassen's algorithm is a divide-and-conquer approach.
- It reduces multiplications for $2 \times 2$ matrices from 8 to 7.
- This leads to an overall time complexity of O($N^{\log_2 7}$) $\approx$ O($N^{2.807}$).
- Key steps: Divide matrices into sub-matrices, compute 7 intermediate products recursively, and combine results.
- Practical considerations: Higher constant factors mean it's only faster for large matrices; a hybrid approach (using standard multiplication for small sub-problems) is often optimal.
- Prerequisites: Understanding of C++, matrices, recursion, and divide and conquer.