add
This commit is contained in:
@@ -0,0 +1,150 @@
|
||||
#include <stdio.h>
|
||||
#include <stdbool.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#define SIZE 4 // 矩阵的大小
|
||||
|
||||
// 检查一个数是否是2的幂次方
|
||||
bool isPowerOfTwo(int num) {
|
||||
return (num & (num - 1)) == 0;
|
||||
}
|
||||
|
||||
// 矩阵加法
|
||||
void add(int* a, int* b, int* c, int length) {
|
||||
for (int i = 0; i < length * length; i++) {
|
||||
c[i] = a[i] + b[i];
|
||||
}
|
||||
}
|
||||
|
||||
// 获取结果矩阵
|
||||
void getResult(int* a, int* b, int* result) {
|
||||
int p1 = a[0] * (b[1] - b[3]);
|
||||
int p2 = (a[0] + a[1]) * b[3];
|
||||
int p3 = (a[2] + a[3]) * b[0];
|
||||
int p4 = a[3] * (b[2] - b[0]);
|
||||
int p5 = (a[0] + a[3]) * (b[0] + b[3]);
|
||||
int p6 = (a[1] - a[3]) * (b[2] + b[3]);
|
||||
int p7 = (a[0] - a[2]) * (b[0] + b[1]);
|
||||
|
||||
result[0] = p5 + p4 - p2 + p6;
|
||||
result[1] = p1 + p2;
|
||||
result[2] = p3 + p4;
|
||||
result[3] = p5 + p1 - p3 - p7;
|
||||
}
|
||||
|
||||
// Strassen 矩阵乘法
|
||||
void sMM(int* a, int* b, int* result, int length) {
|
||||
if (length == 2) {
|
||||
getResult(a, b, result);
|
||||
} else {
|
||||
int tlength = length / 2;
|
||||
int* aa = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* ab = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* ac = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* ad = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* ba = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* bb = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* bc = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* bd = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t1 = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t2 = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t3 = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t4 = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* temp = (int*)malloc(length * length * sizeof(int));
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
for (int j = 0; j < length; j++) {
|
||||
if (i < tlength) {
|
||||
if (j < tlength) {
|
||||
aa[i * tlength + j] = a[i * length + j];
|
||||
ba[i * tlength + j] = b[i * length + j];
|
||||
} else {
|
||||
ab[i * tlength + (j - tlength)] = a[i * length + j];
|
||||
bb[i * tlength + (j - tlength)] = b[i * length + j];
|
||||
}
|
||||
} else {
|
||||
if (j < tlength) {
|
||||
ac[(i - tlength) * tlength + j] = a[i * length + j];
|
||||
bc[(i - tlength) * tlength + j] = b[i * length + j];
|
||||
} else {
|
||||
ad[(i - tlength) * tlength + (j - tlength)] = a[i * length + j];
|
||||
bd[(i - tlength) * tlength + (j - tlength)] = b[i * length + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int* t1_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t2_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t3_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
int* t4_ = (int*)malloc(tlength * tlength * sizeof(int));
|
||||
|
||||
sMM(aa, ba, t1_, tlength);
|
||||
sMM(ab, bc, t2_, tlength);
|
||||
add(t1_, t2_, t1_, tlength);
|
||||
|
||||
sMM(aa, bb, t2_, tlength);
|
||||
sMM(ab, bd, t3_, tlength);
|
||||
add(t2_, t3_, t2_, tlength);
|
||||
|
||||
sMM(ac, ba, t3_, tlength);
|
||||
sMM(ad, bc, t4_, tlength);
|
||||
add(t3_, t4_, t3_, tlength);
|
||||
|
||||
sMM(ac, bb, t4_, tlength);
|
||||
sMM(ad, bd, t1_, tlength);
|
||||
add(t4_, t1_, t4_, tlength);
|
||||
|
||||
for (int i = 0; i < length; i++) {
|
||||
for (int j = 0; j < length; j++) {
|
||||
if (i < tlength) {
|
||||
if (j < tlength)
|
||||
result[i * length + j] = t1[i * tlength + j];
|
||||
else
|
||||
result[i * length + j] = t2[i * tlength + (j - tlength)];
|
||||
} else {
|
||||
if (j < tlength)
|
||||
result[i * length + j] = t3[(i - tlength) * tlength + j];
|
||||
else
|
||||
result[i * length + j] = t4[(i - tlength) * tlength + (j - tlength)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
free(aa);
|
||||
free(ab);
|
||||
free(ac);
|
||||
free(ad);
|
||||
free(ba);
|
||||
free(bb);
|
||||
free(bc);
|
||||
free(bd);
|
||||
free(t1_);
|
||||
free(t2_);
|
||||
free(t3_);
|
||||
free(t4_);
|
||||
free(temp);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
int a[SIZE * SIZE] = { 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4 };
|
||||
int b[SIZE * SIZE] = { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 };
|
||||
int c[SIZE * SIZE];
|
||||
|
||||
sMM(a, b, c, SIZE);
|
||||
|
||||
for (int i = 0; i < SIZE * SIZE; i++) {
|
||||
printf("%d ", c[i]);
|
||||
if ((i + 1) % SIZE == 0) // 换行
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
printf("%d\n", isPowerOfTwo(1));
|
||||
printf("%d\n", isPowerOfTwo(2));
|
||||
printf("%d\n", isPowerOfTwo(4));
|
||||
printf("%d\n", isPowerOfTwo(6));
|
||||
printf("%d\n", isPowerOfTwo(443));
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user