Replace EnumOption with EnumerationArrayOption
[alexxy/gromacs.git] / src / gromacs / trajectoryanalysis / modules / pairdist.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2014,2015,2016,2018,2019,2020, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  * \brief
37  * Implements gmx::analysismodules::PairDistance.
38  *
39  * \author Teemu Murtola <teemu.murtola@gmail.com>
40  * \ingroup module_trajectoryanalysis
41  */
42 #include "gmxpre.h"
43
44 #include "pairdist.h"
45
46 #include <cmath>
47
48 #include <algorithm>
49 #include <limits>
50 #include <string>
51 #include <vector>
52
53 #include "gromacs/analysisdata/analysisdata.h"
54 #include "gromacs/analysisdata/modules/plot.h"
55 #include "gromacs/options/basicoptions.h"
56 #include "gromacs/options/filenameoption.h"
57 #include "gromacs/options/ioptionscontainer.h"
58 #include "gromacs/selection/nbsearch.h"
59 #include "gromacs/selection/selection.h"
60 #include "gromacs/selection/selectionoption.h"
61 #include "gromacs/trajectory/trajectoryframe.h"
62 #include "gromacs/trajectoryanalysis/analysissettings.h"
63 #include "gromacs/trajectoryanalysis/topologyinformation.h"
64 #include "gromacs/utility/arrayref.h"
65 #include "gromacs/utility/exceptions.h"
66 #include "gromacs/utility/stringutil.h"
67
68 namespace gmx
69 {
70
71 namespace analysismodules
72 {
73
74 namespace
75 {
76
77 //! \addtogroup module_trajectoryanalysis
78 //! \{
79
80 //! Enum value to store the selected value for `-type`.
81 enum class DistanceType : int
82 {
83     Min,
84     Max,
85     Count
86 };
87
88 //! Enum value to store the selected value for `-refgrouping`/`-selgrouping`.
89 enum class GroupType : int
90 {
91     All,
92     Residue,
93     Molecule,
94     None,
95     Count
96 };
97
98 //! Strings corresponding to DistanceType.
99 const EnumerationArray<DistanceType, const char*> c_distanceTypeNames = { { "min", "max" } };
100 //! Strings corresponding to GroupType.
101 const EnumerationArray<GroupType, const char*> c_groupTypeNames = { { "all", "res", "mol",
102                                                                       "none" } };
103
104 /*! \brief
105  * Implements `gmx pairdist` trajectory analysis module.
106  */
107 class PairDistance : public TrajectoryAnalysisModule
108 {
109 public:
110     PairDistance();
111
112     void initOptions(IOptionsContainer* options, TrajectoryAnalysisSettings* settings) override;
113     void initAnalysis(const TrajectoryAnalysisSettings& settings, const TopologyInformation& top) override;
114
115     TrajectoryAnalysisModuleDataPointer startFrames(const AnalysisDataParallelOptions& opt,
116                                                     const SelectionCollection& selections) override;
117     void                                analyzeFrame(int frnr, const t_trxframe& fr, t_pbc* pbc, TrajectoryAnalysisModuleData* pdata) override;
118
119     void finishAnalysis(int nframes) override;
120     void writeOutput() override;
121
122 private:
123     /*! \brief
124      * Computed distances as a function of time.
125      *
126      * There is one data set for each selection in `sel_`.
127      * Within each data set, there is one column for each distance to be
128      * computed, as explained in the `-h` text.
129      */
130     AnalysisData distances_;
131
132     /*! \brief
133      * Reference selection to compute distances to.
134      *
135      * mappedId() identifies the group (of type `refGroupType_`) into which
136      * each position belogs.
137      */
138     Selection refSel_;
139     /*! \brief
140      * Selections to compute distances from.
141      *
142      * mappedId() identifies the group (of type `selGroupType_`) into which
143      * each position belogs.
144      */
145     SelectionList sel_;
146
147     std::string fnDist_;
148
149     double       cutoff_;
150     DistanceType distanceType_;
151     GroupType    refGroupType_;
152     GroupType    selGroupType_;
153
154     //! Number of groups in `refSel_`.
155     int refGroupCount_;
156     //! Maximum number of pairs of groups for one selection.
157     int maxGroupCount_;
158     //! Initial squared distance for distance accumulation.
159     real initialDist2_;
160     //! Cutoff squared for use in the actual calculation.
161     real cutoff2_;
162
163     //! Neighborhood search object for the pair search.
164     AnalysisNeighborhood nb_;
165
166     // Copy and assign disallowed by base.
167 };
168
169 PairDistance::PairDistance() :
170     cutoff_(0.0),
171     distanceType_(DistanceType::Min),
172     refGroupType_(GroupType::All),
173     selGroupType_(GroupType::All),
174     refGroupCount_(0),
175     maxGroupCount_(0),
176     initialDist2_(0.0),
177     cutoff2_(0.0)
178 {
179     registerAnalysisDataset(&distances_, "dist");
180 }
181
182
183 void PairDistance::initOptions(IOptionsContainer* options, TrajectoryAnalysisSettings* settings)
184 {
185     static const char* const desc[] = {
186         "[THISMODULE] calculates pairwise distances between one reference",
187         "selection (given with [TT]-ref[tt]) and one or more other selections",
188         "(given with [TT]-sel[tt]).  It can calculate either the minimum",
189         "distance (the default), or the maximum distance (with",
190         "[TT]-type max[tt]).  Distances to each selection provided with",
191         "[TT]-sel[tt] are computed independently.[PAR]",
192         "By default, the global minimum/maximum distance is computed.",
193         "To compute more distances (e.g., minimum distances to each residue",
194         "in [TT]-ref[tt]), use [TT]-refgrouping[tt] and/or [TT]-selgrouping[tt]",
195         "to specify how the positions within each selection should be",
196         "grouped.[PAR]",
197         "Computed distances are written to the file specified with [TT]-o[tt].",
198         "If there are N groups in [TT]-ref[tt] and M groups in the first",
199         "selection in [TT]-sel[tt], then the output contains N*M columns",
200         "for the first selection. The columns contain distances like this:",
201         "r1-s1, r2-s1, ..., r1-s2, r2-s2, ..., where rn is the n'th group",
202         "in [TT]-ref[tt] and sn is the n'th group in the other selection.",
203         "The distances for the second selection comes as separate columns",
204         "after the first selection, and so on.  If some selections are",
205         "dynamic, only the selected positions are used in the computation",
206         "but the same number of columns is always written out.  If there",
207         "are no positions contributing to some group pair, then the cutoff",
208         "value is written (see below).[PAR]",
209         "[TT]-cutoff[tt] sets a cutoff for the computed distances.",
210         "If the result would contain a distance over the cutoff, the cutoff",
211         "value is written to the output file instead. By default, no cutoff",
212         "is used, but if you are not interested in values beyond a cutoff,",
213         "or if you know that the minimum distance is smaller than a cutoff,",
214         "you should set this option to allow the tool to use grid-based",
215         "searching and be significantly faster.[PAR]",
216         "If you want to compute distances between fixed pairs,",
217         "[gmx-distance] may be a more suitable tool."
218     };
219
220     settings->setHelpText(desc);
221
222     options->addOption(FileNameOption("o")
223                                .filetype(eftPlot)
224                                .outputFile()
225                                .required()
226                                .store(&fnDist_)
227                                .defaultBasename("dist")
228                                .description("Distances as function of time"));
229
230     options->addOption(
231             DoubleOption("cutoff").store(&cutoff_).description("Maximum distance to consider"));
232     options->addOption(EnumOption<DistanceType>("type")
233                                .store(&distanceType_)
234                                .enumValue(c_distanceTypeNames)
235                                .description("Type of distances to calculate"));
236     options->addOption(
237             EnumOption<GroupType>("refgrouping")
238                     .store(&refGroupType_)
239                     .enumValue(c_groupTypeNames)
240                     .description("Grouping of -ref positions to compute the min/max over"));
241     options->addOption(
242             EnumOption<GroupType>("selgrouping")
243                     .store(&selGroupType_)
244                     .enumValue(c_groupTypeNames)
245                     .description("Grouping of -sel positions to compute the min/max over"));
246
247     options->addOption(SelectionOption("ref").store(&refSel_).required().description(
248             "Reference positions to calculate distances from"));
249     options->addOption(SelectionOption("sel").storeVector(&sel_).required().multiValue().description(
250             "Positions to calculate distances for"));
251 }
252
253 //! Helper function to initialize the grouping for a selection.
254 int initSelectionGroups(Selection* sel, const gmx_mtop_t* top, GroupType type)
255 {
256     e_index_t indexType = INDEX_UNKNOWN;
257     switch (type)
258     {
259         case GroupType::All: indexType = INDEX_ALL; break;
260         case GroupType::Residue: indexType = INDEX_RES; break;
261         case GroupType::Molecule: indexType = INDEX_MOL; break;
262         case GroupType::None: indexType = INDEX_ATOM; break;
263         case GroupType::Count: GMX_THROW(InternalError("Invalid GroupType"));
264     }
265     return sel->initOriginalIdsToGroup(top, indexType);
266 }
267
268
269 void PairDistance::initAnalysis(const TrajectoryAnalysisSettings& settings, const TopologyInformation& top)
270 {
271     refGroupCount_ = initSelectionGroups(&refSel_, top.mtop(), refGroupType_);
272
273     maxGroupCount_ = 0;
274     distances_.setDataSetCount(sel_.size());
275     for (size_t i = 0; i < sel_.size(); ++i)
276     {
277         const int selGroupCount = initSelectionGroups(&sel_[i], top.mtop(), selGroupType_);
278         const int columnCount   = refGroupCount_ * selGroupCount;
279         maxGroupCount_          = std::max(maxGroupCount_, columnCount);
280         distances_.setColumnCount(i, columnCount);
281     }
282
283     if (!fnDist_.empty())
284     {
285         AnalysisDataPlotModulePointer plotm(new AnalysisDataPlotModule(settings.plotSettings()));
286         plotm->setFileName(fnDist_);
287         if (distanceType_ == DistanceType::Max)
288         {
289             plotm->setTitle("Maximum distance");
290         }
291         else
292         {
293             plotm->setTitle("Minimum distance");
294         }
295         // TODO: Figure out and add a descriptive subtitle and/or a longer
296         // title and/or better legends based on the grouping and the reference
297         // selection.
298         plotm->setXAxisIsTime();
299         plotm->setYLabel("Distance (nm)");
300         for (size_t g = 0; g < sel_.size(); ++g)
301         {
302             plotm->appendLegend(sel_[g].name());
303         }
304         distances_.addModule(plotm);
305     }
306
307     nb_.setCutoff(cutoff_);
308     if (cutoff_ <= 0.0)
309     {
310         cutoff_       = 0.0;
311         initialDist2_ = std::numeric_limits<real>::max();
312     }
313     else
314     {
315         initialDist2_ = cutoff_ * cutoff_;
316     }
317     if (distanceType_ == DistanceType::Max)
318     {
319         initialDist2_ = 0.0;
320     }
321     cutoff2_ = cutoff_ * cutoff_;
322 }
323
324 /*! \brief
325  * Temporary memory for use within a single-frame calculation.
326  */
327 class PairDistanceModuleData : public TrajectoryAnalysisModuleData
328 {
329 public:
330     /*! \brief
331      * Reserves memory for the frame-local data.
332      */
333     PairDistanceModuleData(TrajectoryAnalysisModule*          module,
334                            const AnalysisDataParallelOptions& opt,
335                            const SelectionCollection&         selections,
336                            int                                refGroupCount,
337                            const Selection&                   refSel,
338                            int                                maxGroupCount) :
339         TrajectoryAnalysisModuleData(module, opt, selections)
340     {
341         distArray_.resize(maxGroupCount);
342         countArray_.resize(maxGroupCount);
343         refCountArray_.resize(refGroupCount);
344         if (!refSel.isDynamic())
345         {
346             initRefCountArray(refSel);
347         }
348     }
349
350     void finish() override { finishDataHandles(); }
351
352     /*! \brief
353      * Computes the number of positions in each group in \p refSel
354      * and stores them into `refCountArray_`.
355      */
356     void initRefCountArray(const Selection& refSel)
357     {
358         std::fill(refCountArray_.begin(), refCountArray_.end(), 0);
359         int refPos = 0;
360         while (refPos < refSel.posCount())
361         {
362             const int refIndex = refSel.position(refPos).mappedId();
363             const int startPos = refPos;
364             ++refPos;
365             while (refPos < refSel.posCount() && refSel.position(refPos).mappedId() == refIndex)
366             {
367                 ++refPos;
368             }
369             refCountArray_[refIndex] = refPos - startPos;
370         }
371     }
372
373     /*! \brief
374      * Squared distance between each group
375      *
376      * One entry for each group pair for the current selection.
377      * Enough memory is allocated to fit the largest calculation selection.
378      * This is needed to support neighborhood searching, which may not
379      * return the pairs in order: for each group pair, we need to search
380      * through all the position pairs and update this array to find the
381      * minimum/maximum distance between them.
382      */
383     std::vector<real> distArray_;
384     /*! \brief
385      * Number of pairs within the cutoff that have contributed to the value
386      * in `distArray_`.
387      *
388      * This is needed to identify whether there were any pairs inside the
389      * cutoff and whether there were additional pairs outside the cutoff
390      * that were not covered by the neihborhood search.
391      */
392     std::vector<int> countArray_;
393     /*! \brief
394      * Number of positions within each reference group.
395      *
396      * This is used to more efficiently compute the total number of pairs
397      * (for comparison with `countArray_`), as otherwise these numbers
398      * would need to be recomputed for each selection.
399      */
400     std::vector<int> refCountArray_;
401 };
402
403 TrajectoryAnalysisModuleDataPointer PairDistance::startFrames(const AnalysisDataParallelOptions& opt,
404                                                               const SelectionCollection& selections)
405 {
406     return TrajectoryAnalysisModuleDataPointer(new PairDistanceModuleData(
407             this, opt, selections, refGroupCount_, refSel_, maxGroupCount_));
408 }
409
410 void PairDistance::analyzeFrame(int frnr, const t_trxframe& fr, t_pbc* pbc, TrajectoryAnalysisModuleData* pdata)
411 {
412     AnalysisDataHandle      dh         = pdata->dataHandle(distances_);
413     const Selection&        refSel     = pdata->parallelSelection(refSel_);
414     const SelectionList&    sel        = pdata->parallelSelections(sel_);
415     PairDistanceModuleData& frameData  = *static_cast<PairDistanceModuleData*>(pdata);
416     std::vector<real>&      distArray  = frameData.distArray_;
417     std::vector<int>&       countArray = frameData.countArray_;
418
419     if (cutoff_ > 0.0 && refSel.isDynamic())
420     {
421         // Count the number of reference positions in each group, so that
422         // this does not need to be computed again for each selection.
423         // This is needed only if it is possible that the neighborhood search
424         // does not cover all the pairs, hence the cutoff > 0.0 check.
425         // If refSel is static, then the array contents are static as well,
426         // and it has been initialized in the constructor of the data object.
427         frameData.initRefCountArray(refSel);
428     }
429     const std::vector<int>& refCountArray = frameData.refCountArray_;
430
431     AnalysisNeighborhoodSearch nbsearch = nb_.initSearch(pbc, refSel);
432     dh.startFrame(frnr, fr.time);
433     for (size_t g = 0; g < sel.size(); ++g)
434     {
435         const int columnCount = distances_.columnCount(g);
436         std::fill(distArray.begin(), distArray.begin() + columnCount, initialDist2_);
437         std::fill(countArray.begin(), countArray.begin() + columnCount, 0);
438
439         // Accumulate the number of position pairs within the cutoff and the
440         // min/max distance for each group pair.
441         AnalysisNeighborhoodPairSearch pairSearch = nbsearch.startPairSearch(sel[g]);
442         AnalysisNeighborhoodPair       pair;
443         while (pairSearch.findNextPair(&pair))
444         {
445             const SelectionPosition& refPos   = refSel.position(pair.refIndex());
446             const SelectionPosition& selPos   = sel[g].position(pair.testIndex());
447             const int                refIndex = refPos.mappedId();
448             const int                selIndex = selPos.mappedId();
449             const int                index    = selIndex * refGroupCount_ + refIndex;
450             const real               r2       = pair.distance2();
451             if (distanceType_ == DistanceType::Min)
452             {
453                 if (distArray[index] > r2)
454                 {
455                     distArray[index] = r2;
456                 }
457             }
458             else
459             {
460                 if (distArray[index] < r2)
461                 {
462                     distArray[index] = r2;
463                 }
464             }
465             ++countArray[index];
466         }
467
468         // If it is possible that positions outside the cutoff (or lack of
469         // them) affects the result, then we need to check whether there were
470         // any.  This is necessary for two cases:
471         //  - With max distances, if there are pairs outside the cutoff, then
472         //    the computed distance should be equal to the cutoff instead of
473         //    the largest distance that was found above.
474         //  - With either distance type, if all pairs are outside the cutoff,
475         //    then countArray must be updated so that the presence flag
476         //    in the output data reflects the dynamic selection status, not
477         //    whether something was inside the cutoff or not.
478         if (cutoff_ > 0.0)
479         {
480             int selPos = 0;
481             // Loop over groups in this selection (at start, selPos is always
482             // the first position in the next group).
483             while (selPos < sel[g].posCount())
484             {
485                 // Count the number of positions in this group.
486                 const int selIndex = sel[g].position(selPos).mappedId();
487                 const int startPos = selPos;
488                 ++selPos;
489                 while (selPos < sel[g].posCount() && sel[g].position(selPos).mappedId() == selIndex)
490                 {
491                     ++selPos;
492                 }
493                 const int count = selPos - startPos;
494                 // Check all group pairs that contain this group.
495                 for (int i = 0; i < refGroupCount_; ++i)
496                 {
497                     const int index      = selIndex * refGroupCount_ + i;
498                     const int totalCount = refCountArray[i] * count;
499                     // If there were positions outside the cutoff,
500                     // update the distance if necessary and the count.
501                     if (countArray[index] < totalCount)
502                     {
503                         if (distanceType_ == DistanceType::Max)
504                         {
505                             distArray[index] = cutoff2_;
506                         }
507                         countArray[index] = totalCount;
508                     }
509                 }
510             }
511         }
512
513         // Write the computed distances to the output data.
514         dh.selectDataSet(g);
515         for (int i = 0; i < columnCount; ++i)
516         {
517             if (countArray[i] > 0)
518             {
519                 dh.setPoint(i, std::sqrt(distArray[i]));
520             }
521             else
522             {
523                 // If there are no contributing positions, write out the cutoff
524                 // value.
525                 dh.setPoint(i, cutoff_, false);
526             }
527         }
528     }
529     dh.finishFrame();
530 }
531
532 void PairDistance::finishAnalysis(int /*nframes*/) {}
533
534 void PairDistance::writeOutput() {}
535
536 //! \}
537
538 } // namespace
539
540 const char PairDistanceInfo::name[] = "pairdist";
541 const char PairDistanceInfo::shortDescription[] =
542         "Calculate pairwise distances between groups of positions";
543
544 TrajectoryAnalysisModulePointer PairDistanceInfo::create()
545 {
546     return TrajectoryAnalysisModulePointer(new PairDistance);
547 }
548
549 } // namespace analysismodules
550
551 } // namespace gmx