From 1c1334080b5beba8a801f2fd799be73e1b7c3a23 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 14:55:06 -0700 Subject: [PATCH 1/7] rename python bindings for take and put --- .../source/integer_advanced_indexing.cpp | 28 +++++++++---------- .../source/integer_advanced_indexing.hpp | 28 +++++++++---------- dpnp/tensor/libtensor/source/tensor_ctors.cpp | 8 +++--- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index 9d363f8bbca..a097b4450e2 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -251,13 +251,13 @@ std::vector parse_py_ind(const sycl::queue &q, } std::pair - usm_ndarray_take(const dpnp::tensor::usm_ndarray &src, - const py::object &py_ind, - const dpnp::tensor::usm_ndarray &dst, - int axis_start, - std::uint8_t mode, - sycl::queue &exec_q, - const std::vector &depends) + py_take(const dpnp::tensor::usm_ndarray &src, + const py::object &py_ind, + const dpnp::tensor::usm_ndarray &dst, + int axis_start, + std::uint8_t mode, + sycl::queue &exec_q, + const std::vector &depends) { std::vector ind = parse_py_ind(exec_q, py_ind); @@ -521,13 +521,13 @@ std::pair } std::pair - usm_ndarray_put(const dpnp::tensor::usm_ndarray &dst, - const py::object &py_ind, - const dpnp::tensor::usm_ndarray &val, - int axis_start, - std::uint8_t mode, - sycl::queue &exec_q, - const std::vector &depends) + py_put(const dpnp::tensor::usm_ndarray &dst, + const py::object &py_ind, + const dpnp::tensor::usm_ndarray &val, + int axis_start, + std::uint8_t mode, + sycl::queue &exec_q, + const std::vector &depends) { std::vector ind = parse_py_ind(exec_q, py_ind); int k = ind.size(); diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp index 52d9542e501..cd1a3733f3f 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp @@ -49,22 +49,22 @@ namespace dpnp::tensor::py_internal { extern std::pair - usm_ndarray_take(const dpnp::tensor::usm_ndarray &, - const py::object &, - const dpnp::tensor::usm_ndarray &, - int, - std::uint8_t, - sycl::queue &, - const std::vector & = {}); + py_take(const dpnp::tensor::usm_ndarray &, + const py::object &, + const dpnp::tensor::usm_ndarray &, + int, + std::uint8_t, + sycl::queue &, + const std::vector & = {}); extern std::pair - usm_ndarray_put(const dpnp::tensor::usm_ndarray &, - const py::object &, - const dpnp::tensor::usm_ndarray &, - int, - std::uint8_t, - sycl::queue &, - const std::vector & = {}); + py_put(const dpnp::tensor::usm_ndarray &, + const py::object &, + const dpnp::tensor::usm_ndarray &, + int, + std::uint8_t, + sycl::queue &, + const std::vector & = {}); extern void init_advanced_indexing_dispatch_tables(void); diff --git a/dpnp/tensor/libtensor/source/tensor_ctors.cpp b/dpnp/tensor/libtensor/source/tensor_ctors.cpp index fff4c2174d6..b18080f4721 100644 --- a/dpnp/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpnp/tensor/libtensor/source/tensor_ctors.cpp @@ -106,8 +106,8 @@ using dpnp::tensor::py_internal::usm_ndarray_full; using dpnp::tensor::py_internal::usm_ndarray_zeros; /* ============== Advanced Indexing ============= */ -using dpnp::tensor::py_internal::usm_ndarray_put; -using dpnp::tensor::py_internal::usm_ndarray_take; +using dpnp::tensor::py_internal::py_put; +using dpnp::tensor::py_internal::py_take; using dpnp::tensor::py_internal::py_extract; using dpnp::tensor::py_internal::py_mask_positions; @@ -329,7 +329,7 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("fill_value"), py::arg("dst"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_take", &usm_ndarray_take, + m.def("_take", &py_take, "Takes elements at usm_ndarray indices `ind` and axes starting " "at axis `axis_start` from array `src` and copies them " "into usm_ndarray `dst` synchronously." @@ -338,7 +338,7 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("mode"), py::arg("sycl_queue"), py::arg("depends") = py::list()); - m.def("_put", &usm_ndarray_put, + m.def("_put", &py_put, "Puts elements at usm_ndarray indices `ind` and axes starting " "at axis `axis_start` into array `dst` from " "usm_ndarray `val` synchronously." From 7d49a74443aa202316503eb9319e9cbc0b2ebe3e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 15:00:20 -0700 Subject: [PATCH 2/7] change indexing mode dispatching --- .../source/integer_advanced_indexing.cpp | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index a097b4450e2..6ac914bf7c6 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -57,10 +57,6 @@ #include "integer_advanced_indexing.hpp" -#define INDEXING_MODES 2 -#define WRAP_MODE 0 -#define CLIP_MODE 1 - namespace dpnp::tensor::py_internal { @@ -69,11 +65,15 @@ namespace td_ns = dpnp::tensor::type_dispatch; using dpnp::tensor::kernels::indexing::put_fn_ptr_t; using dpnp::tensor::kernels::indexing::take_fn_ptr_t; -static take_fn_ptr_t take_dispatch_table[INDEXING_MODES][td_ns::num_types] - [td_ns::num_types]; +static take_fn_ptr_t take_wrap_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static take_fn_ptr_t take_clip_dispatch_table[td_ns::num_types] + [td_ns::num_types]; + +static put_fn_ptr_t put_wrap_dispatch_table[td_ns::num_types][td_ns::num_types]; -static put_fn_ptr_t put_dispatch_table[INDEXING_MODES][td_ns::num_types] - [td_ns::num_types]; +static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types]; namespace py = pybind11; @@ -492,7 +492,8 @@ std::pair std::end(pack_deps)); all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); - auto fn = take_dispatch_table[mode][src_type_id][ind_type_id]; + auto fn = mode ? take_clip_dispatch_table[src_type_id][ind_type_id] + : take_wrap_dispatch_table[src_type_id][ind_type_id]; if (fn == nullptr) { sycl::event::wait(host_task_events); @@ -760,7 +761,8 @@ std::pair std::end(pack_deps)); all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends)); - auto fn = put_dispatch_table[mode][dst_type_id][ind_type_id]; + auto fn = mode ? put_clip_dispatch_table[dst_type_id][ind_type_id] + : put_wrap_dispatch_table[dst_type_id][ind_type_id]; if (fn == nullptr) { sycl::event::wait(host_task_events); @@ -795,20 +797,20 @@ void init_advanced_indexing_dispatch_tables(void) using dpnp::tensor::kernels::indexing::TakeClipFactory; DispatchTableBuilder dtb_takeclip; - dtb_takeclip.populate_dispatch_table(take_dispatch_table[CLIP_MODE]); + dtb_takeclip.populate_dispatch_table(take_clip_dispatch_table); using dpnp::tensor::kernels::indexing::TakeWrapFactory; DispatchTableBuilder dtb_takewrap; - dtb_takewrap.populate_dispatch_table(take_dispatch_table[WRAP_MODE]); + dtb_takewrap.populate_dispatch_table(take_wrap_dispatch_table); using dpnp::tensor::kernels::indexing::PutClipFactory; DispatchTableBuilder dtb_putclip; - dtb_putclip.populate_dispatch_table(put_dispatch_table[CLIP_MODE]); + dtb_putclip.populate_dispatch_table(put_clip_dispatch_table); using dpnp::tensor::kernels::indexing::PutWrapFactory; DispatchTableBuilder dtb_putwrap; - dtb_putwrap.populate_dispatch_table(put_dispatch_table[WRAP_MODE]); + dtb_putwrap.populate_dispatch_table(put_wrap_dispatch_table); } } // namespace dpnp::tensor::py_internal From 648d1abef1b07f87b82019159b2ff04a7f4c2965 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 16:42:51 -0700 Subject: [PATCH 3/7] initial refactor of indexing python bindings --- .../source/integer_advanced_indexing.cpp | 289 ++++++++++-------- 1 file changed, 156 insertions(+), 133 deletions(-) diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index 6ac914bf7c6..322e58e1e27 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -77,7 +77,147 @@ static put_fn_ptr_t put_clip_dispatch_table[td_ns::num_types][td_ns::num_types]; namespace py = pybind11; -using dpnp::utils::keep_args_alive; +namespace detail +{ + +void copy_axis_shape_strides(int axis_start, + int inp_nd, + int k, + int ind_nd, + const py::ssize_t *inp_shape, + const std::vector &inp_strides, + const py::ssize_t *arr_shape, + const std::vector &arr_strides, + py::ssize_t *host_along_sh_st) +{ + if (inp_nd > 0) { + std::copy(inp_shape + axis_start, inp_shape + axis_start + k, + host_along_sh_st); + std::copy(inp_strides.begin() + axis_start, + inp_strides.begin() + axis_start + k, host_along_sh_st + k); + } + + if (ind_nd > 0) { + std::copy(arr_shape + axis_start, arr_shape + axis_start + ind_nd, + host_along_sh_st + 2 * k); + std::copy(arr_strides.begin() + axis_start, + arr_strides.begin() + axis_start + ind_nd, + host_along_sh_st + 2 * k + ind_nd); + } +} + +void copy_orthog_shape_strides(int axis_start, + int inp_nd, + int k, + int ind_nd, + int orthog_sh_elems, + const py::ssize_t *inp_shape, + const std::vector &inp_strides, + const std::vector &arr_strides, + py::ssize_t *host_orthog_sh_st) +{ + int orthog_nd = inp_nd - k; + if (orthog_nd == 0) { + return; + } + + if (axis_start > 0) { + std::copy(inp_shape, inp_shape + axis_start, host_orthog_sh_st); + std::copy(inp_strides.begin(), inp_strides.begin() + axis_start, + host_orthog_sh_st + orthog_sh_elems); + std::copy(arr_strides.begin(), arr_strides.begin() + axis_start, + host_orthog_sh_st + 2 * orthog_sh_elems); + } + if (inp_nd > (axis_start + k)) { + std::copy(inp_shape + axis_start + k, inp_shape + inp_nd, + host_orthog_sh_st + axis_start); + std::copy(inp_strides.begin() + axis_start + k, inp_strides.end(), + host_orthog_sh_st + orthog_sh_elems + axis_start); + std::copy(arr_strides.begin() + axis_start + ind_nd, arr_strides.end(), + host_orthog_sh_st + 2 * orthog_sh_elems + axis_start); + } +} + +void validate_index_array(const dpnp::tensor::usm_ndarray &ind_, + const sycl::queue &exec_q, + int ind_nd, + int ind_type_id, + const py::ssize_t *ind_shape, + const td_ns::usm_ndarray_types &array_types, + const dpnp::tensor::overlap::MemoryOverlap &overlap, + const dpnp::tensor::usm_ndarray &other_array) +{ + if (!dpnp::utils::queues_are_compatible(exec_q, {ind_})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + + if (ind_.get_ndim() != ind_nd) { + throw py::value_error("Index dimensions are not the same"); + } + + if (ind_type_id != array_types.typenum_to_lookup_id(ind_.get_typenum())) { + throw py::type_error("Indices array data types are not all the same."); + } + + const py::ssize_t *ind_shape_ = ind_.get_shape_raw(); + for (int dim = 0; dim < ind_nd; ++dim) { + if (ind_shape[dim] != ind_shape_[dim]) { + throw py::value_error("Indices shapes are not all equal."); + } + } + + if (overlap(ind_, other_array)) { + throw py::value_error("Arrays index overlapping segments of memory"); + } +} + +void process_index_arrays(const std::vector &ind, + sycl::queue &exec_q, + int k, + int ind_nd, + int ind_sh_elems, + const py::ssize_t *ind_shape, + int ind_type_id, + const td_ns::usm_ndarray_types &array_types, + const dpnp::tensor::overlap::MemoryOverlap &overlap, + const dpnp::tensor::usm_ndarray &other_array, + std::vector &ind_ptrs, + std::vector &ind_offsets, + std::vector &ind_sh_sts) +{ + for (int i = 0; i < k; ++i) { + const dpnp::tensor::usm_ndarray &ind_ = ind[i]; + + if (i > 0) { + validate_index_array(ind_, exec_q, ind_nd, ind_type_id, ind_shape, + array_types, overlap, other_array); + } + else { + if (!dpnp::utils::queues_are_compatible(exec_q, {ind_})) { + throw py::value_error( + "Execution queue is not compatible with allocation queues"); + } + if (overlap(ind_, other_array)) { + throw py::value_error( + "Arrays index overlapping segments of memory"); + } + } + + char *ind_data = ind_.get_data(); + + if (ind_nd > 0) { + auto ind_strides = ind_.get_strides_vector(); + std::copy(ind_strides.begin(), ind_strides.end(), + ind_sh_sts.begin() + (i + 1) * ind_nd); + } + + ind_ptrs.push_back(ind_data); + ind_offsets.push_back(py::ssize_t(0)); + } +} + +} // namespace detail std::vector _populate_kernel_params(sycl::queue &exec_q, @@ -101,7 +241,6 @@ std::vector int orthog_sh_elems, int ind_sh_elems) { - using usm_host_allocator_T = dpnp::tensor::alloc_utils::usm_host_allocator; using ptrT = std::vector; @@ -144,47 +283,13 @@ std::vector host_ind_offsets_shp->data(), device_ind_offsets, host_ind_offsets_shp->size()); - int orthog_nd = inp_nd - k; - - if (orthog_nd > 0) { - if (axis_start > 0) { - std::copy(inp_shape, inp_shape + axis_start, - host_orthog_sh_st_shp->begin()); - std::copy(inp_strides.begin(), inp_strides.begin() + axis_start, - host_orthog_sh_st_shp->begin() + orthog_sh_elems); - std::copy(arr_strides.begin(), arr_strides.begin() + axis_start, - host_orthog_sh_st_shp->begin() + 2 * orthog_sh_elems); - } - if (inp_nd > (axis_start + k)) { - std::copy(inp_shape + axis_start + k, inp_shape + inp_nd, - host_orthog_sh_st_shp->begin() + axis_start); - std::copy(inp_strides.begin() + axis_start + k, inp_strides.end(), - host_orthog_sh_st_shp->begin() + orthog_sh_elems + - axis_start); - - std::copy(arr_strides.begin() + axis_start + ind_nd, - arr_strides.end(), - host_orthog_sh_st_shp->begin() + 2 * orthog_sh_elems + - axis_start); - } - } + detail::copy_orthog_shape_strides( + axis_start, inp_nd, k, ind_nd, orthog_sh_elems, inp_shape, inp_strides, + arr_strides, host_orthog_sh_st_shp->data()); - if (inp_nd > 0) { - std::copy(inp_shape + axis_start, inp_shape + axis_start + k, - host_along_sh_st_shp->begin()); - - std::copy(inp_strides.begin() + axis_start, - inp_strides.begin() + axis_start + k, - host_along_sh_st_shp->begin() + k); - } - - if (ind_nd > 0) { - std::copy(arr_shape + axis_start, arr_shape + axis_start + ind_nd, - host_along_sh_st_shp->begin() + 2 * k); - std::copy(arr_strides.begin() + axis_start, - arr_strides.begin() + axis_start + ind_nd, - host_along_sh_st_shp->begin() + 2 * k + ind_nd); - } + detail::copy_axis_shape_strides(axis_start, inp_nd, k, ind_nd, inp_shape, + inp_strides, arr_shape, arr_strides, + host_along_sh_st_shp->data()); const sycl::event &device_orthog_sh_st_copy_ev = exec_q.copy( host_orthog_sh_st_shp->data(), device_orthog_sh_st, @@ -382,52 +487,10 @@ std::pair if (ind_nd > 0) { std::copy(ind_shape, ind_shape + ind_nd, ind_sh_sts.begin()); } - for (int i = 0; i < k; ++i) { - dpnp::tensor::usm_ndarray ind_ = ind[i]; - - if (!dpnp::utils::queues_are_compatible(exec_q, {ind_})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - // ndim, type, and shape are checked against the first array - if (i > 0) { - if (!(ind_.get_ndim() == ind_nd)) { - throw py::value_error("Index dimensions are not the same"); - } - - if (!(ind_type_id == - array_types.typenum_to_lookup_id(ind_.get_typenum()))) { - throw py::type_error( - "Indices array data types are not all the same."); - } - - const py::ssize_t *ind_shape_ = ind_.get_shape_raw(); - for (int dim = 0; dim < ind_nd; ++dim) { - if (!(ind_shape[dim] == ind_shape_[dim])) { - throw py::value_error("Indices shapes are not all equal."); - } - } - } - - // check for overlap with destination - if (overlap(dst, ind_)) { - throw py::value_error( - "Arrays index overlapping segments of memory"); - } - - char *ind_data = ind_.get_data(); - - // strides are initialized to 0 for 0D indices, so skip here - if (ind_nd > 0) { - auto ind_strides = ind_.get_strides_vector(); - std::copy(ind_strides.begin(), ind_strides.end(), - ind_sh_sts.begin() + (i + 1) * ind_nd); - } - - ind_ptrs.push_back(ind_data); - ind_offsets.push_back(py::ssize_t(0)); - } + detail::process_index_arrays(ind, exec_q, k, ind_nd, ind_sh_elems, + ind_shape, ind_type_id, array_types, overlap, + dst, ind_ptrs, ind_offsets, ind_sh_sts); if (ind_nelems == 0) { return std::make_pair(sycl::event{}, sycl::event{}); @@ -515,6 +578,7 @@ std::pair packed_ind_ptrs_owner, packed_ind_offsets_owner); host_task_events.push_back(temporaries_cleanup_ev); + using dpnp::utils::keep_args_alive; sycl::event arg_cleanup_ev = keep_args_alive(exec_q, {src, py_ind, dst}, host_task_events); @@ -650,54 +714,12 @@ std::pair ind_offsets.reserve(k); std::vector ind_sh_sts((k + 1) * ind_sh_elems, py::ssize_t(0)); if (ind_nd > 0) { - std::copy(ind_shape, ind_shape + ind_sh_elems, ind_sh_sts.begin()); + std::copy(ind_shape, ind_shape + ind_nd, ind_sh_sts.begin()); } - for (int i = 0; i < k; ++i) { - dpnp::tensor::usm_ndarray ind_ = ind[i]; - if (!dpnp::utils::queues_are_compatible(exec_q, {ind_})) { - throw py::value_error( - "Execution queue is not compatible with allocation queues"); - } - - // ndim, type, and shape are checked against the first array - if (i > 0) { - if (!(ind_.get_ndim() == ind_nd)) { - throw py::value_error("Index dimensions are not the same"); - } - - if (!(ind_type_id == - array_types.typenum_to_lookup_id(ind_.get_typenum()))) { - throw py::type_error( - "Indices array data types are not all the same."); - } - - const py::ssize_t *ind_shape_ = ind_.get_shape_raw(); - for (int dim = 0; dim < ind_nd; ++dim) { - if (!(ind_shape[dim] == ind_shape_[dim])) { - throw py::value_error("Indices shapes are not all equal."); - } - } - } - - // check for overlap with destination - if (overlap(ind_, dst)) { - throw py::value_error( - "Arrays index overlapping segments of memory"); - } - - char *ind_data = ind_.get_data(); - - // strides are initialized to 0 for 0D indices, so skip here - if (ind_nd > 0) { - auto ind_strides = ind_.get_strides_vector(); - std::copy(ind_strides.begin(), ind_strides.end(), - ind_sh_sts.begin() + (i + 1) * ind_nd); - } - - ind_ptrs.push_back(ind_data); - ind_offsets.push_back(py::ssize_t(0)); - } + detail::process_index_arrays(ind, exec_q, k, ind_nd, ind_sh_elems, + ind_shape, ind_type_id, array_types, overlap, + dst, ind_ptrs, ind_offsets, ind_sh_sts); if (ind_nelems == 0) { return std::make_pair(sycl::event{}, sycl::event{}); @@ -784,6 +806,7 @@ std::pair packed_ind_ptrs_owner, packed_ind_offsets_owner); host_task_events.push_back(temporaries_cleanup_ev); + using dpnp::utils::keep_args_alive; sycl::event arg_cleanup_ev = keep_args_alive(exec_q, {dst, py_ind, val}, host_task_events); From 6425d9aa43d020e050ce573136aceb1b3d5be698 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 17:19:50 -0700 Subject: [PATCH 4/7] further refactor indexing kernel param initialization --- .../source/integer_advanced_indexing.cpp | 122 +++++++++++------- 1 file changed, 72 insertions(+), 50 deletions(-) diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index 322e58e1e27..52e520d0f37 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -80,6 +80,49 @@ namespace py = pybind11; namespace detail { +using host_ptr_allocator_t = + dpnp::tensor::alloc_utils::usm_host_allocator; +using host_ptr_vec_t = std::vector; +using host_ptr_shp_t = std::shared_ptr; + +using host_sz_allocator_t = + dpnp::tensor::alloc_utils::usm_host_allocator; +using host_sz_vec_t = std::vector; +using host_sz_shp_t = std::shared_ptr; + +template +sycl::event copy_to_device(sycl::queue &exec_q, + const T *host_data, + T *device_data, + std::size_t count) +{ + return exec_q.copy(host_data, device_data, count); +} + +host_ptr_shp_t allocate_and_copy_ptrs(sycl::queue &exec_q, + const std::vector &ptrs) +{ + host_ptr_allocator_t allocator(exec_q); + auto host_shp = std::make_shared(ptrs.size(), allocator); + std::copy(ptrs.begin(), ptrs.end(), host_shp->begin()); + return host_shp; +} + +host_sz_shp_t allocate_and_copy_sizes(sycl::queue &exec_q, + const std::vector &sizes) +{ + host_sz_allocator_t allocator(exec_q); + auto host_shp = std::make_shared(sizes.size(), allocator); + std::copy(sizes.begin(), sizes.end(), host_shp->begin()); + return host_shp; +} + +host_sz_shp_t allocate_host_buffer(sycl::queue &exec_q, std::size_t size) +{ + host_sz_allocator_t allocator(exec_q); + return std::make_shared(size, allocator); +} + void copy_axis_shape_strides(int axis_start, int inp_nd, int k, @@ -241,47 +284,16 @@ std::vector int orthog_sh_elems, int ind_sh_elems) { - using usm_host_allocator_T = - dpnp::tensor::alloc_utils::usm_host_allocator; - using ptrT = std::vector; - - usm_host_allocator_T ptr_allocator(exec_q); - std::shared_ptr host_ind_ptrs_shp = - std::make_shared(k, ptr_allocator); + auto host_ind_ptrs_shp = detail::allocate_and_copy_ptrs(exec_q, ind_ptrs); + auto host_ind_sh_st_shp = + detail::allocate_and_copy_sizes(exec_q, ind_sh_sts); + auto host_ind_offsets_shp = + detail::allocate_and_copy_sizes(exec_q, ind_offsets); - using usm_host_allocatorT = - dpnp::tensor::alloc_utils::usm_host_allocator; - using shT = std::vector; - - usm_host_allocatorT sz_allocator(exec_q); - std::shared_ptr host_ind_sh_st_shp = - std::make_shared(ind_sh_elems * (k + 1), sz_allocator); - - std::shared_ptr host_ind_offsets_shp = - std::make_shared(k, sz_allocator); - - std::shared_ptr host_orthog_sh_st_shp = - std::make_shared(3 * orthog_sh_elems, sz_allocator); - - std::shared_ptr host_along_sh_st_shp = - std::make_shared(2 * (k + ind_sh_elems), sz_allocator); - - std::copy(ind_sh_sts.begin(), ind_sh_sts.end(), - host_ind_sh_st_shp->begin()); - std::copy(ind_ptrs.begin(), ind_ptrs.end(), host_ind_ptrs_shp->begin()); - std::copy(ind_offsets.begin(), ind_offsets.end(), - host_ind_offsets_shp->begin()); - - const sycl::event &device_ind_ptrs_copy_ev = exec_q.copy( - host_ind_ptrs_shp->data(), device_ind_ptrs, host_ind_ptrs_shp->size()); - - const sycl::event &device_ind_sh_st_copy_ev = - exec_q.copy(host_ind_sh_st_shp->data(), device_ind_sh_st, - host_ind_sh_st_shp->size()); - - const sycl::event &device_ind_offsets_copy_ev = exec_q.copy( - host_ind_offsets_shp->data(), device_ind_offsets, - host_ind_offsets_shp->size()); + auto host_orthog_sh_st_shp = + detail::allocate_host_buffer(exec_q, 3 * orthog_sh_elems); + auto host_along_sh_st_shp = + detail::allocate_host_buffer(exec_q, 2 * (k + ind_sh_elems)); detail::copy_orthog_shape_strides( axis_start, inp_nd, k, ind_nd, orthog_sh_elems, inp_shape, inp_strides, @@ -291,15 +303,27 @@ std::vector inp_strides, arr_shape, arr_strides, host_along_sh_st_shp->data()); - const sycl::event &device_orthog_sh_st_copy_ev = exec_q.copy( - host_orthog_sh_st_shp->data(), device_orthog_sh_st, + const sycl::event device_ind_ptrs_copy_ev = + detail::copy_to_device(exec_q, host_ind_ptrs_shp->data(), + device_ind_ptrs, host_ind_ptrs_shp->size()); + + const sycl::event device_ind_sh_st_copy_ev = + detail::copy_to_device(exec_q, host_ind_sh_st_shp->data(), + device_ind_sh_st, host_ind_sh_st_shp->size()); + + const sycl::event device_ind_offsets_copy_ev = detail::copy_to_device( + exec_q, host_ind_offsets_shp->data(), device_ind_offsets, + host_ind_offsets_shp->size()); + + const sycl::event device_orthog_sh_st_copy_ev = detail::copy_to_device( + exec_q, host_orthog_sh_st_shp->data(), device_orthog_sh_st, host_orthog_sh_st_shp->size()); - const sycl::event &device_along_sh_st_copy_ev = exec_q.copy( - host_along_sh_st_shp->data(), device_along_sh_st, + const sycl::event device_along_sh_st_copy_ev = detail::copy_to_device( + exec_q, host_along_sh_st_shp->data(), device_along_sh_st, host_along_sh_st_shp->size()); - const sycl::event &shared_ptr_cleanup_ev = + const sycl::event shared_ptr_cleanup_ev = exec_q.submit([&](sycl::handler &cgh) { cgh.depends_on({device_along_sh_st_copy_ev, device_orthog_sh_st_copy_ev, @@ -314,11 +338,9 @@ std::vector }); host_task_events.push_back(shared_ptr_cleanup_ev); - std::vector sh_st_pack_deps{ - device_ind_ptrs_copy_ev, device_ind_sh_st_copy_ev, - device_ind_offsets_copy_ev, device_orthog_sh_st_copy_ev, - device_along_sh_st_copy_ev}; - return sh_st_pack_deps; + return {device_ind_ptrs_copy_ev, device_ind_sh_st_copy_ev, + device_ind_offsets_copy_ev, device_orthog_sh_st_copy_ev, + device_along_sh_st_copy_ev}; } /* Utility to parse python object py_ind into vector of `usm_ndarray`s */ From 149c9beeaaa92ddc95d66f435310a4aea8d305fa Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 17:28:33 -0700 Subject: [PATCH 5/7] add axis_end parameter to indexing functions simplifies logic --- dpnp/tensor/_copy_utils.py | 2 + dpnp/tensor/_indexing_functions.py | 18 ++++++++- dpnp/tensor/_searchsorted.py | 1 + dpnp/tensor/_set_functions.py | 2 + .../source/integer_advanced_indexing.cpp | 39 ++++++++++++++----- .../source/integer_advanced_indexing.hpp | 2 + dpnp/tensor/libtensor/source/tensor_ctors.cpp | 12 +++--- 7 files changed, 59 insertions(+), 17 deletions(-) diff --git a/dpnp/tensor/_copy_utils.py b/dpnp/tensor/_copy_utils.py index 3978e7345b1..39c4e9cf9fe 100644 --- a/dpnp/tensor/_copy_utils.py +++ b/dpnp/tensor/_copy_utils.py @@ -475,6 +475,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0): ind=inds, val=rhs, axis_start=p, + axis_end=p_end, mode=mode, sycl_queue=exec_q, depends=dep_ev, @@ -527,6 +528,7 @@ def _take_multi_index(ary, inds, p, mode=0): ind=inds, dst=res, axis_start=p, + axis_end=p_end, mode=mode, sycl_queue=exec_q, depends=dep_ev, diff --git a/dpnp/tensor/_indexing_functions.py b/dpnp/tensor/_indexing_functions.py index 9ea0a16bdd0..675e770e411 100644 --- a/dpnp/tensor/_indexing_functions.py +++ b/dpnp/tensor/_indexing_functions.py @@ -342,7 +342,14 @@ def put_vec_duplicates(vec, ind, vals): _manager = SequentialOrderManager[exec_q] deps_ev = _manager.submitted_events hev, put_ev = ti._put( - x, (indices,), rhs, axis, mode, sycl_queue=exec_q, depends=deps_ev + x, + (indices,), + rhs, + axis, + axis + 1, + mode, + sycl_queue=exec_q, + depends=deps_ev, ) _manager.add_event_pair(hev, put_ev) @@ -543,7 +550,14 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): _manager = SequentialOrderManager[exec_q] deps_ev = _manager.submitted_events hev, take_ev = ti._take( - x, (indices,), out, axis, mode, sycl_queue=exec_q, depends=deps_ev + x, + (indices,), + out, + axis, + axis + 1, + mode, + sycl_queue=exec_q, + depends=deps_ev, ) _manager.add_event_pair(hev, take_ev) diff --git a/dpnp/tensor/_searchsorted.py b/dpnp/tensor/_searchsorted.py index 6d3f8846012..5776e908ec6 100644 --- a/dpnp/tensor/_searchsorted.py +++ b/dpnp/tensor/_searchsorted.py @@ -160,6 +160,7 @@ def searchsorted( ind, res, axis, + axis + 1, wrap_out_of_bound_indices_mode, sycl_queue=q, depends=dep_evs, diff --git a/dpnp/tensor/_set_functions.py b/dpnp/tensor/_set_functions.py index 067de75c42c..e2e5851f131 100644 --- a/dpnp/tensor/_set_functions.py +++ b/dpnp/tensor/_set_functions.py @@ -383,6 +383,7 @@ def unique_inverse(x): ind=(sorting_ids,), dst=s, axis_start=0, + axis_end=1, mode=0, sycl_queue=exec_q, depends=[sort_ev], @@ -558,6 +559,7 @@ def unique_all(x: dpt.usm_ndarray) -> UniqueAllResult: ind=(sorting_ids,), dst=s, axis_start=0, + axis_end=1, mode=0, sycl_queue=exec_q, depends=[sort_ev], diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index 52e520d0f37..aa21222bc80 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -382,6 +382,7 @@ std::pair const py::object &py_ind, const dpnp::tensor::usm_ndarray &dst, int axis_start, + int axis_end, std::uint8_t mode, sycl::queue &exec_q, const std::vector &depends) @@ -395,7 +396,16 @@ std::pair } if (axis_start < 0) { - throw py::value_error("Axis cannot be negative."); + throw py::value_error("Axis start cannot be negative."); + } + + if (axis_end < axis_start) { + throw py::value_error( + "Axis end must be greater than or equal to axis start."); + } + + if (k != (axis_end - axis_start)) { + throw py::value_error("Number of indices must match axis range."); } if (mode != 0 && mode != 1) { @@ -412,9 +422,10 @@ std::pair auto sh_elems = std::max(src_nd, 1); - if (axis_start + k > sh_elems) { - throw py::value_error("Axes are out of range for array of dimension " + - std::to_string(src_nd)); + if (axis_end > sh_elems) { + throw py::value_error( + "Axis end is out of range for array of dimension " + + std::to_string(src_nd)); } if (src_nd == 0) { if (dst_nd != ind_nd) { @@ -612,6 +623,7 @@ std::pair const py::object &py_ind, const dpnp::tensor::usm_ndarray &val, int axis_start, + int axis_end, std::uint8_t mode, sycl::queue &exec_q, const std::vector &depends) @@ -620,12 +632,20 @@ std::pair int k = ind.size(); if (k == 0) { - // no indices to write to throw py::value_error("List of indices is empty."); } if (axis_start < 0) { - throw py::value_error("Axis cannot be negative."); + throw py::value_error("Axis start cannot be negative."); + } + + if (axis_end < axis_start) { + throw py::value_error( + "Axis end must be greater than or equal to axis start."); + } + + if (k != (axis_end - axis_start)) { + throw py::value_error("Number of indices must match axis range."); } if (mode != 0 && mode != 1) { @@ -642,9 +662,10 @@ std::pair auto sh_elems = std::max(dst_nd, 1); - if (axis_start + k > sh_elems) { - throw py::value_error("Axes are out of range for array of dimension " + - std::to_string(dst_nd)); + if (axis_end > sh_elems) { + throw py::value_error( + "Axis end is out of range for array of dimension " + + std::to_string(dst_nd)); } if (dst_nd == 0) { if (val_nd != ind_nd) { diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp index cd1a3733f3f..420a6ac277f 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.hpp @@ -53,6 +53,7 @@ extern std::pair const py::object &, const dpnp::tensor::usm_ndarray &, int, + int, std::uint8_t, sycl::queue &, const std::vector & = {}); @@ -62,6 +63,7 @@ extern std::pair const py::object &, const dpnp::tensor::usm_ndarray &, int, + int, std::uint8_t, sycl::queue &, const std::vector & = {}); diff --git a/dpnp/tensor/libtensor/source/tensor_ctors.cpp b/dpnp/tensor/libtensor/source/tensor_ctors.cpp index b18080f4721..f9cf43e2016 100644 --- a/dpnp/tensor/libtensor/source/tensor_ctors.cpp +++ b/dpnp/tensor/libtensor/source/tensor_ctors.cpp @@ -330,21 +330,21 @@ PYBIND11_MODULE(_tensor_impl, m) py::arg("depends") = py::list()); m.def("_take", &py_take, - "Takes elements at usm_ndarray indices `ind` and axes starting " - "at axis `axis_start` from array `src` and copies them " + "Takes elements at usm_ndarray indices `ind` from axes " + "[axis_start, axis_end) of array `src` and copies them " "into usm_ndarray `dst` synchronously." "Returns a tuple of events: (hev, ev)", py::arg("src"), py::arg("ind"), py::arg("dst"), py::arg("axis_start"), - py::arg("mode"), py::arg("sycl_queue"), + py::arg("axis_end"), py::arg("mode"), py::arg("sycl_queue"), py::arg("depends") = py::list()); m.def("_put", &py_put, - "Puts elements at usm_ndarray indices `ind` and axes starting " - "at axis `axis_start` into array `dst` from " + "Puts elements at usm_ndarray indices `ind` into axes " + "[axis_start, axis_end) of array `dst` from " "usm_ndarray `val` synchronously." "Returns a tuple of events: (hev, ev)", py::arg("dst"), py::arg("ind"), py::arg("val"), py::arg("axis_start"), - py::arg("mode"), py::arg("sycl_queue"), + py::arg("axis_end"), py::arg("mode"), py::arg("sycl_queue"), py::arg("depends") = py::list()); m.def("_eye", &usm_ndarray_eye, From f058c91646c45a66783ae90c41b2bb45e2000dc1 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 18:54:06 -0700 Subject: [PATCH 6/7] further refactoring of integer indexing --- .../source/integer_advanced_indexing.cpp | 335 ++++++++---------- 1 file changed, 140 insertions(+), 195 deletions(-) diff --git a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp index aa21222bc80..ea1f102347d 100644 --- a/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp +++ b/dpnp/tensor/libtensor/source/integer_advanced_indexing.cpp @@ -260,6 +260,84 @@ void process_index_arrays(const std::vector &ind, } } +void validate_axis_range(int axis_start, int axis_end, int k, int array_nd) +{ + if (axis_start < 0) { + throw py::value_error("Axis start cannot be negative."); + } + + if (axis_end < axis_start) { + throw py::value_error( + "Axis end must be greater than or equal to axis start."); + } + + if (k != (axis_end - axis_start)) { + throw py::value_error("Number of indices must match axis range."); + } + + int sh_elems = std::max(array_nd, 1); + if (axis_end > sh_elems) { + throw py::value_error( + "Axis end is out of range for array of dimension " + + std::to_string(array_nd)); + } +} + +void validate_output_shape(int inp_nd, int out_nd, int k, int ind_nd) +{ + int expected_out_nd = (inp_nd == 0) ? ind_nd : (inp_nd - k + ind_nd); + + if (out_nd != expected_out_nd) { + throw py::value_error( + "Output array has incorrect number of dimensions. " + "Expected " + + std::to_string(expected_out_nd) + ", got " + + std::to_string(out_nd)); + } +} + +std::size_t validate_and_compute_orthog_shape(int inp_nd, + int k, + int ind_nd, + int axis_start, + const py::ssize_t *inp_shape, + const py::ssize_t *out_shape) +{ + int orthog_nd = inp_nd - k; + std::size_t orthog_nelems = 1; + + for (int i = 0; i < orthog_nd; ++i) { + int inp_idx = (i < axis_start) ? i : i + k; + int out_idx = (i < axis_start) ? i : i + ind_nd; + + if (inp_shape[inp_idx] != out_shape[out_idx]) { + throw py::value_error("Orthogonal axes have mismatched shapes."); + } + orthog_nelems *= static_cast(inp_shape[inp_idx]); + } + + return orthog_nelems; +} + +std::size_t validate_and_compute_ind_nelems(int ind_nd, + int axis_start, + const py::ssize_t *ind_shape, + const py::ssize_t *out_shape) +{ + std::size_t ind_nelems = 1; + + for (int i = 0; i < ind_nd; ++i) { + ind_nelems *= static_cast(ind_shape[i]); + + if (ind_shape[i] != out_shape[axis_start + i]) { + throw py::value_error("Indices shape does not match output shape " + "along indexed axes."); + } + } + + return ind_nelems; +} + } // namespace detail std::vector @@ -390,82 +468,43 @@ std::pair std::vector ind = parse_py_ind(exec_q, py_ind); int k = ind.size(); - if (k == 0) { throw py::value_error("List of indices is empty."); } - if (axis_start < 0) { - throw py::value_error("Axis start cannot be negative."); - } - - if (axis_end < axis_start) { - throw py::value_error( - "Axis end must be greater than or equal to axis start."); - } - - if (k != (axis_end - axis_start)) { - throw py::value_error("Number of indices must match axis range."); - } - if (mode != 0 && mode != 1) { throw py::value_error("Mode must be 0 or 1."); } dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst); - const dpnp::tensor::usm_ndarray ind_rep = ind[0]; - int src_nd = src.get_ndim(); int dst_nd = dst.get_ndim(); - int ind_nd = ind_rep.get_ndim(); - auto sh_elems = std::max(src_nd, 1); + const dpnp::tensor::usm_ndarray &ind_rep = ind[0]; + int ind_nd = ind_rep.get_ndim(); + const py::ssize_t *ind_shape = ind_rep.get_shape_raw(); - if (axis_end > sh_elems) { - throw py::value_error( - "Axis end is out of range for array of dimension " + - std::to_string(src_nd)); - } - if (src_nd == 0) { - if (dst_nd != ind_nd) { - throw py::value_error( - "Destination is not of appropriate dimension for take kernel."); - } - } - else { - if (dst_nd != (src_nd - k + ind_nd)) { - throw py::value_error( - "Destination is not of appropriate dimension for take kernel."); - } - } + detail::validate_axis_range(axis_start, axis_end, k, src_nd); + detail::validate_output_shape(src_nd, dst_nd, k, ind_nd); const py::ssize_t *src_shape = src.get_shape_raw(); const py::ssize_t *dst_shape = dst.get_shape_raw(); - bool orthog_shapes_equal(true); - std::size_t orthog_nelems(1); - for (int i = 0; i < (src_nd - k); ++i) { - auto idx1 = (i < axis_start) ? i : i + k; - auto idx2 = (i < axis_start) ? i : i + ind_nd; + std::size_t orthog_nelems = detail::validate_and_compute_orthog_shape( + src_nd, k, ind_nd, axis_start, src_shape, dst_shape); - orthog_nelems *= static_cast(src_shape[idx1]); - orthog_shapes_equal = - orthog_shapes_equal && (src_shape[idx1] == dst_shape[idx2]); + if (orthog_nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); } - if (!orthog_shapes_equal) { - throw py::value_error( - "Axes of basic indices are not of matching shapes."); - } + std::size_t ind_nelems = detail::validate_and_compute_ind_nelems( + ind_nd, axis_start, ind_shape, dst_shape); - if (orthog_nelems == 0) { + if (ind_nelems == 0) { return std::make_pair(sycl::event{}, sycl::event{}); } - char *src_data = src.get_data(); - char *dst_data = dst.get_data(); - if (!dpnp::utils::queues_are_compatible(exec_q, {src, dst})) { throw py::value_error( "Execution queue is not compatible with allocation queues"); @@ -476,34 +515,15 @@ std::pair throw py::value_error("Array memory overlap."); } - py::ssize_t src_offset = py::ssize_t(0); - py::ssize_t dst_offset = py::ssize_t(0); - - int src_typenum = src.get_typenum(); - int dst_typenum = dst.get_typenum(); - auto array_types = td_ns::usm_ndarray_types(); - int src_type_id = array_types.typenum_to_lookup_id(src_typenum); - int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); + int src_type_id = array_types.typenum_to_lookup_id(src.get_typenum()); + int dst_type_id = array_types.typenum_to_lookup_id(dst.get_typenum()); if (src_type_id != dst_type_id) { throw py::type_error("Array data types are not the same."); } - const py::ssize_t *ind_shape = ind_rep.get_shape_raw(); - - int ind_typenum = ind_rep.get_typenum(); - int ind_type_id = array_types.typenum_to_lookup_id(ind_typenum); - - std::size_t ind_nelems(1); - for (int i = 0; i < ind_nd; ++i) { - ind_nelems *= static_cast(ind_shape[i]); - - if (!(ind_shape[i] == dst_shape[axis_start + i])) { - throw py::value_error( - "Indices shape does not match shape of axis in destination."); - } - } + int ind_type_id = array_types.typenum_to_lookup_id(ind_rep.get_typenum()); dpnp::tensor::validation::AmpleMemory::throw_if_not_ample( dst, orthog_nelems * ind_nelems); @@ -529,11 +549,10 @@ std::pair return std::make_pair(sycl::event{}, sycl::event{}); } + int orthog_sh_elems = std::max(src_nd - k, 1); + auto packed_ind_ptrs_owner = dpnp::tensor::alloc_utils::smart_malloc_device(k, exec_q); - char **packed_ind_ptrs = packed_ind_ptrs_owner.get(); - - // rearrange to past where indices shapes are checked // packed_ind_shapes_strides = [ind_shape, // ind[0] strides, // ..., @@ -541,15 +560,8 @@ std::pair auto packed_ind_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( (k + 1) * ind_sh_elems, exec_q); - py::ssize_t *packed_ind_shapes_strides = - packed_ind_shapes_strides_owner.get(); - auto packed_ind_offsets_owner = dpnp::tensor::alloc_utils::smart_malloc_device(k, exec_q); - py::ssize_t *packed_ind_offsets = packed_ind_offsets_owner.get(); - - int orthog_sh_elems = std::max(src_nd - k, 1); - // packed_shapes_strides = [src_shape[:axis] + src_shape[axis+k:], // src_strides[:axis] + src_strides[axis+k:], // dst_strides[:axis] + @@ -557,8 +569,6 @@ std::pair auto packed_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( 3 * orthog_sh_elems, exec_q); - py::ssize_t *packed_shapes_strides = packed_shapes_strides_owner.get(); - // packed_axes_shapes_strides = [src_shape[axis:axis+k], // src_strides[axis:axis+k], // dst_shape[axis:axis+ind.ndim], @@ -566,8 +576,6 @@ std::pair auto packed_axes_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( 2 * (k + ind_sh_elems), exec_q); - py::ssize_t *packed_axes_shapes_strides = - packed_axes_shapes_strides_owner.get(); auto src_strides = src.get_strides_vector(); auto dst_strides = dst.get_strides_vector(); @@ -576,11 +584,12 @@ std::pair host_task_events.reserve(2); std::vector pack_deps = _populate_kernel_params( - exec_q, host_task_events, packed_ind_ptrs, packed_ind_shapes_strides, - packed_ind_offsets, packed_shapes_strides, packed_axes_shapes_strides, - src_shape, dst_shape, src_strides, dst_strides, ind_sh_sts, ind_ptrs, - ind_offsets, axis_start, k, ind_nd, src_nd, orthog_sh_elems, - ind_sh_elems); + exec_q, host_task_events, packed_ind_ptrs_owner.get(), + packed_ind_shapes_strides_owner.get(), packed_ind_offsets_owner.get(), + packed_shapes_strides_owner.get(), + packed_axes_shapes_strides_owner.get(), src_shape, dst_shape, + src_strides, dst_strides, ind_sh_sts, ind_ptrs, ind_offsets, axis_start, + k, ind_nd, src_nd, orthog_sh_elems, ind_sh_elems); std::vector all_deps; all_deps.reserve(depends.size() + pack_deps.size()); @@ -597,13 +606,15 @@ std::pair std::to_string(ind_type_id)); } + static constexpr py::ssize_t zero_offset(0); sycl::event take_generic_ev = fn(exec_q, orthog_nelems, ind_nelems, orthog_sh_elems, ind_sh_elems, k, - packed_shapes_strides, packed_axes_shapes_strides, - packed_ind_shapes_strides, src_data, dst_data, packed_ind_ptrs, - src_offset, dst_offset, packed_ind_offsets, all_deps); + packed_shapes_strides_owner.get(), + packed_axes_shapes_strides_owner.get(), + packed_ind_shapes_strides_owner.get(), src.get_data(), + dst.get_data(), packed_ind_ptrs_owner.get(), zero_offset, + zero_offset, packed_ind_offsets_owner.get(), all_deps); - // free packed temporaries sycl::event temporaries_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free( exec_q, {take_generic_ev}, packed_shapes_strides_owner, @@ -629,85 +640,45 @@ std::pair const std::vector &depends) { std::vector ind = parse_py_ind(exec_q, py_ind); - int k = ind.size(); + int k = ind.size(); if (k == 0) { throw py::value_error("List of indices is empty."); } - if (axis_start < 0) { - throw py::value_error("Axis start cannot be negative."); - } - - if (axis_end < axis_start) { - throw py::value_error( - "Axis end must be greater than or equal to axis start."); - } - - if (k != (axis_end - axis_start)) { - throw py::value_error("Number of indices must match axis range."); - } - if (mode != 0 && mode != 1) { throw py::value_error("Mode must be 0 or 1."); } dpnp::tensor::validation::CheckWritable::throw_if_not_writable(dst); - const dpnp::tensor::usm_ndarray ind_rep = ind[0]; - int dst_nd = dst.get_ndim(); int val_nd = val.get_ndim(); - int ind_nd = ind_rep.get_ndim(); - auto sh_elems = std::max(dst_nd, 1); - - if (axis_end > sh_elems) { - throw py::value_error( - "Axis end is out of range for array of dimension " + - std::to_string(dst_nd)); - } - if (dst_nd == 0) { - if (val_nd != ind_nd) { - throw py::value_error("Destination is not of appropriate dimension " - "for put function."); - } - } - else { - if (val_nd != (dst_nd - k + ind_nd)) { - throw py::value_error("Destination is not of appropriate dimension " - "for put function."); - } - } + const dpnp::tensor::usm_ndarray &ind_rep = ind[0]; + int ind_nd = ind_rep.get_ndim(); + const py::ssize_t *ind_shape = ind_rep.get_shape_raw(); - std::size_t dst_nelems = dst.get_size(); + detail::validate_axis_range(axis_start, axis_end, k, dst_nd); + detail::validate_output_shape(dst_nd, val_nd, k, ind_nd); const py::ssize_t *dst_shape = dst.get_shape_raw(); const py::ssize_t *val_shape = val.get_shape_raw(); - bool orthog_shapes_equal(true); - std::size_t orthog_nelems(1); - for (int i = 0; i < (dst_nd - k); ++i) { - auto idx1 = (i < axis_start) ? i : i + k; - auto idx2 = (i < axis_start) ? i : i + ind_nd; + std::size_t orthog_nelems = detail::validate_and_compute_orthog_shape( + dst_nd, k, ind_nd, axis_start, dst_shape, val_shape); - orthog_nelems *= static_cast(dst_shape[idx1]); - orthog_shapes_equal = - orthog_shapes_equal && (dst_shape[idx1] == val_shape[idx2]); + if (orthog_nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); } - if (!orthog_shapes_equal) { - throw py::value_error( - "Axes of basic indices are not of matching shapes."); - } + std::size_t ind_nelems = detail::validate_and_compute_ind_nelems( + ind_nd, axis_start, ind_shape, val_shape); - if (orthog_nelems == 0) { - return std::make_pair(sycl::event(), sycl::event()); + if (ind_nelems == 0) { + return std::make_pair(sycl::event{}, sycl::event{}); } - char *dst_data = dst.get_data(); - char *val_data = val.get_data(); - if (!dpnp::utils::queues_are_compatible(exec_q, {dst, val})) { throw py::value_error( "Execution queue is not compatible with allocation queues"); @@ -718,36 +689,18 @@ std::pair throw py::value_error("Arrays index overlapping segments of memory"); } - py::ssize_t dst_offset = py::ssize_t(0); - py::ssize_t val_offset = py::ssize_t(0); - - dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst, dst_nelems); - - int dst_typenum = dst.get_typenum(); - int val_typenum = val.get_typenum(); - auto array_types = td_ns::usm_ndarray_types(); - int dst_type_id = array_types.typenum_to_lookup_id(dst_typenum); - int val_type_id = array_types.typenum_to_lookup_id(val_typenum); + int dst_type_id = array_types.typenum_to_lookup_id(dst.get_typenum()); + int val_type_id = array_types.typenum_to_lookup_id(val.get_typenum()); if (dst_type_id != val_type_id) { throw py::type_error("Array data types are not the same."); } - const py::ssize_t *ind_shape = ind_rep.get_shape_raw(); - - int ind_typenum = ind_rep.get_typenum(); - int ind_type_id = array_types.typenum_to_lookup_id(ind_typenum); + int ind_type_id = array_types.typenum_to_lookup_id(ind_rep.get_typenum()); - std::size_t ind_nelems(1); - for (int i = 0; i < ind_nd; ++i) { - ind_nelems *= static_cast(ind_shape[i]); - - if (!(ind_shape[i] == val_shape[axis_start + i])) { - throw py::value_error( - "Indices shapes does not match shape of axis in vals."); - } - } + dpnp::tensor::validation::AmpleMemory::throw_if_not_ample(dst, + dst.get_size()); auto ind_sh_elems = std::max(ind_nd, 1); @@ -768,10 +721,10 @@ std::pair return std::make_pair(sycl::event{}, sycl::event{}); } + int orthog_sh_elems = std::max(dst_nd - k, 1); + auto packed_ind_ptrs_owner = dpnp::tensor::alloc_utils::smart_malloc_device(k, exec_q); - char **packed_ind_ptrs = packed_ind_ptrs_owner.get(); - // packed_ind_shapes_strides = [ind_shape, // ind[0] strides, // ..., @@ -779,15 +732,8 @@ std::pair auto packed_ind_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( (k + 1) * ind_sh_elems, exec_q); - py::ssize_t *packed_ind_shapes_strides = - packed_ind_shapes_strides_owner.get(); - auto packed_ind_offsets_owner = dpnp::tensor::alloc_utils::smart_malloc_device(k, exec_q); - py::ssize_t *packed_ind_offsets = packed_ind_offsets_owner.get(); - - int orthog_sh_elems = std::max(dst_nd - k, 1); - // packed_shapes_strides = [dst_shape[:axis] + dst_shape[axis+k:], // dst_strides[:axis] + dst_strides[axis+k:], // val_strides[:axis] + @@ -795,8 +741,6 @@ std::pair auto packed_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( 3 * orthog_sh_elems, exec_q); - py::ssize_t *packed_shapes_strides = packed_shapes_strides_owner.get(); - // packed_axes_shapes_strides = [dst_shape[axis:axis+k], // dst_strides[axis:axis+k], // val_shape[axis:axis+ind.ndim], @@ -804,8 +748,6 @@ std::pair auto packed_axes_shapes_strides_owner = dpnp::tensor::alloc_utils::smart_malloc_device( 2 * (k + ind_sh_elems), exec_q); - py::ssize_t *packed_axes_shapes_strides = - packed_axes_shapes_strides_owner.get(); auto dst_strides = dst.get_strides_vector(); auto val_strides = val.get_strides_vector(); @@ -814,11 +756,12 @@ std::pair host_task_events.reserve(2); std::vector pack_deps = _populate_kernel_params( - exec_q, host_task_events, packed_ind_ptrs, packed_ind_shapes_strides, - packed_ind_offsets, packed_shapes_strides, packed_axes_shapes_strides, - dst_shape, val_shape, dst_strides, val_strides, ind_sh_sts, ind_ptrs, - ind_offsets, axis_start, k, ind_nd, dst_nd, orthog_sh_elems, - ind_sh_elems); + exec_q, host_task_events, packed_ind_ptrs_owner.get(), + packed_ind_shapes_strides_owner.get(), packed_ind_offsets_owner.get(), + packed_shapes_strides_owner.get(), + packed_axes_shapes_strides_owner.get(), dst_shape, val_shape, + dst_strides, val_strides, ind_sh_sts, ind_ptrs, ind_offsets, axis_start, + k, ind_nd, dst_nd, orthog_sh_elems, ind_sh_elems); std::vector all_deps; all_deps.reserve(depends.size() + pack_deps.size()); @@ -835,13 +778,15 @@ std::pair std::to_string(ind_type_id)); } + static constexpr py::ssize_t zero_offset(0); sycl::event put_generic_ev = fn(exec_q, orthog_nelems, ind_nelems, orthog_sh_elems, ind_sh_elems, k, - packed_shapes_strides, packed_axes_shapes_strides, - packed_ind_shapes_strides, dst_data, val_data, packed_ind_ptrs, - dst_offset, val_offset, packed_ind_offsets, all_deps); + packed_shapes_strides_owner.get(), + packed_axes_shapes_strides_owner.get(), + packed_ind_shapes_strides_owner.get(), dst.get_data(), + val.get_data(), packed_ind_ptrs_owner.get(), zero_offset, + zero_offset, packed_ind_offsets_owner.get(), all_deps); - // free packed temporaries sycl::event temporaries_cleanup_ev = dpnp::tensor::alloc_utils::async_smart_free( exec_q, {put_generic_ev}, packed_shapes_strides_owner, From c3eeb0134b0f5fdf7c3e320d964c8f5e5bd52e74 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 29 May 2026 21:06:21 -0700 Subject: [PATCH 7/7] add axis_end to _take_index --- dpnp/dpnp_iface_indexing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 333dbc2349c..90e76a6c9e9 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -312,6 +312,7 @@ def _take_index(x, inds, axis, q, usm_type, out=None, mode=0): ind=(inds,), dst=out, axis_start=axis, + axis_end=axis_end, mode=mode, sycl_queue=q, depends=dep_evs,