Add implicit conversion of mdspan-containing types to mdspan.
authorKevin Boyd <kevin.boyd@uconn.edu>
Fri, 5 Jul 2019 00:30:34 +0000 (20:30 -0400)
committerPaul Bauer <paul.bauer.q@gmail.com>
Wed, 10 Jul 2019 07:53:25 +0000 (09:53 +0200)
Analagous to arrayRef function args taking vectors, arrays as input

Change-Id: I701f6373592cdf4d30649041ffe8dee76423433b

src/gromacs/math/tests/multidimarray.cpp
src/gromacs/mdspan/mdspan.h

index f41b8645db2bdf3244537d1c71d60ec7762dae12..1a97ede00dfdd4cbd2381aa9b6d49bb32043903f 100644 (file)
@@ -314,8 +314,43 @@ TEST_F(MultiDimArrayTest, constViewConstEnd)
     EXPECT_EQ(*x, testNumber_);
 }
 
-} // namespace
+TEST(MultiDimArrayToMdSpanTest, convertsToMdSpan)
+{
+    MultiDimArray < std::array<int, 4>, extents < 2, 2>> arr = {{0, 1, 2, 3}};
+    basic_mdspan < int, extents < 2, 2>> span(arr);
+
+    // test copy correctness
+    EXPECT_EQ(span(1, 1), 3);
+
+    // test that span operates on same data as array
+    span(0, 1) = -5;
+    EXPECT_EQ(arr(0, 1), -5);
+}
+
+TEST(MultiDimArrayToMdSpanTest, constArrayToMdSpan)
+{
+    const MultiDimArray < std::array<int, 4>, extents < 2, 2>> constArr = {{0, 1, 2, 3}};
+    basic_mdspan < const int, extents < 2, 2>> span(constArr);
+    EXPECT_EQ(span(0, 1), 1);
+}
 
+TEST(MultiDimArrayToMdSpanTest, nonConstArrayToConstMdSpan)
+{
+    MultiDimArray < std::array<int, 4>, extents < 2, 2>> arr = {{0, 1, 2, 3}};
+    basic_mdspan < const int, extents < 2, 2>> span(arr);
+    EXPECT_EQ(span(0, 1), 1);
+}
+
+TEST(MultiDimArrayToMdSpanTest, implicitConversionToMdSpan)
+{
+    auto testFunc = [](basic_mdspan < const int, extents < 2, 2>> a){
+            return a(0, 0);
+        };
+    MultiDimArray < std::array<int, 4>, extents < 2, 2>> arr = {{0, 1, 2, 3}};
+    EXPECT_EQ(testFunc(arr), 0);
+}
+
+} // namespace
 } // namespace test
 
 } // namespace gmx
index 49fbbf6a59e38e3d19792d55ad602ff04bc2f7a7..88bd55f6c9d77750c84b5ab68099b851a58e10a3 100644 (file)
@@ -196,8 +196,37 @@ class basic_mdspan
          */
         constexpr basic_mdspan( pointer ptr, const mapping_type &m, const accessor_type &a ) noexcept
             : acc_(a), map_( m ), ptr_(ptr) {}
-
-        /*! \brief Brace operator to access multidimenisonal array element.
+        /*! \brief Construct mdspan from multidimensional arrays implemented with mdspan
+         *
+         * Requires the container to have a view_type describing the mdspan, which is
+         * accessible through an asView() call
+         *
+         *  This allows functions to declare mdspans as arguments, but take e.g. multidimensional
+         *  arrays implicitly during the function call
+         * \tparam U container type
+         * \param[in] other mdspan-implementing container
+         */
+        template<typename U,
+                 typename = typename std::enable_if<
+                         std::is_same<typename std::remove_reference<U>::type::view_type::element_type,
+                                      ElementType>::value>::type>
+        constexpr basic_mdspan(U &&other) : basic_mdspan(other.asView()) {}
+        /*! \brief Construct mdspan of const Elements from multidimensional arrays implemented with mdspan
+         *
+         * Requires the container to have a const_view_type describing the mdspan, which is
+         * accessible through an asConstView() call
+         *
+         *  This allows functions to declare mdspans as arguments, but take e.g. multidimensional
+         *  arrays implicitly during the function call
+         * \tparam U container type
+         * \param[in] other mdspan-implementing container
+         */
+        template<typename U,
+                 typename = typename std::enable_if<
+                         std::is_same<typename std::remove_reference<U>::type::const_view_type::element_type,
+                                      ElementType>::value>::type>
+        constexpr basic_mdspan(const U &other) : basic_mdspan(other.asConstView()) {}
+        /*! \brief Brace operator to access multidimensional array element.
          * \param[in] indices The multidimensional indices of the object.
          * Requires rank() == sizeof...(IndexType). Slicing is implemented via sub_span.
          * \returns reference to element at indices.