Simpify CheckpointData
authorRoland Schulz <roland.schulz@intel.com>
Wed, 9 Sep 2020 20:06:04 +0000 (13:06 -0700)
committerPaul Bauer <paul.bauer.q@gmail.com>
Thu, 10 Sep 2020 15:06:44 +0000 (15:06 +0000)
src/gromacs/mdtypes/checkpointdata.h

index 9bc226506b2cb7c203cb6ee3c682dca38b849191..b82da7cf1e72cdceebac689d78c7848ec316594d 100644 (file)
@@ -134,8 +134,16 @@ struct IsSerializableEnum<T, false>
  * objects, which interact with the checkpoint reading from / writing to
  * file.
  */
+
 template<CheckpointDataOperation operation>
-class CheckpointData
+class CheckpointData;
+
+// Shortcuts
+using ReadCheckpointData  = CheckpointData<CheckpointDataOperation::Read>;
+using WriteCheckpointData = CheckpointData<CheckpointDataOperation::Write>;
+
+template<>
+class CheckpointData<CheckpointDataOperation::Read>
 {
 public:
     /*! \brief Read or write a single value from / to checkpoint
@@ -151,20 +159,10 @@ public:
      * \param value       The value to [read|write]
      */
     //! {
-    // Read
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableType<T>::value, void>
-    scalar(const std::string& key, T* value) const;
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableEnum<T>::value, void>
-    enumScalar(const std::string& key, T* value) const;
-    // Write
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableType<T>::value, void>
-    scalar(const std::string& key, const T* value);
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableEnum<T>::value, void>
-    enumScalar(const std::string& key, const T* value);
+    template<typename T>
+    std::enable_if_t<IsSerializableType<T>::value, void> scalar(const std::string& key, T* value) const;
+    template<typename T>
+    std::enable_if_t<IsSerializableEnum<T>::value, void> enumScalar(const std::string& key, T* value) const;
     //! }
 
     /*! \brief Read or write an ArrayRef from / to checkpoint
@@ -180,21 +178,11 @@ public:
      */
     //! {
     // Read ArrayRef of scalar
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableType<T>::value, void>
-    arrayRef(const std::string& key, ArrayRef<T> values) const;
-    // Write ArrayRef of scalar
-    template<typename T, CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableType<T>::value, void>
-    arrayRef(const std::string& key, ArrayRef<const T> values);
+    template<typename T>
+    std::enable_if_t<IsSerializableType<T>::value, void> arrayRef(const std::string& key,
+                                                                  ArrayRef<T>        values) const;
     // Read ArrayRef of RVec
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read, void> arrayRef(const std::string& key,
-                                                                         ArrayRef<RVec> values) const;
-    // Write ArrayRef of RVec
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write, void> arrayRef(const std::string& key,
-                                                                          ArrayRef<const RVec> values);
+    void arrayRef(const std::string& key, ArrayRef<RVec> values) const;
     //! }
 
     /*! \brief Read or write a tensor from / to checkpoint
@@ -203,16 +191,7 @@ public:
      * \param key         The key to [read|write] the tensor [from|to]
      * \param values      The tensor to [read|write]
      */
-    //! {
-    // Read
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read, void> tensor(const std::string& key,
-                                                                       ::tensor values) const;
-    // Write
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write, void> tensor(const std::string& key,
-                                                                        const ::tensor     values);
-    //! }
+    void tensor(const std::string& key, ::tensor values) const;
 
     /*! \brief Return a subset of the current CheckpointData
      *
@@ -221,35 +200,59 @@ public:
      * \return            A CheckpointData object representing a subset of the current object
      */
     //!{
-    // Read
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Read, CheckpointData>
-    subCheckpointData(const std::string& key) const;
-    // Write
-    template<CheckpointDataOperation op = operation>
-    std::enable_if_t<op == CheckpointDataOperation::Write, CheckpointData>
-    subCheckpointData(const std::string& key);
+    CheckpointData subCheckpointData(const std::string& key) const;
     //!}
 
 private:
     //! KV tree read from checkpoint
     const KeyValueTreeObject* inputTree_ = nullptr;
-    //! Builder for the tree to be written to checkpoint
-    std::optional<KeyValueTreeObjectBuilder> outputTreeBuilder_ = std::nullopt;
 
     //! Construct an input checkpoint data object
     explicit CheckpointData(const KeyValueTreeObject& inputTree);
+
+    // Only holders should build
+    friend class ReadCheckpointDataHolder;
+};
+
+template<>
+class CheckpointData<CheckpointDataOperation::Write>
+{
+public:
+    //! \copydoc CheckpointData<CheckpointDataOperation::Read>::scalar
+    //! {
+    template<typename T>
+    std::enable_if_t<IsSerializableType<T>::value, void> scalar(const std::string& key, const T* value);
+    template<typename T>
+    std::enable_if_t<IsSerializableEnum<T>::value, void> enumScalar(const std::string& key, const T* value);
+    //! }
+
+    //! \copydoc CheckpointData<CheckpointDataOperation::Read>::arrayRef
+    //! {
+    // Write ArrayRef of scalar
+    template<typename T>
+    std::enable_if_t<IsSerializableType<T>::value, void> arrayRef(const std::string& key,
+                                                                  ArrayRef<const T>  values);
+    // Write ArrayRef of RVec
+    void arrayRef(const std::string& key, ArrayRef<const RVec> values);
+    //! }
+
+    //! \copydoc CheckpointData<CheckpointDataOperation::Read>::tensor
+    void tensor(const std::string& key, const ::tensor values);
+
+    //! \copydoc CheckpointData<CheckpointDataOperation::Read>::subCheckpointData
+    CheckpointData subCheckpointData(const std::string& key);
+
+private:
+    //! Builder for the tree to be written to checkpoint
+    std::optional<KeyValueTreeObjectBuilder> outputTreeBuilder_ = std::nullopt;
+
     //! Construct an output checkpoint data object
     explicit CheckpointData(KeyValueTreeObjectBuilder&& outputTreeBuilder);
 
     // Only holders should build
-    friend class ReadCheckpointDataHolder;
     friend class WriteCheckpointDataHolder;
 };
 
-// Shortcuts
-using ReadCheckpointData  = CheckpointData<CheckpointDataOperation::Read>;
-using WriteCheckpointData = CheckpointData<CheckpointDataOperation::Write>;
 
 /*! \libinternal
  * \brief Holder for read checkpoint data
@@ -318,19 +321,17 @@ private:
 // Function definitions - here to avoid template-related linker problems
 // doxygen doesn't like these...
 //! \cond
-template<>
-template<typename T, CheckpointDataOperation op>
-std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableType<T>::value, void>
-ReadCheckpointData::scalar(const std::string& key, T* value) const
+template<typename T>
+std::enable_if_t<IsSerializableType<T>::value, void> ReadCheckpointData::scalar(const std::string& key,
+                                                                                T* value) const
 {
     GMX_RELEASE_ASSERT(inputTree_, "No input checkpoint data available.");
     *value = (*inputTree_)[key].cast<T>();
 }
 
-template<>
-template<typename T, CheckpointDataOperation op>
-std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableEnum<T>::value, void>
-ReadCheckpointData::enumScalar(const std::string& key, T* value) const
+template<typename T>
+std::enable_if_t<IsSerializableEnum<T>::value, void> ReadCheckpointData::enumScalar(const std::string& key,
+                                                                                    T* value) const
 {
     GMX_RELEASE_ASSERT(inputTree_, "No input checkpoint data available.");
     std::underlying_type_t<T> castValue;
@@ -338,18 +339,16 @@ ReadCheckpointData::enumScalar(const std::string& key, T* value) const
     *value    = static_cast<T>(castValue);
 }
 
-template<>
-template<typename T, CheckpointDataOperation op>
-inline std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableType<T>::value, void>
+template<typename T>
+inline std::enable_if_t<IsSerializableType<T>::value, void>
 WriteCheckpointData::scalar(const std::string& key, const T* value)
 {
     GMX_RELEASE_ASSERT(outputTreeBuilder_, "No output checkpoint data available.");
     outputTreeBuilder_->addValue(key, *value);
 }
 
-template<>
-template<typename T, CheckpointDataOperation op>
-inline std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableEnum<T>::value, void>
+template<typename T>
+inline std::enable_if_t<IsSerializableEnum<T>::value, void>
 WriteCheckpointData::enumScalar(const std::string& key, const T* value)
 {
     GMX_RELEASE_ASSERT(outputTreeBuilder_, "No output checkpoint data available.");
@@ -357,9 +356,8 @@ WriteCheckpointData::enumScalar(const std::string& key, const T* value)
     outputTreeBuilder_->addValue(key, castValue);
 }
 
-template<>
-template<typename T, CheckpointDataOperation op>
-inline std::enable_if_t<op == CheckpointDataOperation::Read && IsSerializableType<T>::value, void>
+template<typename T>
+inline std::enable_if_t<IsSerializableType<T>::value, void>
 ReadCheckpointData::arrayRef(const std::string& key, ArrayRef<T> values) const
 {
     GMX_RELEASE_ASSERT(inputTree_, "No input checkpoint data available.");
@@ -375,9 +373,8 @@ ReadCheckpointData::arrayRef(const std::string& key, ArrayRef<T> values) const
     }
 }
 
-template<>
-template<typename T, CheckpointDataOperation op>
-inline std::enable_if_t<op == CheckpointDataOperation::Write && IsSerializableType<T>::value, void>
+template<typename T>
+inline std::enable_if_t<IsSerializableType<T>::value, void>
 WriteCheckpointData::arrayRef(const std::string& key, ArrayRef<const T> values)
 {
     GMX_RELEASE_ASSERT(outputTreeBuilder_, "No output checkpoint data available.");
@@ -388,8 +385,6 @@ WriteCheckpointData::arrayRef(const std::string& key, ArrayRef<const T> values)
     }
 }
 
-template<>
-template<>
 inline void ReadCheckpointData::arrayRef(const std::string& key, ArrayRef<RVec> values) const
 {
     GMX_RELEASE_ASSERT(values.size() >= (*inputTree_)[key].asArray().values().size(),
@@ -406,8 +401,6 @@ inline void ReadCheckpointData::arrayRef(const std::string& key, ArrayRef<RVec>
     }
 }
 
-template<>
-template<>
 inline void WriteCheckpointData::arrayRef(const std::string& key, ArrayRef<const RVec> values)
 {
     auto builder = outputTreeBuilder_->addObjectArray(key);
@@ -418,8 +411,6 @@ inline void WriteCheckpointData::arrayRef(const std::string& key, ArrayRef<const
     }
 }
 
-template<>
-template<>
 inline void ReadCheckpointData::tensor(const std::string& key, ::tensor values) const
 {
     auto array     = (*inputTree_)[key].asArray().values();
@@ -434,8 +425,6 @@ inline void ReadCheckpointData::tensor(const std::string& key, ::tensor values)
     values[ZZ][ZZ] = array[8].cast<real>();
 }
 
-template<>
-template<>
 inline void WriteCheckpointData::tensor(const std::string& key, const ::tensor values)
 {
     auto builder = outputTreeBuilder_->addUniformArray<real>(key);
@@ -450,34 +439,24 @@ inline void WriteCheckpointData::tensor(const std::string& key, const ::tensor v
     builder.addValue(values[ZZ][ZZ]);
 }
 
-template<>
-template<>
 inline ReadCheckpointData ReadCheckpointData::subCheckpointData(const std::string& key) const
 {
     return CheckpointData((*inputTree_)[key].asObject());
 }
 
-template<>
-template<>
 inline WriteCheckpointData WriteCheckpointData::subCheckpointData(const std::string& key)
 {
     return CheckpointData(outputTreeBuilder_->addObject(key));
 }
 
-template<CheckpointDataOperation operation>
-CheckpointData<operation>::CheckpointData(const KeyValueTreeObject& inputTree) :
+inline ReadCheckpointData::CheckpointData(const KeyValueTreeObject& inputTree) :
     inputTree_(&inputTree)
 {
-    static_assert(operation == CheckpointDataOperation::Read,
-                  "This constructor can only be called for a read CheckpointData");
 }
 
-template<CheckpointDataOperation operation>
-CheckpointData<operation>::CheckpointData(KeyValueTreeObjectBuilder&& outputTreeBuilder) :
+inline WriteCheckpointData::CheckpointData(KeyValueTreeObjectBuilder&& outputTreeBuilder) :
     outputTreeBuilder_(outputTreeBuilder)
 {
-    static_assert(operation == CheckpointDataOperation::Write,
-                  "This constructor can only be called for a write CheckpointData");
 }
 //! \endcond