diff --git a/cext/cuda_loader.h b/cext/cuda_loader.h index 2f54a78..dddfe49 100644 --- a/cext/cuda_loader.h +++ b/cext/cuda_loader.h @@ -33,6 +33,7 @@ X(cuEventDestroy, 2000) \ X(cuEventQuery, 2000) \ X(cuEventRecord, 2000) \ + X(cuEventSynchronize, 2000) \ X(cuKernelGetFunction, 12000) \ X(cuMemAlloc, 3020) \ X(cuMemAllocHost, 3020) \ diff --git a/cext/tile_kernel.cpp b/cext/tile_kernel.cpp index c9578fe..5f48fd0 100644 --- a/cext/tile_kernel.cpp +++ b/cext/tile_kernel.cpp @@ -367,6 +367,7 @@ struct LaunchHelper { Vec list_args; size_t total_list_data_size_words; Vec constants; + Vec deferred_dlpack; CUcontext cuda_context; LaunchHelper* next_free; }; @@ -404,6 +405,136 @@ static LaunchHelperPtr launch_helper_get() { } } +namespace { struct DeferredDlpackRelease { + const DriverApi* driver; + CUevent event; + Vec tensors; + DeferredDlpackRelease* next; +}; } + +static PyThread_type_lock g_deferred_dlpack_lock; +static DeferredDlpackRelease* g_deferred_dlpack_head; +static DeferredDlpackRelease* g_deferred_dlpack_tail; +static bool g_deferred_dlpack_worker_running; + +static void release_dlpack_tensors(Vec& tensors) { + for (DLManagedTensor* tensor : tensors) { + if (tensor->deleter) + tensor->deleter(tensor); + } + tensors.clear(); +} + +static void deferred_dlpack_worker(void*) { + for (;;) { + PyThread_acquire_lock(g_deferred_dlpack_lock, WAIT_LOCK); + DeferredDlpackRelease* release = g_deferred_dlpack_head; + if (release) { + g_deferred_dlpack_head = release->next; + if (!g_deferred_dlpack_head) + g_deferred_dlpack_tail = nullptr; + } else { + g_deferred_dlpack_worker_running = false; + PyThread_release_lock(g_deferred_dlpack_lock); + return; + } + PyThread_release_lock(g_deferred_dlpack_lock); + + CUresult res = release->driver->cuEventSynchronize(release->event); + CHECK(res == CUDA_SUCCESS); + + if (Py_IsInitialized()) { + PyGILState_STATE gil = PyGILState_Ensure(); + release_dlpack_tensors(release->tensors); + PyGILState_Release(gil); + } + + res = release->driver->cuEventDestroy(release->event); + CHECK(res == CUDA_SUCCESS); + delete release; + } +} + +static Status defer_dlpack_release(const DriverApi* driver, CUstream stream, + Vec& tensors) { + if (tensors.empty()) return OK; + + DeferredDlpackRelease* release = new DeferredDlpackRelease{ + .driver = driver, + .event = nullptr, + .tensors = std::move(tensors), + .next = nullptr, + }; + + CUresult res = driver->cuEventCreate(&release->event, CU_EVENT_DISABLE_TIMING); + if (res == CUDA_SUCCESS) + res = driver->cuEventRecord(release->event, stream); + + if (res != CUDA_SUCCESS) { + if (release->event) { + CUresult destroy_res = driver->cuEventDestroy(release->event); + CHECK(destroy_res == CUDA_SUCCESS); + } + + CUresult sync_res = driver->cuStreamSynchronize(stream); + release_dlpack_tensors(release->tensors); + delete release; + + if (sync_res != CUDA_SUCCESS) { + return raise(PyExc_RuntimeError, + "Failed to synchronize stream while releasing DLPack tensor: %s", + get_cuda_error(driver, sync_res)); + } + return OK; + } + + PyThread_acquire_lock(g_deferred_dlpack_lock, WAIT_LOCK); + if (g_deferred_dlpack_tail) + g_deferred_dlpack_tail->next = release; + else + g_deferred_dlpack_head = release; + g_deferred_dlpack_tail = release; + + if (!g_deferred_dlpack_worker_running) { + unsigned long thread_id = PyThread_start_new_thread(deferred_dlpack_worker, nullptr); + if (thread_id == static_cast(-1)) { + g_deferred_dlpack_head = nullptr; + g_deferred_dlpack_tail = nullptr; + PyThread_release_lock(g_deferred_dlpack_lock); + + res = driver->cuEventDestroy(release->event); + CHECK(res == CUDA_SUCCESS); + + CUresult sync_res = driver->cuStreamSynchronize(stream); + release_dlpack_tensors(release->tensors); + delete release; + + if (sync_res != CUDA_SUCCESS) { + return raise(PyExc_RuntimeError, + "Failed to synchronize stream while releasing DLPack tensor: %s", + get_cuda_error(driver, sync_res)); + } + return OK; + } + g_deferred_dlpack_worker_running = true; + } + PyThread_release_lock(g_deferred_dlpack_lock); + return OK; +} + +namespace { struct DeferredDlpackGuard { + Vec* tensors; + + ~DeferredDlpackGuard() { + if (tensors) + release_dlpack_tensors(*tensors); + } + + void dismiss() { + tensors = nullptr; + } +}; } + enum class ParameterKind { Array, Boolean, @@ -846,7 +977,7 @@ static PyPtr parse_array_constraint(ConstantCursor& cursor) { static Result arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper&) { PyPtr dict = steal(PyObject_GetAttr(pyobj, g___cuda_array_interface___pyunicode)); if (!PyDict_Check(dict.get())) { PyErr_SetString(PyExc_TypeError, @@ -924,7 +1055,7 @@ static Result arrayrepr_cuda_array_iface(PyObject* pyobj, unsigned in } static Result arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper& helper) { void* ptr = PyCapsule_GetPointer(dlpack_capsule, "dltensor"); if (!ptr) return ErrorRaised; DLManagedTensor* tensor = static_cast(ptr); @@ -971,16 +1102,7 @@ static Result arrayrepr_dlpack_common(PyObject* dlpack_capsule, unsig }; PyCapsule_SetName(dlpack_capsule, "used_dltensor"); - - // We assume that __dlpack__ returns a view of the tensor, - // so we release the capsule immediately. This should be OK for using with PyTorch - // since it always returns a view. - // - // This is technically an incorrect implementation. To do it correctly, we would - // need to implement a mechanism similar to the one found in Torch's CUDACachingAllocator: - // instead of calling the deleter immediately, we would push a cudaEvent to the stream - // after we launch the kernel, and only call the deleter once the event is ready. - tensor->deleter(tensor); + helper.deferred_dlpack.push_back(tensor); return ret; } @@ -996,7 +1118,7 @@ static Result dtype_from_torch_dtype(PyObject* torch_dtype) { } static Result arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper&) { PyPtr data_ptr = steal(PyObject_CallMethod(tensor, "data_ptr", nullptr)); if (!data_ptr) return ErrorRaised; @@ -1076,21 +1198,21 @@ static Result arrayrepr_torch_tensor_pymethod(PyObject* tensor, unsig } static Result arrayrepr_torch_tensor_dlpack(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper& helper) { PyPtr dlpack_capsule = steal(PyObject_CallFunctionObjArgs( g_torch_to_dlpack_func, pyobj, nullptr)); if (!dlpack_capsule) { SavedException exc = save_raised_exception(); LOG_PYTHON_ERROR("debug", exc, "Fail to convert to dlpack, use fallback path"); - return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena); + return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, arena, helper); } - return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena); + return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena, helper); } static Result arrayrepr_dlpack(PyObject* pyobj, unsigned index_bitwidth, - Arena& arena) { + Arena& arena, LaunchHelper& helper) { PyPtr dlpack_method = steal(PyObject_GetAttr(pyobj, g___dlpack___pyunicode)); if (!dlpack_method) return ErrorRaised; @@ -1109,17 +1231,17 @@ static Result arrayrepr_dlpack(PyObject* pyobj, unsigned index_bitwid dlpack_method.get(), empty_args.get(), kwargs.get())); if (!dlpack_capsule) return ErrorRaised; - return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena); + return arrayrepr_dlpack_common(dlpack_capsule.get(), index_bitwidth, arena, helper); } -typedef Result (*ArrayReprFunc)(PyObject*, unsigned, Arena&); +typedef Result (*ArrayReprFunc)(PyObject*, unsigned, Arena&, LaunchHelper&); template static Status extract_array(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth, LaunchHelper& helper) { - Result ar = F(pyobj, index_bitwidth, helper.arena); + Result ar = F(pyobj, index_bitwidth, helper.arena, helper); if (!ar.is_ok()) return ErrorRaised; size_t num_words = 1 + 2 * ar->arrty.ndim; @@ -1249,14 +1371,21 @@ static PyPtr parse_pyfloat_constraint(ConstantCursor& cursor, bool is_constant) } static Result get_array_repr(PythonArgKind kind, PyObject* pyobj, - unsigned index_bitwidth, Arena& arena) { + unsigned index_bitwidth, bool stream_is_capturing, + LaunchHelper& helper) { switch (kind) { case PythonArgKind::TorchTensorDlpack: - return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, arena); + if (stream_is_capturing) + return arrayrepr_torch_tensor_pymethod(pyobj, index_bitwidth, helper.arena, helper); + return arrayrepr_torch_tensor_dlpack(pyobj, index_bitwidth, helper.arena, helper); case PythonArgKind::DlpackArray: - return arrayrepr_dlpack(pyobj, index_bitwidth, arena); + if (stream_is_capturing) { + return raise(PyExc_RuntimeError, + "DLPack array argument in CUDAGraph isn't supported yet"); + } + return arrayrepr_dlpack(pyobj, index_bitwidth, helper.arena, helper); case PythonArgKind::CudaArray: - return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, arena); + return arrayrepr_cuda_array_iface(pyobj, index_bitwidth, helper.arena, helper); default: return raise(PyExc_AssertionError, "Unexpected argument kind for array: %d", static_cast(kind)); @@ -1264,7 +1393,7 @@ static Result get_array_repr(PythonArgKind kind, PyObject* pyobj, } static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned index_bitwidth, - LaunchHelper& helper) { + bool stream_is_capturing, LaunchHelper& helper) { size_t len = PyList_GET_SIZE(pyobj); if (len > INT32_MAX) return raise(PyExc_TypeError, "List is too long"); @@ -1288,7 +1417,7 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned PyTypeObject* first_item_type = first_item->ob_type; Result first_repr_res = get_array_repr(first_arg_kind, first_item, index_bitwidth, - helper.arena); + stream_is_capturing, helper); if (!first_repr_res.is_ok()) return ErrorRaised; Word* item_pointers = helper.arena.alloc(len); @@ -1322,7 +1451,8 @@ static Status extract_py_list(const DriverApi* driver, PyObject* pyobj, unsigned kind = *res; } - Result repr_res = get_array_repr(kind, item, index_bitwidth, helper.arena); + Result repr_res = get_array_repr(kind, item, index_bitwidth, + stream_is_capturing, helper); if (!repr_res.is_ok()) return ErrorRaised; item_pointers[i].arena_ptr = repr_res->repr; @@ -1372,6 +1502,7 @@ static Status extract_cuda_args(const DriverApi* driver, const Vec& constant_arg_flags, const Vec& int64_index_flags, const Vec& int64_param_flags, + bool stream_is_capturing, LaunchHelper& helper) { CHECK(num_pyargs == arg_kinds.size()); helper.arena.clear(); @@ -1379,6 +1510,7 @@ static Status extract_cuda_args(const DriverApi* driver, helper.list_args.clear(); helper.total_list_data_size_words = 0; helper.constants.clear(); + helper.deferred_dlpack.clear(); for (size_t i = 0; i < num_pyargs; ++i) { PyObject* pyobj = pyargs[i]; bool is_constant = constant_arg_flags[i]; @@ -1388,11 +1520,21 @@ static Status extract_cuda_args(const DriverApi* driver, switch (arg_kinds[i]) { case PythonArgKind::TorchTensorDlpack: - if (!extract_array( - driver, pyobj, index_bitwidth, helper)) + if (stream_is_capturing) { + if (!extract_array( + driver, pyobj, index_bitwidth, helper)) { + return ErrorRaised; + } + } else if (!extract_array( + driver, pyobj, index_bitwidth, helper)) { return ErrorRaised; + } break; case PythonArgKind::DlpackArray: + if (stream_is_capturing) { + return raise(PyExc_RuntimeError, + "DLPack array argument in CUDAGraph isn't supported yet"); + } if (!extract_array(driver, pyobj, index_bitwidth, helper)) return ErrorRaised; break; @@ -1410,7 +1552,10 @@ static Status extract_cuda_args(const DriverApi* driver, if (!extract_py_bool(pyobj, is_constant, helper)) return ErrorRaised; break; case PythonArgKind::PyList: - if (!extract_py_list(driver, pyobj, index_bitwidth, helper)) return ErrorRaised; + if (!extract_py_list(driver, pyobj, index_bitwidth, + stream_is_capturing, helper)) { + return ErrorRaised; + } break; } } @@ -2040,6 +2185,17 @@ struct PreparedLaunch { unsigned dynamic_smem_bytes; }; +static bool needs_stream_capture_status(const Vec& arg_kinds) { + for (PythonArgKind kind : arg_kinds) { + if (kind == PythonArgKind::TorchTensorDlpack + || kind == PythonArgKind::DlpackArray + || kind == PythonArgKind::PyList) { + return true; + } + } + return false; +} + static Result get_stream_context(const DriverApi* driver, CUstream stream) { CUcontext ctx = nullptr; CUresult res = driver->cuStreamGetCtx(stream, &ctx); @@ -2053,6 +2209,16 @@ static Result get_stream_context(const DriverApi* driver, CUstream st return ctx; } +static Result get_stream_capture_status(const DriverApi* driver, CUstream stream) { + CUstreamCaptureStatus status = CU_STREAM_CAPTURE_STATUS_NONE; + CUresult res = driver->cuStreamIsCapturing(stream, &status); + if (res != CUDA_SUCCESS) { + return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s", + get_cuda_error(driver, res)); + } + return status != CU_STREAM_CAPTURE_STATUS_NONE; +} + static Result prepare_launch( const DriverApi* driver, PyObject* dispatcher_pyobj, @@ -2062,6 +2228,7 @@ static Result prepare_launch( StreamBufferTransaction& tx) { LaunchHelperPtr helper = launch_helper_get(); + DeferredDlpackGuard deferred_dlpack{&helper->deferred_dlpack}; Result stream_context = get_stream_context(driver, launch_stream); if (!stream_context.is_ok()) return ErrorRaised; @@ -2107,9 +2274,16 @@ static Result prepare_launch( PythonArgProfile{family_item->value, std::move(arg_kinds)}); } + bool stream_is_capturing = false; + if (needs_stream_capture_status(profile_item->value.arg_kinds)) { + Result capture_status = get_stream_capture_status(driver, launch_stream); + if (!capture_status.is_ok()) return ErrorRaised; + stream_is_capturing = *capture_status; + } + if (!extract_cuda_args(driver, pyargs, num_pyargs, profile_item->value.arg_kinds, dispatcher.constant_arg_flags, dispatcher.int64_index_flags, - dispatcher.int64_param_flags, *helper)) { + dispatcher.int64_param_flags, stream_is_capturing, *helper)) { return ErrorRaised; } @@ -2143,12 +2317,7 @@ static Result prepare_launch( // Handle list arguments if (!helper->list_args.empty()) { if (!tx) { - CUstreamCaptureStatus status; - CUresult res = driver->cuStreamIsCapturing(launch_stream, &status); - if (res != CUDA_SUCCESS) - return raise(PyExc_RuntimeError, "Failed to check stream capturing status: %s", - get_cuda_error(driver, res)); - if (status != CU_STREAM_CAPTURE_STATUS_NONE) + if (stream_is_capturing) return raise(PyExc_RuntimeError, "List argument in CUDAGraph isn't supported yet"); Result pool_res = get_stream_buffer_pool(driver, @@ -2194,6 +2363,7 @@ static Result prepare_launch( if (dyn_smem_size < 0 || dyn_smem_size > UINT_MAX) return raise(PyExc_RuntimeError, "Invalid dynamic shared memory size"); + deferred_dlpack.dismiss(); return PreparedLaunch{std::move(helper), kernel_item->value.cukernel.kernel, static_cast(dyn_smem_size)}; } @@ -2215,6 +2385,7 @@ static Status launch(const DriverApi* driver, Result prep = prepare_launch( driver, dispatcher_pyobj, launch_stream, pyargs, num_pyargs, tx); if (!prep.is_ok()) return ErrorRaised; + DeferredDlpackGuard deferred_dlpack{&prep->helper->deferred_dlpack}; ContextGuard ctx_guard(driver); if (!maybe_switch_context(driver, prep->helper->cuda_context, ctx_guard)) @@ -2247,6 +2418,9 @@ static Status launch(const DriverApi* driver, get_cuda_error(driver, res)); } + if (!defer_dlpack_release(driver, launch_stream, prep->helper->deferred_dlpack)) + return ErrorRaised; + deferred_dlpack.dismiss(); return OK; } @@ -2261,6 +2435,7 @@ static Result benchmark(const DriverApi* driver, Result prep = prepare_launch( driver, dispatcher_pyobj, launch_stream, pyargs, num_pyargs, tx); if (!prep.is_ok()) return ErrorRaised; + DeferredDlpackGuard deferred_dlpack{&prep->helper->deferred_dlpack}; CUcontext ctx = prep->helper->cuda_context; ContextGuard bench_ctx(driver); @@ -2925,6 +3100,9 @@ Status tile_kernel_init(PyObject* m) { INIT_STRING_CONSTANT(pdl); g_stream_buffer_pool_by_ctx_id = new StreamBufferPoolMap(); + g_deferred_dlpack_lock = PyThread_allocate_lock(); + if (!g_deferred_dlpack_lock) + return raise(PyExc_RuntimeError, "Failed to allocate DLPack cleanup lock"); try_get_torch_globals(); @@ -2971,4 +3149,3 @@ Status tile_kernel_init(PyObject* m) { return OK; } - diff --git a/test/test_cudagraph.py b/test/test_cudagraph.py index b39830d..0918186 100644 --- a/test/test_cudagraph.py +++ b/test/test_cudagraph.py @@ -7,6 +7,17 @@ import pytest +class DlpackProxy: + def __init__(self, tensor): + self.tensor = tensor + + def __dlpack__(self, stream=None): + return self.tensor.__dlpack__(stream=stream) + + def __dlpack_device__(self): + return self.tensor.__dlpack_device__() + + @ct.kernel def add_one(x): xi = ct.load(x, 0, ()) @@ -27,6 +38,21 @@ def test_simple(): assert x.item() == 10 +def test_proxy_dlpack(): + x = torch.zeros(1, device='cuda') + ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),)) + torch.cuda.synchronize() + assert x.item() == 1 + + +def test_proxy_dlpack_cudagraph(): + x = torch.zeros(1, device='cuda') + graph = torch.cuda.CUDAGraph() + with pytest.raises(RuntimeError, match=r"DLPack array argument in CUDAGraph isn't supported yet"): + with torch.cuda.graph(graph): + ct.launch(torch.cuda.current_stream(), (1,), add_one, (DlpackProxy(x),)) + + @ct.kernel def matmul_accumulate(x, y, z): acc = ct.load(z, (0, 0), (16, 16))