Generalize constraints on MPI rank counts for tests
[alexxy/gromacs.git] / src / testutils / include / testutils / mpitest.h
index 37d543937cbfb64c74fcf46f14728e3dc59edda8..04b041b2b8220dc0b712bd6e1b39e95c86e83dd9 100644 (file)
 #include "config.h"
 
 #include <functional>
+#include <string>
 #include <type_traits>
 
-#include "gromacs/utility/basenetwork.h"
-
 namespace gmx
 {
 namespace test
@@ -64,6 +63,7 @@ namespace test
  * \ingroup module_testutils
  */
 int getNumberOfTestMpiRanks();
+
 //! \cond internal
 /*! \brief
  * Helper function for GMX_MPI_TEST().
@@ -73,58 +73,114 @@ int getNumberOfTestMpiRanks();
 bool threadMpiTestRunner(std::function<void()> testBody);
 //! \endcond
 
-/*! \brief
- * Declares that this test is an MPI-enabled unit test.
+/*! \brief Implementation of MPI test runner for thread-MPI
  *
- * \param[in] expectedRankCount Expected number of ranks for this test.
- *     The test will fail if run with unsupported number of ranks.
+ * See documentation GMX_MPI_TEST */
+#if GMX_THREAD_MPI
+#    define GMX_MPI_TEST_INNER                                                                  \
+        do                                                                                      \
+        {                                                                                       \
+            using MyTestClass = std::remove_reference_t<decltype(*this)>;                       \
+            if (!::gmx::test::threadMpiTestRunner([this]() { this->MyTestClass::TestBody(); })) \
+            {                                                                                   \
+                return;                                                                         \
+            }                                                                                   \
+        } while (0)
+#else
+#    define GMX_MPI_TEST_INNER
+#endif
+
+/*! \brief Declares that this test is an MPI-enabled unit test and
+ * expresses the conditions under which it can run.
  *
  * To write unit tests that run under MPI, you need to do a few things:
- *  - Put GMX_MPI_TEST() as the first statement in your test body and
- *    specify the number of ranks this test expects.
+ *  - Put GMX_MPI_TEST(RankRequirement) as the first statement in your
+ *    test body and either declare or use a suitable class as its
+ *    argument to express what requirements exist on the number of MPI
+ *    ranks for this test.
  *  - Declare your unit test in CMake with gmx_add_mpi_unit_test().
- *    Note that all tests in the binary should fulfill the conditions above,
- *    and work with the same number of ranks.
- * TODO: Figure out a mechanism for mixing tests with different rank counts in
- * the same binary (possibly, also MPI and non-MPI tests).
+ *    Note that all tests in the binary should fulfill the conditions above.
  *
  * When you do the above, the following will happen:
  *  - The test will get compiled only if thread-MPI or real MPI is enabled.
- *  - The test will get executed on the number of ranks specified.
- *    If you are using real MPI, the whole test binary is run under MPI and
- *    test execution across the processes is synchronized (GMX_MPI_TEST()
- *    actually has no effect in this case, the synchronization is handled at a
- *    higher level).
- *    If you are using thread-MPI, GMX_MPI_TEST() is required and it
- *    initializes thread-MPI with the specified number of threads and runs the
- *    rest of the test on each of the threads.
+ *  - The test will get executed only when the specified condition on
+ *    the the number of ranks is satisfied.
+ *  - If you are using real MPI, the whole test binary is run under
+ *    MPI and test execution across the processes is synchronized
+ *    (GMX_MPI_TEST() actually has no effect in this case, the
+ *    synchronization is handled at a higher level).
+ *  - If you are using thread-MPI, GMX_MPI_TEST() is required and it
+ *    initializes thread-MPI with the specified number of threads and
+ *    runs the rest of the test on each of the threads.
+ *
+ * \param[in] RankRequirement Class that expresses the necessary
+ *     conditions on the number of MPI ranks for the test to continue.
+ *     If run with unsupported number of ranks, the remainder of the
+ *     test body is skipped, and the GTEST_SKIP() mechanism used to
+ *     report the reason why the number of MPI ranks is unsuitable.
+ *
+ * The RankRequirement class must have two static members; a static
+ * method \c bool conditionSatisfied(const int) that can be passed the
+ * number of ranks present at run time and return whether the test can
+ * run with that number of ranks, and a static const string \c
+ * s_skipReason describing the reason why the test cannot be run, when
+ * that is the case.
  *
  * You need to be extra careful for variables in the test fixture, if you use
  * one: when run under thread-MPI, these will be shared across all the ranks,
  * while under real MPI, these are naturally different for each process.
  * Local variables in the test body are private to each rank in both cases.
  *
- * Currently, it is not possible to specify the number of ranks as one, because
- * that will lead to problems with (at least) thread-MPI, but such tests can be
- * written as serial tests anyways.
+ * Currently, it is not possible to require the use of a single MPI
+ * rank, because that will lead to problems with (at least)
+ * thread-MPI, but such tests can be written as serial tests anyway.
  *
  * \ingroup module_testutils
  */
-#if GMX_THREAD_MPI
-#    define GMX_MPI_TEST(expectedRankCount)                                                     \
-        do                                                                                      \
-        {                                                                                       \
-            ASSERT_EQ(expectedRankCount, ::gmx::test::getNumberOfTestMpiRanks());               \
-            using MyTestClass = std::remove_reference_t<decltype(*this)>;                       \
-            if (!::gmx::test::threadMpiTestRunner([this]() { this->MyTestClass::TestBody(); })) \
-            {                                                                                   \
-                return;                                                                         \
-            }                                                                                   \
-        } while (0)
-#else
-#    define GMX_MPI_TEST(expectedRankCount) \
-        ASSERT_EQ(expectedRankCount, ::gmx::test::getNumberOfTestMpiRanks())
-#endif
+#define GMX_MPI_TEST(RankRequirement)                                                         \
+    const int numRanks = ::gmx::test::getNumberOfTestMpiRanks();                              \
+    if (!RankRequirement::conditionSatisfied(numRanks))                                       \
+    {                                                                                         \
+        GTEST_SKIP() << std::string("Test skipped because ") + RankRequirement::s_skipReason; \
+        return;                                                                               \
+    }                                                                                         \
+    GMX_MPI_TEST_INNER;
+
+//! Helper for GMX_MPI_TEST to permit any rank count
+class AllowAnyRankCount
+{
+public:
+    /*! \brief Function called by GMX_MPI_CONDITIONAL_TEST to see
+     * whether the test conditions are satisifed */
+    static bool conditionSatisfied(const int /* numRanks */) { return true; }
+    //! Reason to echo when skipping the test
+    inline static const char* s_skipReason = "UNUSED - any rank count satisfies";
+};
+
+//! Helper for GMX_MPI_TEST to permit only a specific rank count
+template<int requiredNumRanks>
+class RequireRankCount
+{
+public:
+    //! Function to require a specific number of ranks
+    static bool conditionSatisfied(const int numRanks) { return numRanks == requiredNumRanks; }
+    //! Text to echo when skipping a test that does not satisfy the requirement
+    inline static const std::string s_skipReason =
+            std::to_string(requiredNumRanks) + " ranks are required";
+};
+
+//! Helper for GMX_MPI_TEST to permit only a specific rank count
+template<int minimumNumRanks>
+class RequireMinimumRankCount
+{
+public:
+    //! Function to require at least the minimum number of ranks
+    static bool conditionSatisfied(const int numRanks) { return numRanks >= minimumNumRanks; }
+    //! Text to echo when skipping a test that does not satisfy the requirement
+    inline static const std::string s_skipReason =
+            std::to_string(minimumNumRanks) + " or more ranks are required";
+};
+
 
 } // namespace test
 } // namespace gmx