Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 37 additions & 27 deletions dpnp/tensor/libtensor/include/kernels/constructors.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,11 @@ sycl::event full_strided_impl(sycl::queue &q,

typedef sycl::event (*eye_fn_ptr_t)(sycl::queue &,
std::size_t nelems, // num_elements
ssize_t start,
ssize_t end,
ssize_t step,
ssize_t rows,
ssize_t cols,
ssize_t k,
ssize_t stride0,
ssize_t stride1,
char *, // dst_data_ptr
const std::vector<sycl::event> &);

Expand All @@ -356,29 +358,30 @@ class EyeFunctor
{
private:
Ty *p = nullptr;
ssize_t start_v;
ssize_t end_v;
ssize_t step_v;
ssize_t k_;
ssize_t stride0_;
ssize_t stride1_;

public:
EyeFunctor(char *dst_p,
const ssize_t v0,
const ssize_t v1,
const ssize_t dv)
: p(reinterpret_cast<Ty *>(dst_p)), start_v(v0), end_v(v1), step_v(dv)
const ssize_t k,
const ssize_t stride0,
const ssize_t stride1)
: p(reinterpret_cast<Ty *>(dst_p)), k_(k), stride0_(stride0),
stride1_(stride1)
{
}

void operator()(sycl::id<1> wiid) const
void operator()(sycl::id<2> idx) const
{
Ty set_v = 0;
ssize_t i = static_cast<ssize_t>(wiid.get(0));
if (i >= start_v and i <= end_v) {
if ((i - start_v) % step_v == 0) {
set_v = 1;
}
}
p[i] = set_v;
const ssize_t row = static_cast<ssize_t>(idx[0]);
const ssize_t col = static_cast<ssize_t>(idx[1]);

// k-th diagonal: col - row == k
const Ty set_v = static_cast<Ty>(col - row == k_);

const ssize_t offset = row * stride0_ + col * stride1_;
p[offset] = set_v;
}
};

Expand All @@ -387,9 +390,12 @@ class EyeFunctor
*
* @param exec_q Sycl queue to which kernel is submitted for execution.
* @param nelems Number of elements to assign.
* @param start Position of the first non-zero value.
* @param end Position of the last non-zero value.
* @param step Number of array elements between non-zeros.
* @param rows Number of rows in the matrix.
* @param cols Number of columns in the matrix.
* @param k Diagonal offset (0 for main diagonal, positive for upper,
* negative for lower).
* @param stride0 Stride for the first dimension (rows).
* @param stride1 Stride for the second dimension (columns).
* @param array_data Kernel accessible USM pointer for the destination array.
* @param depends List of events to wait for before starting computations, if
* any.
Expand All @@ -400,9 +406,11 @@ class EyeFunctor
template <typename Ty>
sycl::event eye_impl(sycl::queue &exec_q,
std::size_t nelems,
const ssize_t start,
const ssize_t end,
const ssize_t step,
const ssize_t rows,
const ssize_t cols,
const ssize_t k,
const ssize_t stride0,
const ssize_t stride1,
char *array_data,
const std::vector<sycl::event> &depends)
{
Expand All @@ -413,8 +421,10 @@ sycl::event eye_impl(sycl::queue &exec_q,
using KernelName = eye_kernel<Ty>;
using Impl = EyeFunctor<Ty>;

cgh.parallel_for<KernelName>(sycl::range<1>{nelems},
Impl(array_data, start, end, step));
cgh.parallel_for<KernelName>(
sycl::range<2>{static_cast<std::size_t>(rows),
static_cast<std::size_t>(cols)},
Impl(array_data, k, stride0, stride1));
});

return eye_event;
Expand Down
42 changes: 16 additions & 26 deletions dpnp/tensor/libtensor/source/eye_ctor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
/// This file defines functions of dpnp.tensor._tensor_impl extensions
//===--------------------------------------------------------------------===//

#include <algorithm>
#include <cstddef>
#include <utility>
#include <vector>
Expand All @@ -53,8 +52,6 @@ namespace td_ns = dpnp::tensor::type_dispatch;
namespace dpnp::tensor::py_internal
{

using dpnp::utils::keep_args_alive;

using dpnp::tensor::kernels::constructors::eye_fn_ptr_t;
static eye_fn_ptr_t eye_dispatch_vector[td_ns::num_types];

Expand All @@ -64,8 +61,6 @@ std::pair<sycl::event, sycl::event>
sycl::queue &exec_q,
const std::vector<sycl::event> &depends)
{
// dst must be 2D

if (dst.get_ndim() != 2) {
throw py::value_error(
"usm_ndarray_eye: Expecting 2D array to populate");
Expand All @@ -82,7 +77,7 @@ std::pair<sycl::event, sycl::event>
int dst_typenum = dst.get_typenum();
int dst_typeid = array_types.typenum_to_lookup_id(dst_typenum);

const py::ssize_t nelem = dst.get_size();
const py::ssize_t nelems = dst.get_size();
const py::ssize_t rows = dst.get_shape(0);
const py::ssize_t cols = dst.get_shape(1);
if (rows == 0 || cols == 0) {
Expand All @@ -96,34 +91,29 @@ std::pair<sycl::event, sycl::event>
throw py::value_error("USM array is not contiguous");
}

py::ssize_t start;
if (is_dst_c_contig) {
start = (k < 0) ? -k * cols : k;
}
else {
start = (k < 0) ? -k : k * rows;
}

const py::ssize_t *strides = dst.get_strides_raw();
py::ssize_t step;
py::ssize_t stride0, stride1;
if (strides == nullptr) {
step = (is_dst_c_contig) ? cols + 1 : rows + 1;
if (is_dst_c_contig) {
stride0 = cols;
stride1 = 1;
}
else {
stride0 = 1;
stride1 = rows;
}
}
else {
step = strides[0] + strides[1];
stride0 = strides[0];
stride1 = strides[1];
}

const py::ssize_t length = std::min({rows, cols, rows + k, cols - k});
const py::ssize_t end = start + step * (length - 1);

char *dst_data = dst.get_data();
sycl::event eye_event;

auto fn = eye_dispatch_vector[dst_typeid];
sycl::event eye_event =
fn(exec_q, static_cast<std::size_t>(nelems), rows, cols, k, stride0,
stride1, dst.get_data(), depends);

eye_event = fn(exec_q, static_cast<std::size_t>(nelem), start, end, step,
dst_data, depends);

using dpnp::utils::keep_args_alive;
return std::make_pair(keep_args_alive(exec_q, {dst}, {eye_event}),
eye_event);
}
Expand Down
Loading