dune-fem 2.8.0
Loading...
Searching...
No Matches
gmres.hh
Go to the documentation of this file.
1#ifndef DUNE_FEM_GMRES_HH
2#define DUNE_FEM_GMRES_HH
3
4#include <cmath>
5#include <cassert>
6#include <iostream>
7
8#include <utility>
9
11
12namespace Dune
13{
14namespace Fem
15{
16namespace LinearSolver
17{
18
19 namespace detail {
20 template <class T>
21 class Matrix
22 {
23 public:
24 Matrix(int n, int m) : n_(n), m_(m), data_( n*m, T(0) ) {}
25 Matrix(int n) : Matrix(n,n) {}
26
27 // element access
28 T& operator()(int i, int j)
29 {
30 assert(i>=0 && i<n_ && j>=0 && j<m_);
31 return data_[i*m_ + j];
32 }
33
34 // const element access
35 T operator()(int i, int j) const
36 {
37 assert(i>=0 && i<n_ && j>=0 && j<m_);
38 return data_[i*m_ + j];
39 }
40
41 // conversion operators
42 operator T*() { return data_.data(); }
43 operator const T *() const { return data_.data(); }
44
45 protected:
46 const int n_, m_;
47 std::vector< T > data_;
48 };
49 }
50
52 template <class FieldType>
53 FieldType scalarProduct( const int dim, const FieldType *x, const FieldType* y )
54 {
55 FieldType scp = 0;
56 for( int i=0; i<dim; ++i )
57 {
58 scp += x[ i ] * y[ i ];
59 }
60 return scp;
61 }
62
63 // computes y = beta y + alpha op(A) x
64 template <class Communication, class FieldType, class DiscreteFunction>
65 void gemv(const Communication& comm,
66 const int m, // j+1
67 std::vector< DiscreteFunction >& v,
68 const DiscreteFunction& vjp,
69 FieldType *y // global_dot
70 )
71 {
72 for(int l=0; l<m; ++l)
73 {
74 y[ l ] = 0;
75 }
76
77 const auto& auxiliaryDofs = vjp.space().auxiliaryDofs();
78 const auto& vj = vjp.dofVector();
79
80 const size_t numAuxiliarys = auxiliaryDofs.size();
81 for( size_t auxiliary = 0, i = 0 ; auxiliary < numAuxiliarys; ++auxiliary )
82 {
83 const size_t nextAuxiliary = auxiliaryDofs[ auxiliary ];
84 for(; i < nextAuxiliary; ++i )
85 {
86 for(int l=0; l<m; ++l)
87 {
88 y[ l ] += (vj[ i ] * v[ l ].dofVector()[ i ]);
89 }
90 }
91 }
92
93 // communicate sum
94 comm.sum( y, m );
95 }
96
98 template<class FieldType>
99 void rotate( const int dim,
100 FieldType* x, FieldType* y,
101 const FieldType c, const FieldType s)
102 {
103 int i = dim;
104 while (i--)
105 {
106 const FieldType _x = *x;
107 const FieldType _y = *y;
108 *x = c*_x + s*_y;
109 *y = c*_y - s*_x;
110 ++x;
111 ++y;
112 }
113 }
114
115 // Saad, Youcef; Schultz, Martin H.
116 // GMRES: A generalized minimal residual algorithm for solving nonsymmetric
117 // linear systems. (English)
118 // [J] SIAM J. Sci. Stat. Comput. 7, 856-869 (1986). [ISSN 0196-5204]
119 template <class Operator, class Preconditioner, class DiscreteFunction>
120 inline int gmres( Operator& op, Preconditioner* preconditioner,
121 std::vector< DiscreteFunction >& v,
122 DiscreteFunction& u,
123 const DiscreteFunction& b,
124 const int m, // gmres inner iterations
125 const double tolerance,
126 const int maxIterations,
127 const int toleranceCriteria,
128 std::ostream* os = nullptr )
129 {
130 typedef typename DiscreteFunction :: RangeFieldType FieldType;
131
132 const auto& comm = u.space().gridPart().comm();
133
134 detail::Matrix< FieldType > H( m+1, m ); // \in \R^{m+1 \times m}
135 std::vector< FieldType > g_( 6*m, 0.0 );
136
137 FieldType* g = g_.data();
138 FieldType* s = g + (m+1);
139 FieldType* c = s + m;
140 FieldType* y = c + m;
141
142 DiscreteFunction& v0 = v[ 0 ];
143
144 std::vector< FieldType > global_dot( m+1, FieldType(0) );
145
146 // relative or absolute tolerance
147 double _tolerance = tolerance;
148 if (toleranceCriteria == ToleranceCriteria::relative)
149 {
150 global_dot[ 0 ] = b.scalarProductDofs( b );
151 _tolerance *= std::sqrt(global_dot[0]);
152 }
153
154 int iterations = 0;
155 while (true)
156 {
157 // start
158 op(u, v0);
159
160 v0 -= b ;
161
162 // cblas_daxpy(dim, -1.0, b, 1, v, 1);
163 //for( int i=0; i<dim; ++i )
164 // v[ i ] -= b[ i ];
165
166 // scalarProduct( dim, v, v );
167 global_dot[ 0 ] = v0.scalarProductDofs( v0 );
168
169 //comm.allreduce(1, local_dot, global_dot, MPI_SUM);
170 FieldType res = std::sqrt(global_dot[0]);
171
172 if (toleranceCriteria == ToleranceCriteria::residualReduction && iterations==0)
173 {
174 _tolerance *= res;
175 }
176
177 if (os)
178 {
179 (*os) << "Fem::GMRES outer iteration : " << res << std::endl;
180 }
181
182 if (res < _tolerance) break;
183
184 g[0] = -res;
185 for(int i=1; i<=m; i++) g[i] = 0.0;
186
187 // cblas_dscal(dim, 1.0/res, v, 1);
188 v0 *= (1.0/res);
189
190 //scale( dim, 1.0/res, v );
191
192 // iterate
193 for(int j=0; j<m; j++)
194 {
195 DiscreteFunction& vj = v[ j ];
196 DiscreteFunction& vjp = v[ j + 1 ];
197
198 // apply the linear operator (perhaps in combination with the
199 // preconditioner)
200 if (preconditioner)
201 {
202 DiscreteFunction& z = v[ m+1 ];
203 (*preconditioner)(vj, z );
204 op( z, vjp);
205 }
206 else
207 {
208 op(vj, vjp);
209 }
210
211 //cblas_dgemv(CblasRowMajor, CblasNoTrans,
212 // j+1, dim, 1.0, v, dim, vjp, 1, 0.0, global_dot, 1);
213 //j+1, dim, 1.0, v, dim, vjp, 1, 0.0, local_dot, 1);
214 gemv(comm, j+1, v, vjp, global_dot.data());
215
216 for(int i=0; i<=j; i++) H(i,j) = global_dot[i];
217
218 //cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
219 // 1, dim, j+1, -1.0, global_dot, m, v, dim, 1.0, vjp, dim);
220 // gemm(1, dim, j+1, -1.0, global_dot, m, v, dim, 1.0, vjp, dim);
221
222 // assuming beta == 1.0
223 for(int l=0; l<j+1; ++l)
224 {
225 vjp.axpy( -global_dot[l], v[l] );
226 }
227
228 global_dot[ 0 ] = vjp.scalarProductDofs( vjp );
229
230 H(j+1,j) = std::sqrt(global_dot[0]);
231 // cblas_dscal(dim, 1.0/H(j+1,j), vjp, 1);
232
233 vjp *= 1.0/H(j+1,j);
234 // scale(dim, 1.0/H(j+1,j), vjp );
235
236 // perform Givens rotation
237 for(int i=0; i<j; i++)
238 {
239 rotate(1, &H(i+1,j), &H(i,j), c[i], s[i]);
240 }
241
242 const FieldType h_j_j = H(j,j);
243 const FieldType h_jp_j = H(j+1,j);
244 const FieldType norm = std::sqrt(h_j_j*h_j_j + h_jp_j*h_jp_j);
245 c[j] = h_j_j / norm;
246 s[j] = -h_jp_j / norm;
247 rotate(1, &H(j+1,j), &H(j,j), c[j], s[j]);
248 rotate(1, &g[j+1], &g[j], c[j], s[j]);
249
250 if ( os )
251 {
252 (*os) << "Fem::GMRES it: " << iterations << " : " << std::abs(g[j+1]) << std::endl;
253 }
254
255 ++iterations;
256 if (std::abs(g[j+1]) < _tolerance
257 || iterations >= maxIterations ) break;
258 }
259
260 //
261 // form the approximate solution
262 //
263
264 int last = iterations%m;
265 if (last == 0) last = m;
266
267 // compute y via backsubstitution
268 for(int i=last-1; i>=0; --i)
269 {
270 const FieldType dot = scalarProduct( last-(i+1), &H(i,i)+1, &y[i+1] );
271 y[i] = (g[i] - dot)/ H(i,i);
272 }
273
274 // update the approx. solution
275 if (preconditioner)
276 {
277 // u += M^{-1} (v[0], ..., v[last-1]) y
278 DiscreteFunction& u_tmp = v[ m ]; // we don't need this vector anymore
279 DiscreteFunction& z = v[ m+1 ];
280 u_tmp.clear();
281
282 // u += (v[0], ..., v[last-1]) y
283 for(int i=0; i<last; ++i)
284 {
285 u_tmp.axpy( y[ i ], v[ i ] );
286 }
287
288 (*preconditioner)(u_tmp, z);
289 u += z;
290 }
291 else{
292 // u += (v[0], ..., v[last-1]) y
293 for(int i=0; i<last; ++i)
294 {
295 u.axpy( y[ i ], v[ i ] );
296 }
297 }
298
299 if (std::abs(g[last]) < _tolerance) break;
300 }
301
302 // output
303 if ( os ) {
304 (*os) << "Fem::GMRES: number of iterations: "
305 << iterations
306 << std::endl;
307 }
308
309 return (iterations < maxIterations) ? iterations : -iterations;
310 }
311
312
313} // end namespace Solver
314
315} // end namespace Fem
316
317} // end namespace Dune
318
319#endif
double sqrt(const Dune::Fem::Double &v)
Definition: double.hh:977
Dune::Fem::Double abs(const Dune::Fem::Double &a)
Definition: double.hh:942
Definition: bindguard.hh:11
void gemv(const Communication &comm, const int m, std::vector< DiscreteFunction > &v, const DiscreteFunction &vjp, FieldType *y)
Definition: gmres.hh:65
int gmres(Operator &op, Preconditioner *preconditioner, std::vector< DiscreteFunction > &v, DiscreteFunction &u, const DiscreteFunction &b, const int m, const double tolerance, const int maxIterations, const int toleranceCriteria, std::ostream *os=nullptr)
Definition: gmres.hh:120
void rotate(const int dim, FieldType *x, FieldType *y, const FieldType c, const FieldType s)
dblas_rotate with inc = 1
Definition: gmres.hh:99
FieldType scalarProduct(const int dim, const FieldType *x, const FieldType *y)
return x * y
Definition: gmres.hh:53
abstract operator
Definition: operator.hh:34
static const int relative
Definition: cg.hh:18
static const int residualReduction
Definition: cg.hh:19