Skip to content

Support ibv_reg_dmabuf_mr for buffer allocated by cuMemMalloc #513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions src/ib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,53 @@ IbMr::IbMr(ibv_pd* pd, void* buff, std::size_t size) : buff(buff) {
}
uintptr_t addr = reinterpret_cast<uintptr_t>(buff) & -pageSize;
std::size_t pages = (size + (reinterpret_cast<uintptr_t>(buff) - addr) + pageSize - 1) / pageSize;
this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast<void*>(addr), pages * pageSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);

CUdeviceptr dptr = reinterpret_cast<CUdeviceptr>(buff);
bool cuMemAlloc = mscclpp::isCuMemMapAllocated((void*)dptr);
CUdevice dev;
int dmaBufSupported = 0;
#if !defined(__HIP_PLATFORM_AMD__)
MSCCLPP_CUTHROW(cuCtxGetDevice(&dev));
MSCCLPP_CUTHROW(cuDeviceGetAttribute(&dmaBufSupported, CU_DEVICE_ATTRIBUTE_DMA_BUF_SUPPORTED, dev));
#endif // !defined(__HIP_PLATFORM_AMD__)
if (cuMemAlloc && dmaBufSupported) {
#if !defined(__HIP_PLATFORM_AMD__)
CUdeviceptr base;
size_t actualSize;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&base, &actualSize, dptr));

size_t offset = static_cast<size_t>(dptr - base);

// Align offset down to the nearest page
size_t alignedOffset = (offset / pageSize) * pageSize;

// Align userBufferSize up to the nearest page
size_t alignedUserBufferSize = ((size + pageSize - 1) / pageSize) * pageSize;

// Ensure aligned range fits within the original allocation
if (alignedOffset + alignedUserBufferSize > actualSize) {
std::stringstream err;
err << "aligned range excceeds original allocation (errno " << errno << ")";
throw mscclpp::IbError(err.str(), errno);
}

int fd;
MSCCLPP_CUTHROW(cuMemGetHandleForAddressRange(&fd, base + alignedOffset, alignedUserBufferSize,
CU_MEM_RANGE_HANDLE_TYPE_DMA_BUF_FD, 0));

// Compute offset of dptr within the DMA-BUF
size_t offsetInDmaBuf = offset - alignedOffset;

this->mr = IBVerbs::ibv_reg_dmabuf_mr(pd, offsetInDmaBuf, size, (uint64_t)dptr, fd,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
#endif // !defined(__HIP_PLATFORM_AMD__)
} else {
this->mr = IBVerbs::ibv_reg_mr2(pd, reinterpret_cast<void*>(addr), pages * pageSize,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ |
IBV_ACCESS_RELAXED_ORDERING | IBV_ACCESS_REMOTE_ATOMIC);
}

if (this->mr == nullptr) {
std::stringstream err;
err << "ibv_reg_mr failed (errno " << errno << ")";
Expand Down
19 changes: 17 additions & 2 deletions src/include/ibverbs_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct IBVerbs {
ibv_create_qp_lib = (ibv_create_qp_t)dlsym(handle, "ibv_create_qp");
ibv_destroy_cq_lib = (ibv_destroy_cq_t)dlsym(handle, "ibv_destroy_cq");
ibv_reg_mr_lib = (ibv_reg_mr_t)dlsym(handle, "ibv_reg_mr");
ibv_reg_dmabuf_mr_lib = (ibv_reg_dmabuf_mr_t)dlsym(handle, "ibv_reg_dmabuf_mr");
ibv_dereg_mr_lib = (ibv_dereg_mr_t)dlsym(handle, "ibv_dereg_mr");
ibv_query_gid_lib = (ibv_query_gid_t)dlsym(handle, "ibv_query_gid");
ibv_modify_qp_lib = (ibv_modify_qp_t)dlsym(handle, "ibv_modify_qp");
Expand All @@ -44,7 +45,8 @@ struct IBVerbs {
if (!ibv_get_device_list_lib || !ibv_free_device_list_lib || !ibv_alloc_pd_lib || !ibv_dealloc_pd_lib ||
!ibv_open_device_lib || !ibv_close_device_lib || !ibv_query_device_lib || !ibv_create_cq_lib ||
!ibv_create_qp_lib || !ibv_destroy_cq_lib || !ibv_reg_mr_lib || !ibv_dereg_mr_lib || !ibv_query_gid_lib ||
!ibv_reg_mr_iova2_lib || !ibv_modify_qp_lib || !ibv_destroy_qp_lib || !ibv_query_port_lib) {
!ibv_reg_mr_iova2_lib || !ibv_modify_qp_lib || !ibv_destroy_qp_lib || !ibv_query_port_lib ||
!ibv_reg_dmabuf_mr_lib) {
throw mscclpp::IbError("Failed to load one or more function in the ibibverbs library: " + std::string(dlerror()),
errno);
dlclose(handle);
Expand Down Expand Up @@ -151,6 +153,16 @@ struct IBVerbs {
return nullptr;
}

// Static method to register a dma-buf based memory region
static struct ibv_mr* ibv_reg_dmabuf_mr(struct ibv_pd* pd, uint64_t offset, size_t length, uint64_t iova, int fd,
int access) {
if (!initialized) initialize();
if (ibv_reg_dmabuf_mr_lib) {
return ibv_reg_dmabuf_mr_lib(pd, offset, length, iova, fd, access);
}
return nullptr;
}

// Static method to deregister a memory region
static int ibv_dereg_mr(struct ibv_mr* mr) {
if (!initialized) initialize();
Expand Down Expand Up @@ -239,6 +251,8 @@ struct IBVerbs {
typedef int (*ibv_destroy_cq_t)(struct ibv_cq*);
typedef int (*ibv_destroy_qp_t)(struct ibv_qp*);
typedef struct ibv_mr* (*ibv_reg_mr_t)(struct ibv_pd*, void*, size_t, int);
typedef struct ibv_mr* (*ibv_reg_dmabuf_mr_t)(struct ibv_pd*, uint64_t offset, size_t length, uint64_t iova, int fd,
int access);
typedef int (*ibv_dereg_mr_t)(struct ibv_mr*);
typedef int (*ibv_query_gid_t)(struct ibv_context*, uint8_t, int, union ibv_gid*);
typedef int (*ibv_modify_qp_t)(struct ibv_qp*, struct ibv_qp_attr*, int);
Expand All @@ -257,6 +271,7 @@ struct IBVerbs {
static inline ibv_create_qp_t ibv_create_qp_lib = nullptr;
static inline ibv_destroy_cq_t ibv_destroy_cq_lib = nullptr;
static inline ibv_reg_mr_t ibv_reg_mr_lib = nullptr;
static inline ibv_reg_dmabuf_mr_t ibv_reg_dmabuf_mr_lib = nullptr;
static inline ibv_dereg_mr_t ibv_dereg_mr_lib = nullptr;
static inline ibv_query_gid_t ibv_query_gid_lib = nullptr;
static inline ibv_modify_qp_t ibv_modify_qp_lib = nullptr;
Expand All @@ -269,4 +284,4 @@ struct IBVerbs {

} // namespace mscclpp

#endif // MSCCLPP_IBVERBS_WRAPPER_HPP_
#endif // MSCCLPP_IBVERBS_WRAPPER_HPP_
Loading