#include #include "caffe2/core/db.h" #include "caffe2/core/logging.h" #include "caffe2/utils/proto_utils.h" namespace caffe2 { namespace db { class ProtoDBCursor : public Cursor { public: explicit ProtoDBCursor(const TensorProtos* proto) : proto_(proto), iter_(0) {} // NOLINTNEXTLINE(modernize-use-equals-default) ~ProtoDBCursor() override {} void Seek(const string& /*str*/) override { CAFFE_THROW("ProtoDB is not designed to support seeking."); } void SeekToFirst() override { iter_ = 0; } void Next() override { ++iter_; } string key() override { return proto_->protos(iter_).name(); } string value() override { return SerializeAsString_EnforceCheck( proto_->protos(iter_), "ProtoDBCursor"); } bool Valid() override { return iter_ < proto_->protos_size(); } private: const TensorProtos* proto_; int iter_; }; class ProtoDBTransaction : public Transaction { public: explicit ProtoDBTransaction(TensorProtos* proto) : proto_(proto), existing_names_() { for (const auto& tensor : proto_->protos()) { existing_names_.insert(tensor.name()); } } ~ProtoDBTransaction() override { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) Commit(); } void Put(const string& key, string&& value) override { if (existing_names_.count(key)) { CAFFE_THROW("An item with key ", key, " already exists."); } auto* tensor = proto_->add_protos(); CAFFE_ENFORCE( tensor->ParseFromString(value), "Cannot parse content from the value string."); CAFFE_ENFORCE( tensor->name() == key, "Passed in key ", key, " does not equal to the tensor name ", tensor->name()); } // Commit does nothing. The protocol buffer will be written at destruction // of ProtoDB. void Commit() override {} private: TensorProtos* proto_; std::unordered_set existing_names_; C10_DISABLE_COPY_AND_ASSIGN(ProtoDBTransaction); }; class ProtoDB : public DB { public: ProtoDB(const string& source, Mode mode) : DB(source, mode), proto_(), source_(source) { if (mode == READ || mode == WRITE) { // Read the current protobuffer. CAFFE_ENFORCE( ReadProtoFromFile(source, &proto_), "Cannot read protobuffer."); } LOG(INFO) << "Opened protodb " << source; } ~ProtoDB() override { // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) Close(); } void Close() override { if (mode_ == NEW || mode_ == WRITE) { WriteProtoToBinaryFile(proto_, source_); } } unique_ptr NewCursor() override { return make_unique(&proto_); } unique_ptr NewTransaction() override { return make_unique(&proto_); } private: TensorProtos proto_; string source_; }; REGISTER_CAFFE2_DB(ProtoDB, ProtoDB); // For lazy-minded, one can also call with lower-case name. REGISTER_CAFFE2_DB(protodb, ProtoDB); } // namespace db } // namespace caffe2