稀疏矩阵是绝大多数元素为0,只有少数非零元素的矩阵。现在我必须填充我的
Sparsematrix class
,使得矩阵可以做add
,subtract
和multiply
。
我使用 COO 来存储我的矩阵。
template <class T>
class VecList{
private:
int capacity;
int length;
T* arr;
void doubleListSize(){
T * oldArr = arr;
arr = new T[2*capacity];
capacity = 2 * capacity;
for(int i=0;i<length;i++){
arr[i] = oldArr[i];
}
delete [] oldArr;
}
public:
VecList(){
length = 0;
capacity = 100;
arr = new T[capacity];
}
VecList(T* a, int n){
length = n;
capacity = 100 + 2*n;
arr = new T[capacity];
for(int i=0;i<n;i++){
arr[i] = a[i];
}
for (int i = 0; i < n; i++)
{
cout << arr[i] << " ";
}
cout << endl;
printList();
}
~VecList(){
delete [] arr;
}
int getLength(){
return length;
}
bool isEmpty(){
return length==0;
}
void insertEleAtPos(int i, T x){
if(length==capacity)
doubleListSize();
if(i > length || i < 0)
throw "Illegal position";
for(int j=length;j>i;j--)
arr[j] = arr[j-1];
arr[i] = x;
length++;
}
T deleteEleAtPos(int i){
if(i >= length || i < 0)
throw "Illegal position";
T tmp = arr[i];
for(int j=i;j<length-1;j++)
arr[j] = arr[j+1];
length--;
return tmp;
}
void setEleAtPos(int i, T x){
if(i >= length || i < 0)
throw "Illegal position";
arr[i] = x;
}
T getEleAtPos(int i){
if(i >= length || i < 0)
throw "Illegal position";
return arr[i];
}
int locateEle(T x){
for(int i=0;i<length;i++){
if(arr[i]==x)
return i;
}
return -1;
}
void printList(){
for(int i=0;i<length;i++)
cout << arr[i] << " ";
}
};
COO 使用三个
VecList
来存储矩阵。
rowIndex
:表示行数。colIndex
:表示列数。values
:表示元素的值。
以下是我的Sparsematrix class
:template <class T>
class SparseMatrix{
private:
int rows;
int cols;
VecList<int>* rowIndex;
VecList<int>* colIndex;
VecList<T>* values;
public:
SparseMatrix(){ //Create a 10x10 Sparse matrix
rows = 10;
cols = 10;
rowIndex = new VecList<int>();
colIndex = new VecList<int>();
values = new VecList<T>();
}
SparseMatrix(int r, int c){ //Create a rxc Sparse matrix
rows = r;
cols = c;
rowIndex = new VecList<int>();
colIndex = new VecList<int>();
values = new VecList<T>();
}
~SparseMatrix(){
delete rowIndex;
delete colIndex;
delete values;
}
};
如您所见,我只需要关注非零元素。例如:稀疏矩阵
0 2 0
0 0 1
3 0 0
rows = Exact number of rows, 3
cols = Exact number of columns, 3
rowIndex = 0 1 2
colIndex = 1 2 0
values = 2 1 3
第一条垂直线意味着,
row[0]col[1]
有一个非零元素“2”。
第二条竖线表示 row[1]col[2]
有一个非零元素“1”。
现在我写了几个函数来实现矩阵之间的运算。
int findPos(int a, int b){ //If there is a non-zero element at (a, b), then return its position in "rowIndex", else return -1.
for (int i = 0; i < rowIndex->getLength(); i++)
{
if(rowIndex->getEleAtPos(i) == a && colIndex->getEleAtPos(i) == b)return i;
else if(rowIndex->getEleAtPos(rowIndex->getLength() - 1 - i) == a && colIndex->getEleAtPos(colIndex->getLength()-1-i) == b)return rowIndex->getLength()-1 - i;
}
return -1;
}
void setEntry(int rPos, int cPos, T x){ // Set (rPos, cPos) = x
int pos = findPos(rPos,cPos);
//Find if there is a non-zero element at (rPos, cPos).
if(x != 0){
//If the origin matrix does not have an element at(rPos, cPos),insert x to the matrix.
if (pos == -1)
{
rowIndex->insertEleAtPos(rowIndex->getLength(),rPos);
colIndex->insertEleAtPos(colIndex->getLength(),cPos);
values->insertEleAtPos(values->getLength(),x);
}
else{
//If the origin matrix has an element at(rPos, cPos),replace it with x.
rowIndex->setEleAtPos(pos,rPos);
colIndex->setEleAtPos(pos,cPos);
values->setEleAtPos(pos,x);
}
}
else{
//If x == 0 and the origin matrix has an element at(rPos, cPos), delete the element.
if(pos != -1){
rowIndex->deleteEleAtPos(pos);
colIndex->deleteEleAtPos(pos);
values->deleteEleAtPos(pos);
}
}
//If x == 0, and the origin matrix does not have an element at(rPos, cPos), nothing changed.
}
T getEntry(int rPos, int cPos){
//Get the element at (rPos, cPos)
return findPos(rPos,cPos) == -1 ? 0 : values->getEleAtPos(findPos(rPos,cPos));
}
SparseMatrix<T> * add(SparseMatrix<T> * B){
if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);//Create a new matrix C as result.
for (int i = 0; i < rowIndex->getLength(); i++)
{
//I call the two input matrices "A" and "B". I put every elements of A into C, and also put every elements of B into C. But I use "C->setEntry", which means when A[i][j] has an element and B[i][j] also has an element, "setEntry" will cover the prior one. So I use C->setEntry(i,j,C->getEntry(i,j) + A[i][j] or B[i][j]), in another word, setEntry with (oldvalue + newvalue).That's what I did.
C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))+values->getEleAtPos(i));
C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))+B->values->getEleAtPos(i));
}
return C;
}
SparseMatrix<T> * subtract(SparseMatrix<T> * B){
//The same method as add.
if(rows != B->rows || cols != B->cols)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,cols);
for (int i = 0; i < rowIndex->getLength(); i++)
{
C->setEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i),C->getEntry(rowIndex->getEleAtPos(i),colIndex->getEleAtPos(i))-values->getEleAtPos(i));
C->setEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i),C->getEntry(B->rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(i))-B->values->getEleAtPos(i));
}
return C;
}
SparseMatrix<T> * multiply(SparseMatrix<T> * B){
//perform multiplication if the sizes of the matrices are compatible.
if(rows != B->cols || cols != B->rows)throw "Matrices have incompatible sizes";
SparseMatrix<T> *C = new SparseMatrix<T>(rows,B->cols);
//I call the two input matrices as "A" and "B".
//My method is take a row of A first, let this row do the arithmetic with each column of B,then I finish a row in C. Then continue to the next row.
for (int i = 0; i < rowIndex->getLength();i++)
{
for (int j = 0; j < B->colIndex->getLength(); j++)
{
if (B->findPos(colIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j)) != -1)
{
C->setEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j),C->getEntry(rowIndex->getEleAtPos(i),B->colIndex->getEleAtPos(j))+(values->getEleAtPos(i)*B->values->getEleAtPos(j)));
}
}
}
return C;
}
void printMatrix(){
for (int i = 0; i < rows; i++)
{
for (int j = 0; j < cols; j++)
{
cout << getEntry(i,j) << " ";
}
cout << endl;
}
}
我测试了多种条件,所有条件都表明
add
、subtract
和 multiply
运行良好。但是有一个 10000x10000 矩阵(称为“X”和“Y”)测试我无法通过,X 和 Y 没有很多非零元素。而且它们只是做加法、减法和乘法。
时间限制是1秒。(不包括printMatrix(),但包括setEntry())我超出了它。**而且我只能UPLAOD类SParsematrix,意味着我不能使用任何#include <>除了**如何我是否可以减少程序的运行时间? (我还想知道COO存储是否错误,findPos()
功能是否空闲。)谢谢。
我的工具是VSCode2022,带有C++11,Windows 11。
这是测试代码示例。
#include <iostream>
#include <algorithm>
#include <chrono>
using namespace std;
int main(){
auto start = std::chrono::high_resolution_clock::now();
SparseMatrix<int> X,Y;
X.setEntry(1,3,4);
X.setEntry(7,8,2);
Y.setEntry(1,6,4);
Y.setEntry(1,3,4);
Y.setEntry(7,7,2);
X.printMatrix();
cout << endl;
Y.printMatrix();
cout << endl;
X.add(&Y)->printMatrix();
cout << endl;
X.subtract(&Y)->printMatrix();
cout << endl;
Y.multiply(&X)->printMatrix();
cout << "Done" << endl;
auto stop = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(stop - start).count();
cout << "Running Time:" << duration << "ms\n";
return 0;
}
首先进行一些快速的粗略计算:您的
findPos
是 O(n),其中 n
是矩阵中的项目数。您的乘法调用有两个嵌套的 for 循环,每个循环都有 10,000 次迭代。内循环至少调用 findPos
一次,并且可能会再调用 getEntry
和 setEntry
两次。即 O(n) 操作的 100,000,000 次调用。你不可能在一秒钟内做到这一点。您需要更智能的算法。
作为第一个改进,请考虑先按行排序,然后按列排序
rowIndex
和 colIndex
。这允许使用一种称为二分搜索的模糊算法,其时间复杂度为 O(log n)。
此外,它还允许您快速选择属于给定行的所有元素。
第二个优化是另一个常见的优化:在乘法 A * B 中,逐行读取 A 矩阵,逐列读取 B 矩阵。因此,如果您首先转置 B,您可以逐行读取它,如上所述,现在速度很快。