// // Copyright (C) 2011-15 DyND Developers // BSD 2-Clause License, see LICENSE.txt // #pragma once #include #include #include namespace dynd { namespace kernels { template struct permute_ck; template struct permute_ck : nd::base_strided_kernel, N> { typedef permute_ck self_type; intptr_t perm[N]; permute_ck(const intptr_t *perm) { memcpy(this->perm, perm, sizeof(this->perm)); } void single(char *dst, char *const *src) { char *src_inv_perm[N]; inv(src_inv_perm, dst, src); nd::kernel_prefix *child = this->get_child(); kernel_single_t single = child->get_function(); single(child, NULL, src_inv_perm); } void strided(char *dst, intptr_t dst_stride, char *const *src, const intptr_t *src_stride, size_t count) { char *src_inv_perm[N]; inv(src_inv_perm, dst, src); intptr_t src_stride_inv_perm[N]; inv(src_stride_inv_perm, dst_stride, src_stride); nd::kernel_prefix *child = this->get_child(); kernel_strided_t strided = child->get_function(); strided(child, NULL, 0, src_inv_perm, src_stride_inv_perm, count); } static void instantiate(char *static_data, char *DYND_UNUSED(data), nd::kernel_builder *ckb, const ndt::type &dst_tp, const char *dst_arrmeta, intptr_t nsrc, const ndt::type *src_tp, const char *const *src_arrmeta, kernel_request_t kernreq, intptr_t nkwd, const nd::array *kwds, const std::map &tp_vars) { const std::pair> *data = reinterpret_cast> *>(static_data); const nd::base_callable *child = data->first.get(); const intptr_t *perm = data->second.data(); ndt::type src_tp_inv[N]; inv(src_tp_inv, dst_tp, src_tp, perm); const char *src_arrmeta_inv[N]; inv(src_arrmeta_inv, dst_arrmeta, src_arrmeta, perm); ckb->emplace_back(kernreq, detail::make_array_wrapper(perm)); child->instantiate(const_cast(child->static_data()), NULL, ckb, ndt::make_type(), NULL, nsrc, src_tp_inv, src_arrmeta_inv, kernreq | kernel_request_data_only, nkwd, kwds, tp_vars); } private: static void inv(ndt::type *src_inv, const ndt::type &dst, const ndt::type *src, const intptr_t *perm) { for (intptr_t i = 0; i < N; ++i) { intptr_t j = perm[i]; if (j == -1) { src_inv[i] = dst; } else { src_inv[i] = src[j]; } } } template static void inv(T *src_inv, const T &dst, const T *src, const intptr_t *perm) { for (intptr_t i = 0; i < N; ++i) { intptr_t j = perm[i]; if (j == -1) { src_inv[i] = dst; } else { src_inv[i] = src[j]; } } } void inv(ndt::type *src_inv, const ndt::type &dst, const ndt::type *src) { return inv(src_inv, dst, src, perm); } template void inv(T *src_inv, const T &dst, const T *src) { return inv(src_inv, dst, src, perm); } }; template struct permute_ck : nd::base_strided_kernel, N> { typedef permute_ck self_type; intptr_t perm[N]; permute_ck(const intptr_t *perm) { memcpy(this->perm, perm, sizeof(this->perm)); } void single(char *dst, char *const *src) { char *src_inv_perm[N]; inv_permute(src_inv_perm, src, perm); nd::kernel_prefix *child = this->get_child(); kernel_single_t single = child->get_function(); single(dst, src_inv_perm, child); } static void instantiate(char *static_data, char *DYND_UNUSED(data), nd::kernel_builder *ckb, const ndt::type &dst_tp, const char *dst_arrmeta, intptr_t nsrc, const ndt::type *src_tp, const char *const *src_arrmeta, kernel_request_t kernreq, intptr_t nkwd, const nd::array *kwds, const std::map &tp_vars) { const std::pair> *data = reinterpret_cast> *>(static_data); const nd::base_callable *child = data->first.get(); const intptr_t *perm = data->second.data(); ndt::type src_tp_inv[N]; inv_permute(src_tp_inv, src_tp, perm); const char *src_arrmeta_inv[N]; inv_permute(src_arrmeta_inv, src_arrmeta, perm); ckb->emplace_back(kernreq, detail::make_array_wrapper(perm)); child->instantiate(const_cast(child->static_data()), NULL, ckb, dst_tp, dst_arrmeta, nsrc, src_tp_inv, src_arrmeta_inv, kernreq, nkwd, kwds, tp_vars); } private: template static void inv_permute(T *src_inv, const T *src, const intptr_t *perm) { for (intptr_t i = 0; i < N; ++i) { src_inv[i] = src[perm[i]]; } } }; } // dynd::kernels namespace nd { namespace functional { DYND_API callable permute(const callable &child, const std::vector &perm); } // namespace dynd::nd::functional } // namespace dynd::nd } // namespace dynd