SYCL: Avoid using no_init read accessor in rocFFT
[alexxy/gromacs.git] / src / gromacs / applied_forces / awh / biasgrid.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2015,2016,2017,2018,2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  * \brief
38  * Implements functions in grid.h.
39  *
40  * \author Viveca Lindahl
41  * \author Berk Hess <hess@kth.se>
42  * \ingroup module_awh
43  */
44
45 #include "gmxpre.h"
46
47 #include "biasgrid.h"
48
49 #include <cassert>
50 #include <cmath>
51 #include <cstring>
52
53 #include <algorithm>
54 #include <optional>
55
56 #include "gromacs/math/functions.h"
57 #include "gromacs/math/utilities.h"
58 #include "gromacs/mdtypes/awh_params.h"
59 #include "gromacs/utility/cstringutil.h"
60 #include "gromacs/utility/exceptions.h"
61 #include "gromacs/utility/gmxassert.h"
62 #include "gromacs/utility/smalloc.h"
63 #include "gromacs/utility/stringutil.h"
64
65 namespace gmx
66 {
67
68 namespace
69 {
70
71 /*! \brief
72  * Return x so that it is periodic in [-period/2, +period/2).
73  *
74  * x is modified by shifting its value by a +/- a period if
75  * needed. Thus, it is assumed that x is at most one period
76  * away from this interval. For period = 0, x is not modified.
77  *
78  * \param[in] x       Pointer to the value to modify.
79  * \param[in] period  The period, or 0 if not periodic.
80  * \returns   Value that is within the period.
81  */
82 double centerPeriodicValueAroundZero(const double x, double period)
83 {
84     GMX_ASSERT(period >= 0, "Periodic should not be negative");
85
86     const double halfPeriod = period * 0.5;
87
88     double valueInPeriod = x;
89
90     if (valueInPeriod >= halfPeriod)
91     {
92         valueInPeriod -= period;
93     }
94     else if (valueInPeriod < -halfPeriod)
95     {
96         valueInPeriod += period;
97     }
98     return valueInPeriod;
99 }
100
101 /*! \brief
102  * If period>0, retrun x so that it is periodic in [0, period), else return x.
103  *
104  * Return x is shifted its value by a +/- a period, if
105  * needed. Thus, it is assumed that x is at most one period
106  * away from this interval. For this domain and period > 0
107  * this is equivalent to x = x % period. For period = 0,
108  * x is not modified.
109  *
110  * \param[in,out] x       Pointer to the value to modify, should be >= 0.
111  * \param[in]     period  The period, or 0 if not periodic.
112  * \returns for period>0: index value witin [0, period), otherwise: \p x.
113  */
114 int indexWithinPeriod(int x, int period)
115 {
116     GMX_ASSERT(period >= 0, "Periodic should not be negative");
117
118     if (period == 0)
119     {
120         return x;
121     }
122
123     GMX_ASSERT(x > -period && x < 2 * period,
124                "x should not be more shifted by more than one period");
125
126     if (x >= period)
127     {
128         return x - period;
129     }
130     else if (x < 0)
131     {
132         return x + period;
133     }
134     else
135     {
136         return x;
137     }
138 }
139
140 /*! \brief
141  * Get the length of the interval (origin, end).
142  *
143  * This returns the distance obtained by connecting the origin point to
144  * the end point in the positive direction. Note that this is generally
145  * not the shortest distance. For period > 0, both origin and
146  * end are expected to take values in the same periodic interval,
147  * ie. |origin - end| < period.
148  *
149  * \param[in] origin    Start value of the interval.
150  * \param[in] end       End value of the interval.
151  * \param[in] period    The period, or 0 if not periodic.
152  * \returns the interval length from origin to end.
153  */
154 double getIntervalLengthPeriodic(double origin, double end, double period)
155 {
156     double length = end - origin;
157     if (length < 0)
158     {
159         /* The interval wraps around the +/- boundary which has a discontinuous jump of -period. */
160         length += period;
161     }
162
163     GMX_RELEASE_ASSERT(length >= 0, "Negative AWH grid axis length.");
164     GMX_RELEASE_ASSERT(period == 0 || length <= period, "Interval length longer than period.");
165
166     return length;
167 }
168
169 /*! \brief
170  * Get the deviation x - x0.
171  *
172  * For period > 0, the deviation with minimum absolute value is returned,
173  * i.e. with a value in the interval [-period/2, +period/2).
174  * Also for period > 0, it is assumed that |x - x0| < period.
175  *
176  * \param[in] x        From value.
177  * \param[in] x0       To value.
178  * \param[in] period   The period, or 0 if not periodic.
179  * \returns the deviation from x to x0.
180  */
181 double getDeviationPeriodic(double x, double x0, double period)
182 {
183     double dev = x - x0;
184
185     if (period > 0)
186     {
187         dev = centerPeriodicValueAroundZero(dev, period);
188     }
189
190     return dev;
191 }
192
193 } // namespace
194
195 double getDeviationFromPointAlongGridAxis(const BiasGrid& grid, int dimIndex, int pointIndex, double value)
196 {
197     double coordValue = grid.point(pointIndex).coordValue[dimIndex];
198
199     return getDeviationPeriodic(value, coordValue, grid.axis(dimIndex).period());
200 }
201
202 double getDeviationFromPointAlongGridAxis(const BiasGrid& grid, int dimIndex, int pointIndex1, int pointIndex2)
203 {
204     double coordValue1 = grid.point(pointIndex1).coordValue[dimIndex];
205     double coordValue2 = grid.point(pointIndex2).coordValue[dimIndex];
206
207     return getDeviationPeriodic(coordValue1, coordValue2, grid.axis(dimIndex).period());
208 }
209
210 bool pointsAlongLambdaAxis(const BiasGrid& grid, int pointIndex1, int pointIndex2)
211 {
212     if (!grid.hasLambdaAxis())
213     {
214         return false;
215     }
216     if (pointIndex1 == pointIndex2)
217     {
218         return true;
219     }
220     const int numDimensions = grid.numDimensions();
221     for (int d = 0; d < numDimensions; d++)
222     {
223         if (grid.axis(d).isFepLambdaAxis())
224         {
225             if (getDeviationFromPointAlongGridAxis(grid, d, pointIndex1, pointIndex2) == 0)
226             {
227                 return false;
228             }
229         }
230         else
231         {
232             if (getDeviationFromPointAlongGridAxis(grid, d, pointIndex1, pointIndex2) != 0)
233             {
234                 return false;
235             }
236         }
237     }
238     return true;
239 }
240
241 bool pointsHaveDifferentLambda(const BiasGrid& grid, int pointIndex1, int pointIndex2)
242 {
243     if (!grid.hasLambdaAxis())
244     {
245         return false;
246     }
247     if (pointIndex1 == pointIndex2)
248     {
249         return false;
250     }
251     const int numDimensions = grid.numDimensions();
252     for (int d = 0; d < numDimensions; d++)
253     {
254         if (grid.axis(d).isFepLambdaAxis())
255         {
256             if (getDeviationFromPointAlongGridAxis(grid, d, pointIndex1, pointIndex2) != 0)
257             {
258                 return true;
259             }
260         }
261     }
262     return false;
263 }
264
265 void linearArrayIndexToMultiDim(int indexLinear, int numDimensions, const awh_ivec numPointsDim, awh_ivec indexMulti)
266 {
267     for (int d = 0; d < numDimensions; d++)
268     {
269         int stride = 1;
270
271         for (int k = d + 1; k < numDimensions; k++)
272         {
273             stride *= numPointsDim[k];
274         }
275
276         indexMulti[d] = indexLinear / stride;
277         indexLinear -= indexMulti[d] * stride;
278     }
279 }
280
281 void linearGridindexToMultiDim(const BiasGrid& grid, int indexLinear, awh_ivec indexMulti)
282 {
283     awh_ivec  numPointsDim;
284     const int numDimensions = grid.numDimensions();
285     for (int d = 0; d < numDimensions; d++)
286     {
287         numPointsDim[d] = grid.axis(d).numPoints();
288     }
289
290     linearArrayIndexToMultiDim(indexLinear, numDimensions, numPointsDim, indexMulti);
291 }
292
293
294 int multiDimArrayIndexToLinear(const awh_ivec indexMulti, int numDimensions, const awh_ivec numPointsDim)
295 {
296     int stride      = 1;
297     int indexLinear = 0;
298     for (int d = numDimensions - 1; d >= 0; d--)
299     {
300         indexLinear += stride * indexMulti[d];
301         stride *= numPointsDim[d];
302     }
303
304     return indexLinear;
305 }
306
307 namespace
308 {
309
310 /*! \brief Convert a multidimensional grid point index to a linear one.
311  *
312  * \param[in] axis       The grid axes.
313  * \param[in] indexMulti Multidimensional grid point index to convert to a linear one.
314  * \returns the linear index.
315  */
316 int multiDimGridIndexToLinear(ArrayRef<const GridAxis> axis, const awh_ivec indexMulti)
317 {
318     awh_ivec numPointsDim = { 0 };
319
320     for (size_t d = 0; d < axis.size(); d++)
321     {
322         numPointsDim[d] = axis[d].numPoints();
323     }
324
325     return multiDimArrayIndexToLinear(indexMulti, axis.size(), numPointsDim);
326 }
327
328 } // namespace
329
330 int multiDimGridIndexToLinear(const BiasGrid& grid, const awh_ivec indexMulti)
331 {
332     return multiDimGridIndexToLinear(grid.axis(), indexMulti);
333 }
334
335 namespace
336 {
337
338 /*! \brief
339  * Take a step in a multidimensional array.
340  *
341  * The multidimensional index gives the starting point to step from. Dimensions are
342  * stepped through in order of decreasing dimensional index such that the index is
343  * incremented in the highest dimension possible. If the starting point is the end
344  * of the array, a step cannot be taken and the index is not modified.
345  *
346  * \param[in] numDim        Number of dimensions of the array.
347  * \param[in] numPoints     Vector with the number of points along each dimension.
348  * \param[in,out] indexDim  Multidimensional index, each with values in [0, numPoints[d] - 1].
349  * \returns true if a step was taken, false if not.
350  */
351 bool stepInMultiDimArray(int numDim, const awh_ivec numPoints, awh_ivec indexDim)
352 {
353     bool haveStepped = false;
354
355     for (int d = numDim - 1; d >= 0 && !haveStepped; d--)
356     {
357         if (indexDim[d] < numPoints[d] - 1)
358         {
359             /* Not at a boundary, just increase by 1. */
360             indexDim[d]++;
361             haveStepped = true;
362         }
363         else
364         {
365             /* At a boundary. If we are not at the end of the array,
366                reset the index and check if we can step in higher dimensions */
367             if (d > 0)
368             {
369                 indexDim[d] = 0;
370             }
371         }
372     }
373
374     return haveStepped;
375 }
376
377 /*! \brief
378  * Transforms a grid point index to to the multidimensional index of a subgrid.
379  *
380  * The subgrid is defined by the location of its origin and the number of points
381  * along each dimension. The index transformation thus consists of a projection
382  * of the linear index onto each dimension, followed by a translation of the origin.
383  * The subgrid may have parts that don't overlap with the grid. E.g. the origin
384  * vector can have negative components meaning the origin lies outside of the grid.
385  * However, the given point needs to be both a grid and subgrid point.
386  *
387  * Periodic boundaries are taken care of by wrapping the subgrid around the grid.
388  * Thus, for periodic dimensions the number of subgrid points need to be less than
389  * the number of points in a period to prevent problems of wrapping around.
390  *
391  * \param[in]     grid            The grid.
392  * \param[in]     subgridOrigin   Vector locating the subgrid origin relative to the grid origin.
393  * \param[in]     subgridNpoints  The number of subgrid points in each dimension.
394  * \param[in]     point           BiasGrid point to get subgrid index for.
395  * \param[in,out] subgridIndex    Subgrid multidimensional index.
396  */
397 void gridToSubgridIndex(const BiasGrid& grid,
398                         const awh_ivec  subgridOrigin,
399                         const awh_ivec  subgridNpoints,
400                         int             point,
401                         awh_ivec        subgridIndex)
402 {
403     /* Get the subgrid index of the given grid point, for each dimension. */
404     for (int d = 0; d < grid.numDimensions(); d++)
405     {
406         /* The multidimensional grid point index relative to the subgrid origin. */
407         subgridIndex[d] = indexWithinPeriod(grid.point(point).index[d] - subgridOrigin[d],
408                                             grid.axis(d).numPointsInPeriod());
409
410         /* The given point should be in the subgrid. */
411         GMX_RELEASE_ASSERT((subgridIndex[d] >= 0) && (subgridIndex[d] < subgridNpoints[d]),
412                            "Attempted to convert an AWH grid point index not in subgrid to out of "
413                            "bounds subgrid index");
414     }
415 }
416
417 /*! \brief
418  * Transform a multidimensional subgrid index to a grid point index.
419  *
420  * If the given subgrid point is not a grid point the transformation will not be successful
421  * and the grid point index will not be set. Periodic boundaries are taken care of by
422  * wrapping the subgrid around the grid.
423  *
424  * \param[in]     grid           The grid.
425  * \param[in]     subgridOrigin  Vector locating the subgrid origin relative to the grid origin.
426  * \param[in]     subgridIndex   Subgrid multidimensional index to get grid point index for.
427  * \param[in,out] gridIndex      BiasGrid point index.
428  * \returns true if the transformation was successful.
429  */
430 bool subgridToGridIndex(const BiasGrid& grid, const awh_ivec subgridOrigin, const awh_ivec subgridIndex, int* gridIndex)
431 {
432     awh_ivec globalIndexDim;
433
434     /* Check and apply boundary conditions for each dimension */
435     for (int d = 0; d < grid.numDimensions(); d++)
436     {
437         /* Transform to global multidimensional indexing by adding the origin */
438         globalIndexDim[d] = subgridOrigin[d] + subgridIndex[d];
439
440         /* The local grid is allowed to stick out on the edges of the global grid. Here the boundary conditions are applied.*/
441         if (globalIndexDim[d] < 0 || globalIndexDim[d] > grid.axis(d).numPoints() - 1)
442         {
443             /* Try to wrap around if periodic. Otherwise, the transformation failed so return. */
444             if (!grid.axis(d).isPeriodic())
445             {
446                 return false;
447             }
448
449             /* The grid might not contain a whole period. Can only wrap around if this gap is not too large. */
450             int gap = grid.axis(d).numPointsInPeriod() - grid.axis(d).numPoints();
451
452             int bridge;
453             int numWrapped;
454             if (globalIndexDim[d] < 0)
455             {
456                 bridge     = -globalIndexDim[d];
457                 numWrapped = bridge - gap;
458                 if (numWrapped > 0)
459                 {
460                     globalIndexDim[d] = grid.axis(d).numPoints() - numWrapped;
461                 }
462             }
463             else
464             {
465                 bridge     = globalIndexDim[d] - (grid.axis(d).numPoints() - 1);
466                 numWrapped = bridge - gap;
467                 if (numWrapped > 0)
468                 {
469                     globalIndexDim[d] = numWrapped - 1;
470                 }
471             }
472
473             if (numWrapped <= 0)
474             {
475                 return false;
476             }
477         }
478     }
479
480     /* Translate from multidimensional to linear indexing and set the return value */
481     (*gridIndex) = multiDimGridIndexToLinear(grid, globalIndexDim);
482
483     return true;
484 }
485
486 } // namespace
487
488 bool advancePointInSubgrid(const BiasGrid& grid,
489                            const awh_ivec  subgridOrigin,
490                            const awh_ivec  subgridNumPoints,
491                            int*            gridPointIndex)
492 {
493     /* Initialize the subgrid index to the subgrid origin. */
494     awh_ivec subgridIndex = { 0 };
495
496     /* Get the subgrid index of the given grid point index. */
497     if (*gridPointIndex >= 0)
498     {
499         gridToSubgridIndex(grid, subgridOrigin, subgridNumPoints, *gridPointIndex, subgridIndex);
500     }
501     else
502     {
503         /* If no grid point is given we start at the subgrid origin (which subgridIndex is initialized to).
504            If this is a valid grid point then we're done, otherwise keep looking below. */
505         /* TODO: separate into a separate function (?) */
506         if (subgridToGridIndex(grid, subgridOrigin, subgridIndex, gridPointIndex))
507         {
508             return true;
509         }
510     }
511
512     /* Traverse the subgrid and look for the first point that is also in the grid. */
513     while (stepInMultiDimArray(grid.numDimensions(), subgridNumPoints, subgridIndex))
514     {
515         /* If this is a valid grid point, the grid point index is updated.*/
516         if (subgridToGridIndex(grid, subgridOrigin, subgridIndex, gridPointIndex))
517         {
518             return true;
519         }
520     }
521
522     return false;
523 }
524
525 /*! \brief
526  * Returns the point distance between from value x to value x0 along the given axis.
527  *
528  * Note that the returned distance may be negative or larger than the
529  * number of points in the axis. For a periodic axis, the distance is chosen
530  * to be in [0, period), i.e. always positive but not the shortest one.
531  *
532  * \param[in]  axis   BiasGrid axis.
533  * \param[in]  x      From value.
534  * \param[in]  x0     To value.
535  * \returns (x - x0) in number of points.
536  */
537 static int pointDistanceAlongAxis(const GridAxis& axis, double x, double x0)
538 {
539     int distance = 0;
540
541     if (axis.spacing() > 0)
542     {
543         /* Get the real-valued distance. For a periodic axis, the shortest one. */
544         double period = axis.period();
545         double dx     = getDeviationPeriodic(x, x0, period);
546
547         /* Transform the distance into a point distance by rounding. */
548         distance = gmx::roundToInt(dx / axis.spacing());
549
550         /* If periodic, shift the point distance to be in [0, period) */
551         distance = indexWithinPeriod(distance, axis.numPointsInPeriod());
552     }
553
554     return distance;
555 }
556
557 /*! \brief
558  * Query if a value is in range of the grid.
559  *
560  * \param[in] value   Value to check.
561  * \param[in] axis    The grid axes.
562  * \returns true if the value is in the grid.
563  */
564 static bool valueIsInGrid(const awh_dvec value, ArrayRef<const GridAxis> axis)
565 {
566     /* For each dimension get the one-dimensional index and check if it is in range. */
567     for (size_t d = 0; d < axis.size(); d++)
568     {
569         /* The index is computed as the point distance from the origin. */
570         int index = pointDistanceAlongAxis(axis[d], value[d], axis[d].origin());
571
572         if (!(index >= 0 && index < axis[d].numPoints()))
573         {
574             return false;
575         }
576     }
577
578     return true;
579 }
580
581 bool BiasGrid::covers(const awh_dvec value) const
582 {
583     return valueIsInGrid(value, axis());
584 }
585
586 std::optional<int> BiasGrid::lambdaAxisIndex() const
587 {
588     for (size_t i = 0; i < axis_.size(); i++)
589     {
590         if (axis_[i].isFepLambdaAxis())
591         {
592             return i;
593         }
594     }
595     return {};
596 }
597
598 int BiasGrid::numFepLambdaStates() const
599 {
600     for (size_t i = 0; i < axis_.size(); i++)
601     {
602         if (axis_[i].isFepLambdaAxis())
603         {
604             return axis_[i].numPoints();
605         }
606     }
607     return 0;
608 }
609
610 int GridAxis::nearestIndex(double value) const
611 {
612     /* Get the point distance to the origin. This may by an out of index range for the axis. */
613     int index = pointDistanceAlongAxis(*this, value, origin_);
614
615     if (index < 0 || index >= numPoints_)
616     {
617         if (isPeriodic())
618         {
619             GMX_RELEASE_ASSERT(index >= 0 && index < numPointsInPeriod_,
620                                "Index not in periodic interval 0 for AWH periodic axis");
621             int endDistance    = (index - (numPoints_ - 1));
622             int originDistance = (numPointsInPeriod_ - index);
623             index              = originDistance < endDistance ? 0 : numPoints_ - 1;
624         }
625         else
626         {
627             index = (index < 0) ? 0 : (numPoints_ - 1);
628         }
629     }
630
631     return index;
632 }
633
634 /*! \brief
635  * Map a value to the nearest point in the grid.
636  *
637  * \param[in] value  Value.
638  * \param[in] axis   The grid axes.
639  * \returns the point index nearest to the value.
640  */
641 static int getNearestIndexInGrid(const awh_dvec value, ArrayRef<const GridAxis> axis)
642 {
643     awh_ivec indexMulti;
644
645     /* If the index is out of range, modify it so that it is in range by choosing the nearest point on the edge. */
646     for (size_t d = 0; d < axis.size(); d++)
647     {
648         indexMulti[d] = axis[d].nearestIndex(value[d]);
649     }
650
651     return multiDimGridIndexToLinear(axis, indexMulti);
652 }
653
654 int BiasGrid::nearestIndex(const awh_dvec value) const
655 {
656     return getNearestIndexInGrid(value, axis());
657 }
658
659 namespace
660 {
661
662 /*! \brief
663  * Find and set the neighbors of a grid point.
664  *
665  * The search space for neighbors is a subgrid with size set by a scope cutoff.
666  * In general not all point within scope will be valid grid points.
667  *
668  * \param[in]     pointIndex           BiasGrid point index.
669  * \param[in]     grid                 The grid.
670  * \param[in,out] neighborIndexArray   Array to fill with neighbor indices.
671  */
672 void setNeighborsOfGridPoint(int pointIndex, const BiasGrid& grid, std::vector<int>* neighborIndexArray)
673 {
674     const int c_maxNeighborsAlongAxis =
675             1 + 2 * static_cast<int>(BiasGrid::c_numPointsPerSigma * BiasGrid::c_scopeCutoff);
676
677     awh_ivec numCandidates = { 0 };
678     awh_ivec subgridOrigin = { 0 };
679     for (int d = 0; d < grid.numDimensions(); d++)
680     {
681         if (grid.axis(d).isFepLambdaAxis())
682         {
683             /* Use all points along an axis linked to FEP */
684             numCandidates[d] = grid.axis(d).numPoints();
685             subgridOrigin[d] = 0;
686         }
687         else
688         {
689             /* The number of candidate points along this dimension is given by the scope cutoff. */
690             numCandidates[d] = std::min(c_maxNeighborsAlongAxis, grid.axis(d).numPoints());
691
692             /* The origin of the subgrid to search */
693             int centerIndex  = grid.point(pointIndex).index[d];
694             subgridOrigin[d] = centerIndex - numCandidates[d] / 2;
695         }
696     }
697
698     /* Find and set the neighbors */
699     int  neighborIndex = -1;
700     bool aPointExists  = true;
701
702     /* Keep looking for grid points while traversing the subgrid. */
703     while (aPointExists)
704     {
705         /* The point index is updated if a grid point was found. */
706         aPointExists = advancePointInSubgrid(grid, subgridOrigin, numCandidates, &neighborIndex);
707
708         if (aPointExists)
709         {
710             neighborIndexArray->push_back(neighborIndex);
711         }
712     }
713 }
714
715 } // namespace
716
717 void BiasGrid::initPoints()
718 {
719     awh_ivec numPointsDimWork = { 0 };
720     awh_ivec indexWork        = { 0 };
721
722     for (size_t d = 0; d < axis_.size(); d++)
723     {
724         /* Temporarily gather the number of points in each dimension in one array */
725         numPointsDimWork[d] = axis_[d].numPoints();
726     }
727
728     for (auto& point : point_)
729     {
730         for (size_t d = 0; d < axis_.size(); d++)
731         {
732             if (axis_[d].isFepLambdaAxis())
733             {
734                 point.coordValue[d] = indexWork[d];
735             }
736             else
737             {
738                 point.coordValue[d] = axis_[d].origin() + indexWork[d] * axis_[d].spacing();
739             }
740
741             if (axis_[d].period() > 0)
742             {
743                 /* Do we always want the values to be centered around 0 ? */
744                 point.coordValue[d] =
745                         centerPeriodicValueAroundZero(point.coordValue[d], axis_[d].period());
746             }
747
748             point.index[d] = indexWork[d];
749         }
750
751         stepInMultiDimArray(axis_.size(), numPointsDimWork, indexWork);
752     }
753 }
754
755 GridAxis::GridAxis(double origin, double end, double period, double pointDensity) :
756     origin_(origin), period_(period), isFepLambdaAxis_(false)
757 {
758     length_ = getIntervalLengthPeriodic(origin_, end, period_);
759
760     /* Automatically determine number of points based on the user given endpoints
761        and the expected fluctuations in the umbrella. */
762     if (length_ == 0)
763     {
764         numPoints_ = 1;
765     }
766     else if (pointDensity == 0)
767     {
768         numPoints_ = 2;
769     }
770     else
771     {
772         /* An extra point is added here to account for the endpoints. The
773            minimum number of points for a non-zero interval is 2. */
774         numPoints_ = 1 + static_cast<int>(std::ceil(length_ * pointDensity));
775     }
776
777     /* Set point spacing based on the number of points */
778     if (isPeriodic())
779     {
780         /* Set the grid spacing so that a period is matched exactly by an integer number of points.
781            The number of points in a period is equal to the number of grid spacings in a period
782            since the endpoints are connected.  */
783         numPointsInPeriod_ =
784                 length_ > 0 ? static_cast<int>(std::ceil(period / length_ * (numPoints_ - 1))) : 1;
785         spacing_ = period_ / numPointsInPeriod_;
786
787         /* Modify the number of grid axis points to be compatible with the period dependent spacing. */
788         numPoints_ = std::min(static_cast<int>(round(length_ / spacing_)) + 1, numPointsInPeriod_);
789     }
790     else
791     {
792         numPointsInPeriod_ = 0;
793         spacing_           = numPoints_ > 1 ? length_ / (numPoints_ - 1) : 0;
794     }
795 }
796
797 GridAxis::GridAxis(double origin, double end, double period, int numPoints, bool isFepLambdaAxis) :
798     origin_(origin), period_(period), numPoints_(numPoints), isFepLambdaAxis_(isFepLambdaAxis)
799 {
800     if (isFepLambdaAxis)
801     {
802         length_            = end - origin_;
803         spacing_           = 1;
804         numPointsInPeriod_ = numPoints;
805     }
806     else
807     {
808         length_            = getIntervalLengthPeriodic(origin_, end, period_);
809         spacing_           = numPoints_ > 1 ? length_ / (numPoints_ - 1) : period_;
810         numPointsInPeriod_ = static_cast<int>(std::round(period_ / spacing_));
811     }
812 }
813
814 BiasGrid::BiasGrid(ArrayRef<const DimParams> dimParams, ArrayRef<const AwhDimParams> awhDimParams)
815 {
816     GMX_RELEASE_ASSERT(dimParams.size() == awhDimParams.size(), "Dimensions needs to be equal");
817     /* Define the discretization along each dimension */
818     awh_dvec period;
819     int      numPoints = 1;
820     for (int d = 0; d < gmx::ssize(awhDimParams); d++)
821     {
822         double origin = dimParams[d].scaleUserInputToInternal(awhDimParams[d].origin());
823         double end    = dimParams[d].scaleUserInputToInternal(awhDimParams[d].end());
824         if (awhDimParams[d].coordinateProvider() == AwhCoordinateProviderType::Pull)
825         {
826             period[d] = dimParams[d].scaleUserInputToInternal(awhDimParams[d].period());
827             static_assert(
828                     c_numPointsPerSigma >= 1.0,
829                     "The number of points per sigma should be at least 1.0 to get a uniformly "
830                     "covering the reaction using Gaussians");
831             double pointDensity = std::sqrt(dimParams[d].pullDimParams().betak) * c_numPointsPerSigma;
832             axis_.emplace_back(origin, end, period[d], pointDensity);
833         }
834         else
835         {
836             axis_.emplace_back(origin, end, 0, dimParams[d].fepDimParams().numFepLambdaStates, true);
837         }
838         numPoints *= axis_[d].numPoints();
839     }
840
841     point_.resize(numPoints);
842
843     /* Set their values */
844     initPoints();
845
846     /* Keep a neighbor list for each point.
847      * Note: could also generate neighbor list only when needed
848      * instead of storing them for each point.
849      */
850     for (size_t m = 0; m < point_.size(); m++)
851     {
852         std::vector<int>* neighbor = &point_[m].neighbor;
853
854         setNeighborsOfGridPoint(m, *this, neighbor);
855     }
856 }
857
858 void mapGridToDataGrid(std::vector<int>*    gridpointToDatapoint,
859                        const double* const* data,
860                        int                  numDataPoints,
861                        const std::string&   dataFilename,
862                        const BiasGrid&      grid,
863                        const std::string&   correctFormatMessage)
864 {
865     /* Transform the data into a grid in order to map each grid point to a data point
866        using the grid functions. */
867
868     /* Count the number of points for each dimension. Each dimension
869        has its own stride. */
870     int               stride           = 1;
871     int               numPointsCounted = 0;
872     std::vector<int>  numPoints(grid.numDimensions());
873     std::vector<bool> isFepLambdaAxis(grid.numDimensions());
874     for (int d = grid.numDimensions() - 1; d >= 0; d--)
875     {
876         int    numPointsInDim = 0;
877         int    pointIndex     = 0;
878         double firstValue     = data[d][pointIndex];
879         do
880         {
881             numPointsInDim++;
882             pointIndex += stride;
883         } while (pointIndex < numDataPoints
884                  && !gmx_within_tol(firstValue, data[d][pointIndex], GMX_REAL_EPS));
885
886         /* The stride in dimension dimension d - 1 equals the number of points
887            dimension d. */
888         stride = numPointsInDim;
889
890         numPointsCounted = (numPointsCounted == 0) ? numPointsInDim : numPointsCounted * numPointsInDim;
891
892         numPoints[d]       = numPointsInDim;
893         isFepLambdaAxis[d] = grid.axis(d).isFepLambdaAxis();
894     }
895
896     if (numPointsCounted != numDataPoints)
897     {
898         std::string mesg = gmx::formatString(
899                 "Could not extract data properly from %s. Wrong data format?"
900                 "\n\n%s",
901                 dataFilename.c_str(),
902                 correctFormatMessage.c_str());
903         GMX_THROW(InvalidInputError(mesg));
904     }
905
906     std::vector<GridAxis> axis_;
907     axis_.reserve(grid.numDimensions());
908     /* The data grid has the data that was read and the properties of the AWH grid */
909     for (int d = 0; d < grid.numDimensions(); d++)
910     {
911         if (isFepLambdaAxis[d])
912         {
913             axis_.emplace_back(data[d][0], data[d][numDataPoints - 1], 0, numPoints[d], true);
914         }
915         else
916         {
917             axis_.emplace_back(
918                     data[d][0], data[d][numDataPoints - 1], grid.axis(d).period(), numPoints[d], false);
919         }
920     }
921
922     /* Map each grid point to a data point. No interpolation, just pick the nearest one.
923      * It is assumed that the given data is uniformly spaced for each dimension.
924      */
925     for (size_t m = 0; m < grid.numPoints(); m++)
926     {
927         /* We only define what we need for the datagrid since it's not needed here which is a bit ugly */
928
929         if (!valueIsInGrid(grid.point(m).coordValue, axis_))
930         {
931             std::string mesg = gmx::formatString(
932                     "%s does not contain data for all coordinate values. "
933                     "Make sure your input data covers the whole sampling domain "
934                     "and is correctly formatted. \n\n%s",
935                     dataFilename.c_str(),
936                     correctFormatMessage.c_str());
937             GMX_THROW(InvalidInputError(mesg));
938         }
939         (*gridpointToDatapoint)[m] = getNearestIndexInGrid(grid.point(m).coordValue, axis_);
940     }
941 }
942
943 } // namespace gmx