Add mdspan basic elementwise math
authorKevin Boyd <kevin.boyd@uconn.edu>
Tue, 18 Jun 2019 01:27:30 +0000 (21:27 -0400)
committerPaul Bauer <paul.bauer.q@gmail.com>
Tue, 2 Jul 2019 07:14:18 +0000 (09:14 +0200)
BasicMatrix3x3 provides a replacement for the c-style tensors, but
has yet to replicate the tensor operations in math/vec.h. This provides
a subset of that functionality for msdpan in general, which can be used
for MultiDimArray trivially

Refs #2976

Change-Id: I17b77df032dbbfde0ff87108215edcec07fef6c4

src/gromacs/mdspan/extensions.h
src/gromacs/mdspan/tests/extensions.cpp

index efa564429a53ef7571bca9d4d7d23ef98e4c902d..20163f495413aa296bea214e1f2b562714675f04 100644 (file)
@@ -44,6 +44,9 @@
 #ifndef GMX_MDSPAN_EXTENSIONS_H_
 #define GMX_MDSPAN_EXTENSIONS_H_
 
+#include <algorithm>
+#include <functional>
+
 #include "gromacs/mdspan/mdspan.h"
 
 namespace gmx
@@ -80,6 +83,46 @@ end(const BasicMdspan &basicMdspan)
 //! Convenience type for often-used three dimensional extents
 using dynamicExtents3D = extents<dynamic_extent, dynamic_extent, dynamic_extent>;
 
+//! Elementwise addition
+template <class BasicMdspan>
+constexpr BasicMdspan addElementwise(const BasicMdspan &span1, const BasicMdspan &span2)
+{
+    BasicMdspan result(span1);
+    std::transform(begin(span1), end(span1), begin(span2),
+                   begin(result), std::plus<typename BasicMdspan::element_type>());
+    return result;
+}
+
+//! Elementwise subtraction - left minus right
+template <class BasicMdspan>
+constexpr BasicMdspan subtractElementwise(const BasicMdspan &span1, const BasicMdspan &span2)
+{
+    BasicMdspan result(span1);
+    std::transform(begin(span1), end(span1), begin(span2),
+                   begin(result), std::minus<typename BasicMdspan::element_type>());
+    return result;
+}
+
+//! Elementwise multiplication
+template <class BasicMdspan>
+constexpr BasicMdspan multiplyElementwise(const BasicMdspan &span1, const BasicMdspan &span2)
+{
+    BasicMdspan result(span1);
+    std::transform(begin(span1), end(span1), begin(span2),
+                   begin(result), std::multiplies<typename BasicMdspan::element_type>());
+    return result;
+}
+
+//! Elementwise division - left / right
+template <class BasicMdspan>
+constexpr BasicMdspan divideElementwise(const BasicMdspan &span1, const BasicMdspan &span2)
+{
+    BasicMdspan result(span1);
+    std::transform(begin(span1), end(span1), begin(span2),
+                   begin(result), std::divides<typename BasicMdspan::element_type>());
+    return result;
+}
+
 }      // namespace gmx
 
 #endif // GMX_MDSPAN_EXTENSIONS_H_
index 8b68b114e00982f34dbc7d403348a30ab24a64e6..009ea92dde2539280ed9c37b79cdd0999d3435d2 100644 (file)
@@ -126,4 +126,71 @@ TEST(MdSpanExtension, SlicingEqualsView3D)
     EXPECT_EQ(span[1][1][1], span(1, 1, 1));
 }
 
+TEST(MdSpanExtension, additionWorks)
+{
+    std::array<int, 2 * 2 * 2>           arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}};
+    std::array<int, 2 * 2 * 2>           arr2 = {{1, 1, 1, 1, 1, 1, 1, 1}};
+    basic_mdspan<int, extents<2, 2, 2> > span1 {
+        arr1.data()
+    };
+    basic_mdspan<int, extents<2, 2, 2> > span2 {
+        arr2.data()
+    };
+
+    auto              result = addElementwise(span1, span2);
+    EXPECT_EQ(result[0][0][1], -2);
+    EXPECT_EQ(result[0][1][0], -1);
+    EXPECT_EQ(result[1][0][0],  1);
+}
+
+TEST(MdSpanExtension, subtractionWorks)
+{
+    std::array<int, 2 * 2 * 2>           arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}};
+    std::array<int, 2 * 2 * 2>           arr2 = {{1, 1, 1, 1, 1, 1, 1, 1}};
+    basic_mdspan<int, extents<2, 2, 2> > span1 {
+        arr1.data()
+    };
+    basic_mdspan<int, extents<2, 2, 2> > span2 {
+        arr2.data()
+    };
+
+    auto              result = subtractElementwise(span1, span2);
+    EXPECT_EQ(result[0][0][1], -4);
+    EXPECT_EQ(result[0][1][0], -3);
+    EXPECT_EQ(result[1][0][0], -1);
+}
+
+TEST(MdSpanExtension, multiplicationWorks)
+{
+    std::array<int, 2 * 2 * 2>           arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}};
+    std::array<int, 2 * 2 * 2>           arr2 = {{2, 2, 2, 2, 2, 2, 2, }};
+    basic_mdspan<int, extents<2, 2, 2> > span1 {
+        arr1.data()
+    };
+    basic_mdspan<int, extents<2, 2, 2> > span2 {
+        arr2.data()
+    };
+
+    auto              result = multiplyElementwise(span1, span2);
+    EXPECT_EQ(result[0][0][1], -6);
+    EXPECT_EQ(result[0][1][0], -4);
+    EXPECT_EQ(result[1][0][0],  0);
+}
+
+TEST(MdSpanExtension, divisionWorks)
+{
+    std::array<float, 2 * 2 * 2>           arr1 = {{-4, -3, -2, -1, 0, 1, 2, 3}};
+    std::array<float, 2 * 2 * 2>           arr2 = {{2, 2, 2, 2, 2, 2, 2, 2, }};
+    basic_mdspan<float, extents<2, 2, 2> > span1 {
+        arr1.data()
+    };
+    basic_mdspan<float, extents<2, 2, 2> > span2 {
+        arr2.data()
+    };
+
+    auto              result = divideElementwise(span1, span2);
+    EXPECT_EQ(result[0][0][1], -1.5);
+    EXPECT_EQ(result[0][1][0], -1);
+    EXPECT_EQ(result[1][0][0],    0);
+}
 } // namespace gmx