From d951f51e01c768526f208c3922f82a00cb1b28f7 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sat, 30 May 2026 02:25:06 -0700 Subject: [PATCH] Rewrite eye implementation --- .../include/kernels/constructors.hpp | 64 +++++++++++-------- dpnp/tensor/libtensor/source/eye_ctor.cpp | 42 +++++------- 2 files changed, 53 insertions(+), 53 deletions(-) diff --git a/dpnp/tensor/libtensor/include/kernels/constructors.hpp b/dpnp/tensor/libtensor/include/kernels/constructors.hpp index a4b2ade646c..f8ef825e3be 100644 --- a/dpnp/tensor/libtensor/include/kernels/constructors.hpp +++ b/dpnp/tensor/libtensor/include/kernels/constructors.hpp @@ -345,9 +345,11 @@ sycl::event full_strided_impl(sycl::queue &q, typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &, std::size_t nelems, // num_elements - ssize_t start, - ssize_t end, - ssize_t step, + ssize_t rows, + ssize_t cols, + ssize_t k, + ssize_t stride0, + ssize_t stride1, char *, // dst_data_ptr const std::vector &); @@ -356,29 +358,30 @@ class EyeFunctor { private: Ty *p = nullptr; - ssize_t start_v; - ssize_t end_v; - ssize_t step_v; + ssize_t k_; + ssize_t stride0_; + ssize_t stride1_; public: EyeFunctor(char *dst_p, - const ssize_t v0, - const ssize_t v1, - const ssize_t dv) - : p(reinterpret_cast(dst_p)), start_v(v0), end_v(v1), step_v(dv) + const ssize_t k, + const ssize_t stride0, + const ssize_t stride1) + : p(reinterpret_cast(dst_p)), k_(k), stride0_(stride0), + stride1_(stride1) { } - void operator()(sycl::id<1> wiid) const + void operator()(sycl::id<2> idx) const { - Ty set_v = 0; - ssize_t i = static_cast(wiid.get(0)); - if (i >= start_v and i <= end_v) { - if ((i - start_v) % step_v == 0) { - set_v = 1; - } - } - p[i] = set_v; + const ssize_t row = static_cast(idx[0]); + const ssize_t col = static_cast(idx[1]); + + // k-th diagonal: col - row == k + const Ty set_v = static_cast(col - row == k_); + + const ssize_t offset = row * stride0_ + col * stride1_; + p[offset] = set_v; } }; @@ -387,9 +390,12 @@ class EyeFunctor * * @param exec_q Sycl queue to which kernel is submitted for execution. * @param nelems Number of elements to assign. - * @param start Position of the first non-zero value. - * @param end Position of the last non-zero value. - * @param step Number of array elements between non-zeros. + * @param rows Number of rows in the matrix. + * @param cols Number of columns in the matrix. + * @param k Diagonal offset (0 for main diagonal, positive for upper, + * negative for lower). + * @param stride0 Stride for the first dimension (rows). + * @param stride1 Stride for the second dimension (columns). * @param array_data Kernel accessible USM pointer for the destination array. * @param depends List of events to wait for before starting computations, if * any. @@ -400,9 +406,11 @@ class EyeFunctor template sycl::event eye_impl(sycl::queue &exec_q, std::size_t nelems, - const ssize_t start, - const ssize_t end, - const ssize_t step, + const ssize_t rows, + const ssize_t cols, + const ssize_t k, + const ssize_t stride0, + const ssize_t stride1, char *array_data, const std::vector &depends) { @@ -413,8 +421,10 @@ sycl::event eye_impl(sycl::queue &exec_q, using KernelName = eye_kernel; using Impl = EyeFunctor; - cgh.parallel_for(sycl::range<1>{nelems}, - Impl(array_data, start, end, step)); + cgh.parallel_for( + sycl::range<2>{static_cast(rows), + static_cast(cols)}, + Impl(array_data, k, stride0, stride1)); }); return eye_event; diff --git a/dpnp/tensor/libtensor/source/eye_ctor.cpp b/dpnp/tensor/libtensor/source/eye_ctor.cpp index ea3765d04e1..edac74000b8 100644 --- a/dpnp/tensor/libtensor/source/eye_ctor.cpp +++ b/dpnp/tensor/libtensor/source/eye_ctor.cpp @@ -32,7 +32,6 @@ /// This file defines functions of dpnp.tensor._tensor_impl extensions //===--------------------------------------------------------------------===// -#include #include #include #include @@ -53,8 +52,6 @@ namespace td_ns = dpnp::tensor::type_dispatch; namespace dpnp::tensor::py_internal { -using dpnp::utils::keep_args_alive; - using dpnp::tensor::kernels::constructors::eye_fn_ptr_t; static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types]; @@ -64,8 +61,6 @@ std::pair sycl::queue &exec_q, const std::vector &depends) { - // dst must be 2D - if (dst.get_ndim() != 2) { throw py::value_error( "usm_ndarray_eye: Expecting 2D array to populate"); @@ -82,7 +77,7 @@ std::pair int dst_typenum = dst.get_typenum(); int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum); - const py::ssize_t nelem = dst.get_size(); + const py::ssize_t nelems = dst.get_size(); const py::ssize_t rows = dst.get_shape(0); const py::ssize_t cols = dst.get_shape(1); if (rows == 0 || cols == 0) { @@ -96,34 +91,29 @@ std::pair throw py::value_error("USM array is not contiguous"); } - py::ssize_t start; - if (is_dst_c_contig) { - start = (k < 0) ? -k * cols : k; - } - else { - start = (k < 0) ? -k : k * rows; - } - const py::ssize_t *strides = dst.get_strides_raw(); - py::ssize_t step; + py::ssize_t stride0, stride1; if (strides == nullptr) { - step = (is_dst_c_contig) ? cols + 1 : rows + 1; + if (is_dst_c_contig) { + stride0 = cols; + stride1 = 1; + } + else { + stride0 = 1; + stride1 = rows; + } } else { - step = strides[0] + strides[1]; + stride0 = strides[0]; + stride1 = strides[1]; } - const py::ssize_t length = std::min({rows, cols, rows + k, cols - k}); - const py::ssize_t end = start + step * (length - 1); - - char *dst_data = dst.get_data(); - sycl::event eye_event; - auto fn = eye_dispatch_vector[dst_typeid]; + sycl::event eye_event = + fn(exec_q, static_cast(nelems), rows, cols, k, stride0, + stride1, dst.get_data(), depends); - eye_event = fn(exec_q, static_cast(nelem), start, end, step, - dst_data, depends); - + using dpnp::utils::keep_args_alive; return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}), eye_event); }