| 1 | /*!\file PlapackInvertMatrix.cpp
|
|---|
| 2 | * \brief invert petsc matrix using Plapack
|
|---|
| 3 | */
|
|---|
| 4 | #include "../../../include/include.h"
|
|---|
| 5 |
|
|---|
| 6 | /* petsc: */
|
|---|
| 7 | #include "../../petsc/petscincludes.h"
|
|---|
| 8 |
|
|---|
| 9 | /*plapack: */
|
|---|
| 10 | #include "../plapackincludes.h"
|
|---|
| 11 |
|
|---|
| 12 | /* Some fortran routines: */
|
|---|
| 13 | #include "../../scalapack/FortranMapping.h"
|
|---|
| 14 |
|
|---|
| 15 | void PlapackInvertMatrixLocalCleanup(PLA_Obj* pa,PLA_Template* ptempl,double** parrayA,int** pidxnA,MPI_Comm* pcomm_2d);
|
|---|
| 16 |
|
|---|
| 17 | int PlapackInvertMatrix(Mat* A,Mat* inv_A,int status,int con){
|
|---|
| 18 | /*inv_A does not yet exist, inv_A was just allocated, that's all*/
|
|---|
| 19 |
|
|---|
| 20 | /*Error management*/
|
|---|
| 21 | int i,j;
|
|---|
| 22 |
|
|---|
| 23 | /*input*/
|
|---|
| 24 | int mA,nA;
|
|---|
| 25 | int local_mA,local_nA;
|
|---|
| 26 | int lower_row,upper_row,range,this_range,this_lower_row;
|
|---|
| 27 | MatType type;
|
|---|
| 28 |
|
|---|
| 29 | /*Plapack: */
|
|---|
| 30 | MPI_Datatype datatype;
|
|---|
| 31 | MPI_Comm comm_2d;
|
|---|
| 32 | PLA_Obj a=NULL;
|
|---|
| 33 | PLA_Template templ;
|
|---|
| 34 | double one=1.0;
|
|---|
| 35 | int ierror;
|
|---|
| 36 | int nb,nb_alg;
|
|---|
| 37 | int nprows,npcols;
|
|---|
| 38 | int initialized=0;
|
|---|
| 39 |
|
|---|
| 40 | /*Petsc to Plapack: */
|
|---|
| 41 | double *arrayA=NULL;
|
|---|
| 42 | int* idxnA=NULL;
|
|---|
| 43 | int d_nz,o_nz;
|
|---|
| 44 |
|
|---|
| 45 | /*Feedback to client*/
|
|---|
| 46 | int computation_status;
|
|---|
| 47 |
|
|---|
| 48 | /*Verify that A is square*/
|
|---|
| 49 | MatGetSize(*A,&mA,&nA);
|
|---|
| 50 | MatGetLocalSize(*A,&local_mA,&local_nA);
|
|---|
| 51 |
|
|---|
| 52 | /*Some dimensions checks: */
|
|---|
| 53 | if (mA!=nA) _error2_("trying to take the invert of a non-square matrix!");
|
|---|
| 54 |
|
|---|
| 55 | /* Set default Plapack parameters */
|
|---|
| 56 | //First find nprows*npcols=num_procs;
|
|---|
| 57 | CyclicalFactorization(&nprows,&npcols,num_procs);
|
|---|
| 58 | //nprows=num_procs;
|
|---|
| 59 | //npcols=1;
|
|---|
| 60 | ierror = 0;
|
|---|
| 61 | nb = nA/num_procs;
|
|---|
| 62 | if(nA - nb*num_procs) nb++; /* without cyclic distribution */
|
|---|
| 63 |
|
|---|
| 64 | if (ierror){
|
|---|
| 65 | PLA_Set_error_checking(ierror,PETSC_TRUE,PETSC_TRUE,PETSC_FALSE );
|
|---|
| 66 | }
|
|---|
| 67 | else {
|
|---|
| 68 | PLA_Set_error_checking(ierror,PETSC_FALSE,PETSC_FALSE,PETSC_FALSE );
|
|---|
| 69 | }
|
|---|
| 70 | nb_alg = 0;
|
|---|
| 71 | if (nb_alg){
|
|---|
| 72 | pla_Environ_set_nb_alg (PLA_OP_ALL_ALG,nb_alg);
|
|---|
| 73 | }
|
|---|
| 74 |
|
|---|
| 75 | /*Verify that plapack is not already initialized: */
|
|---|
| 76 | if(PLA_Initialized(NULL)==TRUE)PLA_Finalize();
|
|---|
| 77 | /* Create a 2D communicator */
|
|---|
| 78 | PLA_Comm_1D_to_2D(MPI_COMM_WORLD,nprows,npcols,&comm_2d);
|
|---|
| 79 |
|
|---|
| 80 | /*Initlialize plapack: */
|
|---|
| 81 | PLA_Init(comm_2d);
|
|---|
| 82 |
|
|---|
| 83 | templ = NULL;
|
|---|
| 84 | PLA_Temp_create(nb, 0, &templ);
|
|---|
| 85 |
|
|---|
| 86 | /* Use suggested nb_alg if it is not provided by user */
|
|---|
| 87 | if (nb_alg == 0){
|
|---|
| 88 | PLA_Environ_nb_alg(PLA_OP_PAN_PAN,templ,&nb_alg);
|
|---|
| 89 | pla_Environ_set_nb_alg(PLA_OP_ALL_ALG,nb_alg);
|
|---|
| 90 | }
|
|---|
| 91 |
|
|---|
| 92 | /* Set the datatype */
|
|---|
| 93 | datatype = MPI_DOUBLE;
|
|---|
| 94 |
|
|---|
| 95 | /* Copy A into a*/
|
|---|
| 96 | PLA_Matrix_create(datatype,mA,nA,templ,PLA_ALIGN_FIRST,PLA_ALIGN_FIRST,&a);
|
|---|
| 97 | PLA_Obj_set_to_zero(a);
|
|---|
| 98 | /*Take array from A: use MatGetValues, because we are sure this routine works with
|
|---|
| 99 | any matrix type.*/
|
|---|
| 100 | MatGetOwnershipRange(*A,&lower_row,&upper_row);
|
|---|
| 101 | upper_row--;
|
|---|
| 102 | range=upper_row-lower_row+1;
|
|---|
| 103 | arrayA = xNew<double>(nA);
|
|---|
| 104 | idxnA = xNew<int>(nA);
|
|---|
| 105 | for (i=0;i<nA;i++){
|
|---|
| 106 | *(idxnA+i)=i;
|
|---|
| 107 | }
|
|---|
| 108 | PLA_API_begin();
|
|---|
| 109 | PLA_Obj_API_open(a);
|
|---|
| 110 | for (i=lower_row;i<=upper_row;i++){
|
|---|
| 111 | MatGetValues(*A,1,&i,nA,idxnA,arrayA);
|
|---|
| 112 | PLA_API_axpy_matrix_to_global(1,nA, &one,(void *)arrayA,1,a,i,0);
|
|---|
| 113 | }
|
|---|
| 114 | PLA_Obj_API_close(a);
|
|---|
| 115 | PLA_API_end();
|
|---|
| 116 |
|
|---|
| 117 | /*Call the plapack invert routine*/
|
|---|
| 118 | PLA_General_invert(PLA_METHOD_INV,a);
|
|---|
| 119 |
|
|---|
| 120 | /*Translate Plapack a into Petsc invA*/
|
|---|
| 121 | MatGetType(*A,&type);
|
|---|
| 122 | PlapackToPetsc(inv_A,local_mA,local_nA,mA,nA,type,a,templ,nprows,npcols,nb);
|
|---|
| 123 |
|
|---|
| 124 | /*Free ressources:*/
|
|---|
| 125 | PLA_Obj_free(&a);
|
|---|
| 126 | PLA_Temp_free(&templ);
|
|---|
| 127 | xDelete<double>(arrayA);
|
|---|
| 128 | xDelete<int>(idxnA);
|
|---|
| 129 |
|
|---|
| 130 | /*Finalize PLAPACK*/
|
|---|
| 131 | PLA_Finalize();
|
|---|
| 132 | MPI_Comm_free(&comm_2d);
|
|---|
| 133 | }
|
|---|