From: Kevin Boyd Date: Tue, 18 Jun 2019 01:27:30 +0000 (-0400) Subject: Add mdspan basic elementwise math X-Git-Url: http://biod.pnpi.spb.ru/gitweb/?a=commitdiff_plain;h=f3c2e710808372b44e22d01bdafbc8d07ce42797;p=alexxy%2Fgromacs.git Add mdspan basic elementwise math BasicMatrix3x3 provides a replacement for the c-style tensors, but has yet to replicate the tensor operations in math/vec.h. This provides a subset of that functionality for msdpan in general, which can be used for MultiDimArray trivially Refs #2976 Change-Id: I17b77df032dbbfde0ff87108215edcec07fef6c4 --- diff --git a/src/gromacs/mdspan/extensions.h b/src/gromacs/mdspan/extensions.h index efa564429a..20163f4954 100644 --- a/src/gromacs/mdspan/extensions.h +++ b/src/gromacs/mdspan/extensions.h @@ -44,6 +44,9 @@ #ifndef GMX_MDSPAN_EXTENSIONS_H_ #define GMX_MDSPAN_EXTENSIONS_H_ +#include +#include + #include "gromacs/mdspan/mdspan.h" namespace gmx @@ -80,6 +83,46 @@ end(const BasicMdspan &basicMdspan) //! Convenience type for often-used three dimensional extents using dynamicExtents3D = extents; +//! Elementwise addition +template +constexpr BasicMdspan addElementwise(const BasicMdspan &span1, const BasicMdspan &span2) +{ + BasicMdspan result(span1); + std::transform(begin(span1), end(span1), begin(span2), + begin(result), std::plus()); + return result; +} + +//! Elementwise subtraction - left minus right +template +constexpr BasicMdspan subtractElementwise(const BasicMdspan &span1, const BasicMdspan &span2) +{ + BasicMdspan result(span1); + std::transform(begin(span1), end(span1), begin(span2), + begin(result), std::minus()); + return result; +} + +//! Elementwise multiplication +template +constexpr BasicMdspan multiplyElementwise(const BasicMdspan &span1, const BasicMdspan &span2) +{ + BasicMdspan result(span1); + std::transform(begin(span1), end(span1), begin(span2), + begin(result), std::multiplies()); + return result; +} + +//! Elementwise division - left / right +template +constexpr BasicMdspan divideElementwise(const BasicMdspan &span1, const BasicMdspan &span2) +{ + BasicMdspan result(span1); + std::transform(begin(span1), end(span1), begin(span2), + begin(result), std::divides()); + return result; +} + } // namespace gmx #endif // GMX_MDSPAN_EXTENSIONS_H_ diff --git a/src/gromacs/mdspan/tests/extensions.cpp b/src/gromacs/mdspan/tests/extensions.cpp index 8b68b114e0..009ea92dde 100644 --- a/src/gromacs/mdspan/tests/extensions.cpp +++ b/src/gromacs/mdspan/tests/extensions.cpp @@ -126,4 +126,71 @@ TEST(MdSpanExtension, SlicingEqualsView3D) EXPECT_EQ(span[1][1][1], span(1, 1, 1)); } +TEST(MdSpanExtension, additionWorks) +{ + std::array arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}}; + std::array arr2 = {{1, 1, 1, 1, 1, 1, 1, 1}}; + basic_mdspan > span1 { + arr1.data() + }; + basic_mdspan > span2 { + arr2.data() + }; + + auto result = addElementwise(span1, span2); + EXPECT_EQ(result[0][0][1], -2); + EXPECT_EQ(result[0][1][0], -1); + EXPECT_EQ(result[1][0][0], 1); +} + +TEST(MdSpanExtension, subtractionWorks) +{ + std::array arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}}; + std::array arr2 = {{1, 1, 1, 1, 1, 1, 1, 1}}; + basic_mdspan > span1 { + arr1.data() + }; + basic_mdspan > span2 { + arr2.data() + }; + + auto result = subtractElementwise(span1, span2); + EXPECT_EQ(result[0][0][1], -4); + EXPECT_EQ(result[0][1][0], -3); + EXPECT_EQ(result[1][0][0], -1); +} + +TEST(MdSpanExtension, multiplicationWorks) +{ + std::array arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}}; + std::array arr2 = {{2, 2, 2, 2, 2, 2, 2, }}; + basic_mdspan > span1 { + arr1.data() + }; + basic_mdspan > span2 { + arr2.data() + }; + + auto result = multiplyElementwise(span1, span2); + EXPECT_EQ(result[0][0][1], -6); + EXPECT_EQ(result[0][1][0], -4); + EXPECT_EQ(result[1][0][0], 0); +} + +TEST(MdSpanExtension, divisionWorks) +{ + std::array arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}}; + std::array arr2 = {{2, 2, 2, 2, 2, 2, 2, 2, }}; + basic_mdspan > span1 { + arr1.data() + }; + basic_mdspan > span2 { + arr2.data() + }; + + auto result = divideElementwise(span1, span2); + EXPECT_EQ(result[0][0][1], -1.5); + EXPECT_EQ(result[0][1][0], -1); + EXPECT_EQ(result[1][0][0], 0); +} } // namespace gmx