Skip to content
Snippets Groups Projects
Forked from xiaoboh2 / hpvm-deepcopy
36 commits behind the upstream repository.
hpvm_dclib.hpp 6.14 KiB
#include <argp.h>
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <functional>
#include <iostream>
#include <typeinfo>
#include <unordered_map>
#include <vector>

#include "./pinnedVector/pinnedVector/PinnedVector.h"

// [cpp obj {ptr 1, ptr2}, ptr1{ptr3}.. ] -> Buffer[ cpp obj{ptr->relative
// address}, *ptr1, *ptr2  ] Library:
//  memory_map: tracking allocation type/size/meta_info
//  Helper functions: do buffer management
// Codegen:
//  Enumerate member, call library
// User code:
//  Tell library/Codegen we want to serialize object

using byte = char;

// In a separate file ----------------------
static std::unordered_map<void*, std::pair<const std::type_info*, size_t>>
    hpvm_allocation_record;

// Lib header provides ---------------------
extern std::unordered_map<void*, std::pair<const std::type_info*, size_t>>
    hpvm_allocation_record;

bool is_allocated_memory(void* ptr) {
  return hpvm_allocation_record.count(ptr);
}
size_t allocation_size(void* ptr) {
  if (is_allocated_memory(ptr)) return hpvm_allocation_record[ptr].second;
  return 0;
}

template <class T, class... Args>
T* hpvm_new(Args... args) {
  T* obj = new T(args...);
  hpvm_allocation_record[obj] = {&typeid(T), sizeof(T)};
  return obj;
}

template <class T>
T* hpvm_new_arr(size_t num_obj) {
  T* obj = new T[num_obj];
  hpvm_allocation_record[obj] = {&typeid(T), num_obj * sizeof(T)};
  return obj;
}

template <class T, class... Args>
void hpvm_delete(T* obj) {
  hpvm_allocation_record.erase(obj);
  delete obj;
}
template <class T, class... Args>
void hpvm_delete_arr(T* obj) {
  hpvm_allocation_record.erase(obj);
  delete[] obj;
}

// Hpvm Buf layout:
// [header, pointer_recover_list(size_t[]), obj_buf]
template <typename T>
class HpvmBufHeader {
 public:
  size_t pointer_recover_list_size;
  size_t obj_size;
  byte buf[];

  void* get_obj_start() {
    return buf + pointer_recover_list_size * sizeof(size_t);
  }
  size_t* get_pointer_list_start() { return (size_t*)buf; }
  size_t* get_pointer_at(size_t offset) {
    size_t* sub_ptr = (size_t*)((char*)get_obj_start() + offset);
    return sub_ptr;
  }
  void to_relative_pointer(void* base) {
    for (int i = 0; i < pointer_recover_list_size; i++) {
      size_t pointer_offset = get_pointer_list_start()[i];
      size_t* sub_ptr = get_pointer_at(pointer_offset);
#ifndef USE_HPVM
      std::cout << "To relative: " << (void*)sub_ptr[0];
#endif
      sub_ptr[0] -= (size_t)base;
#ifndef USE_HPVM
      std::cout << " -> " << (void*)sub_ptr[0] << "\n";
#endif
    }
  }
  void to_absolute_pointer(void* base) {
    to_relative_pointer((void*)(-(long long)base));
  }
  size_t total_size() {
    return sizeof(*this) + pointer_recover_list_size * sizeof(size_t) +
           obj_size;
  }
};
template <typename T>
class HpvmBuf {
 public:
  HpvmBufHeader<T>* buf = nullptr;
  std::vector<size_t> rel_pointer_list;
  PinnedVector<byte> obj_data;

  HpvmBuf() {
    rel_pointer_list.reserve(1);
    obj_data.reserve(1);
  }

  HpvmBufHeader<T>* formulate_device_buffer() {
    if (buf) free(buf);
    size_t buf_size = sizeof(*buf) + rel_pointer_list.size() * sizeof(size_t) +
                      obj_data.size();
    buf = (HpvmBufHeader<T>*)calloc(buf_size, 1);
    buf->pointer_recover_list_size = rel_pointer_list.size();
    buf->obj_size = obj_data.size();

    std::cout << "Formulating buffer \n";
    memcpy(buf->get_pointer_list_start(), rel_pointer_list.data(),
           buf->pointer_recover_list_size * sizeof(size_t));
    memcpy(buf->get_obj_start(), obj_data.data(), buf->obj_size);
    std::cout << "We will have pointers at offsets: ";
    for (auto i : rel_pointer_list) std::cout << (void*)i << ",";
    std::cout << "\n";

    buf->to_relative_pointer(obj_data.data());
    return buf;
  }
  T* recover_host_accessible_obj() {
    buf->to_absolute_pointer(buf->get_obj_start());
    return (T*)buf->get_obj_start();
  }

  template <class AllocType = void>
  AllocType* allocate(size_t size) {
    if (size == 0) return nullptr;
    size_t old_size = obj_data.size();
    obj_data.resize(obj_data.size() + size);
    std::cout << "Allocate up to:" << obj_data.size() << " bytes" << std::endl;
    return (AllocType*)(obj_data.data() + old_size);
  }

  void register_pointer_in_buf(void* ptr_addr_in_buf) {
    rel_pointer_list.push_back((size_t)(ptr_addr_in_buf) -
                               (size_t)obj_data.data());
  }
};

// Main user entry to snapshot an object into a buffer
template <class T>
HpvmBuf<T> hpvm_snapshot(T* original_obj) {
  std::cout << "Snapshot: " << (void*)original_obj << std::endl;
  HpvmBuf<T> buf;
  hpvm_snapshot_internal(buf, original_obj);
  std::cout << "Snapshot Done\n";
  buf.formulate_device_buffer();
  return buf;
}

// Allocated space at end of buffer and do trivial copy
template <class T, class Buf>
T* snapshot_trivial(Buf& buf, T* src) {
  auto* dst = buf.template allocate<T>(sizeof(T));
  memcpy(dst, src, sizeof(T));
  return dst;
}

// User can implement this specialization for deep copy
template <class Buf, class T>
extern void hpvm_snapshot_custom(Buf& buf, T* dst_obj, T* original_obj) {}

// Snapshot function used internally
template <class Buf, class T>
extern T* hpvm_snapshot_internal(Buf& buf, T* original_obj) {
  auto* dst = snapshot_trivial(buf, original_obj);
  hpvm_snapshot_custom(buf, dst, original_obj);
  return dst;
}

// Snapshot an (array of) pointer(s)
template <class T, class Buf>
void snapshot_pointer(Buf& buf, T*& dst_ptr, T* src_ptr) {
  if (int num_allocation = allocation_size(src_ptr) / sizeof(T)) {
    std::cout << "Snapshot pointer to buffer of " << num_allocation
              << " elems\n";
    // Ensure array allocation is continuous
    T* dst = nullptr;
    for (int i = 0; i < num_allocation; i++) {
      auto* alloc = snapshot_trivial(buf, src_ptr + i);
      if (!dst) dst = alloc;
    }

    // Only do deep copy after continuous buffer have been allocated
    for (int i = 0; i < num_allocation; i++)
      hpvm_snapshot_custom(buf, dst + i, src_ptr + i);

    dst_ptr = dst;
    buf.register_pointer_in_buf(&dst_ptr);
  } else
    dst_ptr = nullptr;
}
// Library function
// ends----------------------------------------------------------