/*************************************************************************
 * Copyright (c) 2015-2025, NVIDIA CORPORATION. All rights reserved.
 *
 * See LICENSE.txt for license information
 ************************************************************************/

#include "comm.h"
#include "transport.h"
#include "group.h"

NCCL_API(ncclResult_t, ncclMemAlloc, void **ptr, size_t size);
ncclResult_t  ncclMemAlloc_impl(void **ptr, size_t size) {
  NVTX3_FUNC_RANGE_IN(nccl_domain);
  ncclResult_t ret = ncclSuccess;

#if ROCM_VERSION >= 70000
  size_t memGran = 0;
  hipDevice_t currentDev;
  hipMemAllocationProp memprop = {};
  hipMemAccessDesc accessDesc = {};
  hipMemGenericAllocationHandle_t handle = (hipMemGenericAllocationHandle_t)-1;
  int cudaDev;
  int flag;
  int dcnt;

  if (ptr == NULL || size == 0) goto fallback;

  // if (rocmLibraryInit() != ncclSuccess) goto fallback;
  rocmLibraryInit();

  CUDACHECK(hipGetDevice(&cudaDev));
  CUCHECK(hipDeviceGet(&currentDev, cudaDev));

  if (ncclCuMemEnable()) {
    size_t handleSize = size;
    int requestedHandleTypes = hipMemHandleTypePosixFileDescriptor;
    // Query device to see if FABRIC handle support is available
    flag = 0;
    (void) CUPFN(hipDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, currentDev));
    if (flag) requestedHandleTypes |= CU_MEM_HANDLE_TYPE_FABRIC;
    memprop.type = hipMemAllocationTypePinned;
    memprop.location.type = hipMemLocationTypeDevice;
    memprop.requestedHandleTypes = (hipMemAllocationHandleType) requestedHandleTypes;
    memprop.location.id = currentDev;
    // Query device to see if RDMA support is available
    flag = 0;
    // CUCHECK(hipDeviceGetAttribute(&flag, CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, currentDev));
    if (flag) memprop.allocFlags.gpuDirectRDMACapable = 1;
    CUCHECK(hipMemGetAllocationGranularity(&memGran, &memprop, hipMemAllocationGranularityRecommended));
    CUDACHECK(hipGetDeviceCount(&dcnt));
    ALIGN_SIZE(handleSize, memGran);

    if (requestedHandleTypes & CU_MEM_HANDLE_TYPE_FABRIC) {
      /* First try hipMemCreate() with FABRIC handle support and then remove if it fails */
      hipError_t err = CUPFN(hipMemCreate(&handle, handleSize, &memprop, 0));
      if (err == hipErrorNotSupported) {
        requestedHandleTypes &= ~CU_MEM_HANDLE_TYPE_FABRIC;
        memprop.requestedHandleTypes = (hipMemAllocationHandleType) requestedHandleTypes;
        /* Allocate the physical memory on the device */
        CUCHECK(hipMemCreate(&handle, handleSize, &memprop, 0));
      } else if (err != hipSuccess) {
        // Catch and report any error from above
        CUCHECK(hipMemCreate(&handle, handleSize, &memprop, 0));
      }
    } else {
      /* Allocate the physical memory on the device */
      CUCHECK(hipMemCreate(&handle, handleSize, &memprop, 0));
    }
    /* Reserve a virtual address range */
    CUCHECK(hipMemAddressReserve((hipDeviceptr_t*)ptr, handleSize, memGran, 0, 0));
    /* Map the virtual address range to the physical allocation */
    CUCHECK(hipMemMap((hipDeviceptr_t)*ptr, handleSize, 0, handle, 0));
    /* Now allow RW access to the newly mapped memory */
    for (int i = 0; i < dcnt; ++i) {
      int p2p = 0;
      if (i == cudaDev || ((hipDeviceCanAccessPeer(&p2p, i, cudaDev) == hipSuccess) && p2p)) {
        accessDesc.location.type = hipMemLocationTypeDevice;
        accessDesc.location.id = i;
        accessDesc.flags = hipMemAccessFlagsProtReadWrite;
        CUCHECK(hipMemSetAccess((hipDeviceptr_t)*ptr, handleSize, &accessDesc, 1));
      }
      if (0 == p2p && i != cudaDev) INFO(NCCL_ALLOC, "P2P not supported between GPU%d and GPU%d", cudaDev, i);
    }
    goto exit;
  }

fallback:
#endif
  // Coverity is right to complain that we may pass a NULL ptr to hipMalloc.  That's deliberate though:
  // we want CUDA to return an error to the caller.
  // coverity[var_deref_model]
  CUDACHECKGOTO(hipMalloc(ptr, size), ret, fail);

exit:
  return ret;
fail:
  goto exit;
}

NCCL_API(ncclResult_t, ncclMemFree, void *ptr);
ncclResult_t  ncclMemFree_impl(void *ptr) {
  NVTX3_FUNC_RANGE_IN(nccl_domain);
  ncclResult_t ret = ncclSuccess;
  int saveDevice;

  CUDACHECK(hipGetDevice(&saveDevice));
#if ROCM_VERSION >= 70000
  hipDevice_t ptrDev = 0;

  if (ptr == NULL) goto fallback;
  // if (rocmLibraryInit() != ncclSuccess) goto fallback;
  rocmLibraryInit();

  CUCHECKGOTO(hipPointerGetAttribute((void*)&ptrDev, HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, (hipDeviceptr_t)ptr), ret, fail);
  CUDACHECKGOTO(hipSetDevice((int)ptrDev), ret, fail);
  if (ncclCuMemEnable()) {
    NCCLCHECKGOTO(ncclCuMemFree(ptr), ret, fail);
    goto exit;
  }

fallback:
#endif
  CUDACHECKGOTO(hipFree(ptr), ret, fail);

exit:
  CUDACHECK(hipSetDevice(saveDevice));
  return ret;
fail:
  goto exit;
}

// This is a collective function and should be called by all ranks in the communicator
ncclResult_t ncclCommSymmetricAllocInternal(struct ncclComm* comm, size_t size, size_t alignment, void** symPtr) {
  ncclResult_t ret = ncclSuccess;
  void* regSymAddr = NULL;
  size_t allocSize = size;
  size_t granularity;
  hipDevice_t cuDev;
  hipMemAllocationProp memprop = {};
  hipMemGenericAllocationHandle_t memHandle;
  int bit = 0, cnt = 0;

  // aligment must be power of 2 as an input
  while (bit < sizeof(size_t) * 8) {
    if (alignment & (1L << bit)) cnt++;
    if (cnt == 2) {
      WARN("rank %d alignment %ld is not power of 2", comm->rank, alignment);
      goto fail;
    }
    bit++;
  }
  // temporarily align the alignment to NCCL_REC_PAGE_SIZE
  ALIGN_SIZE(alignment, NCCL_REC_PAGE_SIZE);

  CUCHECKGOTO(hipDeviceGet(&cuDev, comm->cudaDev), ret, fail);
  memprop.type = hipMemAllocationTypePinned;
  memprop.location.type = hipMemLocationTypeDevice;
  memprop.requestedHandleType = ncclCuMemHandleType;
  memprop.location.id = cuDev;
  CUCHECKGOTO(hipMemGetAllocationGranularity(&granularity, &memprop, hipMemAllocationGranularityRecommended), ret, fail);
  ALIGN_SIZE(allocSize, granularity);

  CUCHECKGOTO(hipMemCreate(&memHandle, allocSize, &memprop, 0), ret, fail);
  ALIGN_SIZE(comm->symAllocHead, alignment);
  NCCLCHECKGOTO(ncclIpcSymmetricMap(comm, comm->symAllocHead, allocSize, memHandle, &regSymAddr), ret, fail);
  NCCLCHECKGOTO(ncclNvlsSymmetricMap(comm, comm->symAllocHead, allocSize, regSymAddr), ret, fail);
  NCCLCHECKGOTO(bootstrapIntraNodeBarrier(comm->bootstrap, comm->localRankToRank, comm->localRank, comm->localRanks, comm->localRankToRank[0]), ret, fail);
  comm->symAllocHead += allocSize;
  *symPtr = regSymAddr;

exit:
  return ret;
fail:
  *symPtr = NULL;
  goto exit;
}

ncclResult_t ncclCommSymmetricFreeInternal(struct ncclComm* comm, void* symPtr) {
  hipMemGenericAllocationHandle_t handle;
  size_t size = 0;
  ncclResult_t ret = ncclSuccess;
  int saveDev = comm->cudaDev;
  CUDACHECKGOTO(hipGetDevice(&saveDev), ret, fail);
  if (ncclCuMemEnable()) {
    CUDACHECKGOTO(hipSetDevice(comm->cudaDev), ret, fail);
    CUCHECKGOTO(hipMemRetainAllocationHandle(&handle, symPtr), ret, fail);
    CUCHECKGOTO(hipMemRelease(handle), ret, fail);
    CUCHECKGOTO(hipMemGetAddressRange(NULL, &size, (hipDeviceptr_t)symPtr), ret, fail);
    NCCLCHECKGOTO(ncclNvlsSymmetricFree(comm, size, symPtr), ret, fail);
    NCCLCHECKGOTO(ncclIpcSymmetricFree(comm, size, symPtr), ret, fail);
    CUCHECKGOTO(hipMemRelease(handle), ret, fail);
  }
exit:
  CUDACHECK(hipSetDevice(saveDev));
  return ret;
fail:
  goto exit;
}
