class MockThreadAffinityAccess : public IThreadAffinityAccess
{
- public:
- MockThreadAffinityAccess();
- ~MockThreadAffinityAccess() override;
+public:
+ MockThreadAffinityAccess();
+ ~MockThreadAffinityAccess() override;
- void setSupported(bool supported) { supported_ = supported; }
+ void setSupported(bool supported) { supported_ = supported; }
- bool isThreadAffinitySupported() const override { return supported_; }
- MOCK_METHOD1(setCurrentThreadAffinityToCore, bool(int core));
+ bool isThreadAffinitySupported() const override { return supported_; }
+ MOCK_METHOD1(setCurrentThreadAffinityToCore, bool(int core));
- private:
- bool supported_;
+private:
+ bool supported_;
};
class ThreadAffinityTestHelper
{
- public:
- ThreadAffinityTestHelper();
- ~ThreadAffinityTestHelper();
-
- void setAffinitySupported(bool supported)
- {
- affinityAccess_.setSupported(supported);
- }
- void setAffinityOption(ThreadAffinity affinityOption)
- {
- hwOpt_.threadAffinity = affinityOption;
- }
- void setOffsetAndStride(int offset, int stride)
- {
- hwOpt_.core_pinning_offset = offset;
- hwOpt_.core_pinning_stride = stride;
- }
-
- void setPhysicalNodeId(int nodeId)
- {
- physicalNodeId_ = nodeId;
- }
-
- void setLogicalProcessorCount(int logicalProcessorCount);
-
- void setTotNumThreadsIsAuto(bool isAuto)
- {
- hwOpt_.totNumThreadsIsAuto = isAuto;
- }
-
- void expectAffinitySet(int core)
- {
- EXPECT_CALL(affinityAccess_, setCurrentThreadAffinityToCore(core));
- }
- void expectAffinitySet(std::initializer_list<int> cores)
- {
- for (int core : cores)
- {
- expectAffinitySet(core);
- }
- }
- void expectAffinitySetThatFails(int core)
- {
- using ::testing::Return;
+public:
+ ThreadAffinityTestHelper();
+ ~ThreadAffinityTestHelper();
+
+ void setAffinitySupported(bool supported) { affinityAccess_.setSupported(supported); }
+ void setAffinityOption(ThreadAffinity affinityOption)
+ {
+ hwOpt_.threadAffinity = affinityOption;
+ }
+ void setOffsetAndStride(int offset, int stride)
+ {
+ hwOpt_.core_pinning_offset = offset;
+ hwOpt_.core_pinning_stride = stride;
+ }
+
+ void setPhysicalNodeId(int nodeId) { physicalNodeId_ = nodeId; }
+
+ void setLogicalProcessorCount(int logicalProcessorCount);
+
+ void setTotNumThreadsIsAuto(bool isAuto) { hwOpt_.totNumThreadsIsAuto = isAuto; }
+
+ void expectAffinitySet(int core)
+ {
+ EXPECT_CALL(affinityAccess_, setCurrentThreadAffinityToCore(core));
+ }
+ void expectAffinitySet(std::initializer_list<int> cores)
+ {
+ for (int core : cores)
+ {
+ expectAffinitySet(core);
+ }
+ }
+ void expectAffinitySetThatFails(int core)
+ {
+ using ::testing::Return;
#ifndef __clang_analyzer__
- EXPECT_CALL(affinityAccess_, setCurrentThreadAffinityToCore(core))
- .WillOnce(Return(false));
+ EXPECT_CALL(affinityAccess_, setCurrentThreadAffinityToCore(core)).WillOnce(Return(false));
#else
- GMX_UNUSED_VALUE(core);
+ GMX_UNUSED_VALUE(core);
#endif
- }
-
- void expectWarningMatchingRegex(const char *re)
- {
- expectWarningMatchingRegexIf(re, true);
- }
- void expectWarningMatchingRegexIf(const char *re, bool condition)
- {
- expectLogMessageMatchingRegexIf(MDLogger::LogLevel::Warning, re, condition);
- }
- void expectInfoMatchingRegex(const char *re)
- {
- expectInfoMatchingRegexIf(re, true);
- }
- void expectInfoMatchingRegexIf(const char *re, bool condition)
- {
- expectLogMessageMatchingRegexIf(MDLogger::LogLevel::Info, re, condition);
- }
- void expectGenericFailureMessage()
- {
- expectGenericFailureMessageIf(true);
- }
- void expectGenericFailureMessageIf(bool condition)
- {
- expectWarningMatchingRegexIf("NOTE: Thread affinity was not set.", condition);
- }
- void expectPinningMessage(bool userSpecifiedStride, int stride)
- {
- std::string pattern = formatString("Pinning threads .* %s.* stride of %d",
- userSpecifiedStride ? "user" : "auto",
- stride);
- expectInfoMatchingRegex(pattern.c_str());
- }
- void expectLogMessageMatchingRegexIf(MDLogger::LogLevel level,
- const char *re, bool condition)
- {
- if (condition)
- {
- logHelper_.expectEntryMatchingRegex(level, re);
- }
- }
-
- void setAffinity(int numThreadsOnThisRank)
- {
- if (hwTop_ == nullptr)
- {
- setLogicalProcessorCount(1);
- }
- gmx::PhysicalNodeCommunicator comm(MPI_COMM_WORLD, physicalNodeId_);
- int numThreadsOnThisNode, indexWithinNodeOfFirstThreadOnThisRank;
- analyzeThreadsOnThisNode(comm,
- numThreadsOnThisRank,
- &numThreadsOnThisNode,
- &indexWithinNodeOfFirstThreadOnThisRank);
- gmx_set_thread_affinity(logHelper_.logger(), cr_, &hwOpt_, *hwTop_,
- numThreadsOnThisRank, numThreadsOnThisNode,
- indexWithinNodeOfFirstThreadOnThisRank, &affinityAccess_);
- }
-
- private:
- t_commrec *cr_;
- gmx_hw_opt_t hwOpt_;
- std::unique_ptr<HardwareTopology> hwTop_;
- MockThreadAffinityAccess affinityAccess_;
- LoggerTestHelper logHelper_;
- int physicalNodeId_;
+ }
+
+ void expectWarningMatchingRegex(const char* re) { expectWarningMatchingRegexIf(re, true); }
+ void expectWarningMatchingRegexIf(const char* re, bool condition)
+ {
+ expectLogMessageMatchingRegexIf(MDLogger::LogLevel::Warning, re, condition);
+ }
+ void expectInfoMatchingRegex(const char* re) { expectInfoMatchingRegexIf(re, true); }
+ void expectInfoMatchingRegexIf(const char* re, bool condition)
+ {
+ expectLogMessageMatchingRegexIf(MDLogger::LogLevel::Info, re, condition);
+ }
+ void expectGenericFailureMessage() { expectGenericFailureMessageIf(true); }
+ void expectGenericFailureMessageIf(bool condition)
+ {
+ expectWarningMatchingRegexIf("NOTE: Thread affinity was not set.", condition);
+ }
+ void expectPinningMessage(bool userSpecifiedStride, int stride)
+ {
+ std::string pattern = formatString("Pinning threads .* %s.* stride of %d",
+ userSpecifiedStride ? "user" : "auto", stride);
+ expectInfoMatchingRegex(pattern.c_str());
+ }
+ void expectLogMessageMatchingRegexIf(MDLogger::LogLevel level, const char* re, bool condition)
+ {
+ if (condition)
+ {
+ logHelper_.expectEntryMatchingRegex(level, re);
+ }
+ }
+
+ void setAffinity(int numThreadsOnThisRank)
+ {
+ if (hwTop_ == nullptr)
+ {
+ setLogicalProcessorCount(1);
+ }
+ gmx::PhysicalNodeCommunicator comm(MPI_COMM_WORLD, physicalNodeId_);
+ int numThreadsOnThisNode, indexWithinNodeOfFirstThreadOnThisRank;
+ analyzeThreadsOnThisNode(comm, numThreadsOnThisRank, &numThreadsOnThisNode,
+ &indexWithinNodeOfFirstThreadOnThisRank);
+ gmx_set_thread_affinity(logHelper_.logger(), cr_, &hwOpt_, *hwTop_, numThreadsOnThisRank,
+ numThreadsOnThisNode, indexWithinNodeOfFirstThreadOnThisRank,
+ &affinityAccess_);
+ }
+
+private:
+ t_commrec* cr_;
+ gmx_hw_opt_t hwOpt_;
+ std::unique_ptr<HardwareTopology> hwTop_;
+ MockThreadAffinityAccess affinityAccess_;
+ LoggerTestHelper logHelper_;
+ int physicalNodeId_;
};
} // namespace test