16template <
class T,
class R>
17inline void asTensor(
const T& arg, R& ret);
25template <
class T,
class S>
26inline bool has_shape_begin(
const T& t,
const S& s)
28 return s.dimension() >= t.dimension() &&
29 std::equal(t.shape().cbegin(), t.shape().cend(), s.shape().begin());
36struct is_std_array : std::false_type {};
38template <
class T,
size_t N>
39struct is_std_array<std::array<T, N>> : std::true_type {};
44template <
class T, std::
size_t N>
45auto std_array_size_impl(
const std::array<T, N>&) -> std::integral_constant<std::size_t, N>;
51using std_array_size =
decltype(std_array_size_impl(std::declval<const T&>()));
56template <
class I, std::
size_t L>
57std::array<I, L> to_std_array(
const I (&shape)[L])
60 std::copy(&shape[0], &shape[0] + L, r.begin());
67template <
class T,
class R,
typename =
void>
68struct asTensor_write {
69 static void impl(
const T& arg, R& ret)
73 using strides_type =
typename T::strides_type::value_type;
74 std::vector<strides_type> ret_strides(ret.dimension());
75 std::copy(arg.strides().begin(), arg.strides().end(), ret_strides.begin());
76 std::fill(ret_strides.begin() + arg.dimension(), ret_strides.end(), 0);
77 ret = xt::strided_view(
78 arg, ret.shape(), std::move(ret_strides), 0ul, xt::layout_type::dynamic
86template <
class T,
class R>
90 typename std::enable_if_t<xt::has_fixed_rank_t<T>::value && xt::has_fixed_rank_t<R>::value>> {
91 static void impl(
const T& arg, R& ret)
93 static_assert(T::rank <= R::rank,
"Return must be fixed rank too");
95 using strides_type =
typename T::strides_type::value_type;
96 std::array<strides_type, R::rank> ret_strides;
97 std::copy(arg.strides().begin(), arg.strides().end(), ret_strides.begin());
98 std::fill(ret_strides.begin() + T::rank, ret_strides.end(), 0);
99 ret = xt::strided_view(
100 arg, ret.shape(), std::move(ret_strides), 0ul, xt::layout_type::dynamic
108template <
class T,
class S,
typename =
void>
109struct asTensor_allocate {
110 static auto impl(
const T& arg,
const S& shape)
112 using value_type =
typename T::value_type;
113 size_t dim = arg.dimension();
114 size_t rank = shape.size();
115 std::vector<size_t> ret_shape(dim + rank);
116 std::copy(arg.shape().begin(), arg.shape().end(), ret_shape.begin());
117 std::copy(shape.begin(), shape.end(), ret_shape.begin() + dim);
118 xt::xarray<value_type> ret(ret_shape);
127template <
class T,
class S>
128struct asTensor_allocate<T, S, typename std::enable_if_t<detail::is_std_array<S>::value>> {
129 static auto impl(
const T& arg,
const S& shape)
131 using value_type =
typename T::value_type;
132 static constexpr size_t dim = T::rank;
133 static constexpr size_t rank = std_array_size<S>::value;
134 std::array<size_t, dim + rank> ret_shape;
135 std::copy(arg.shape().begin(), arg.shape().end(), ret_shape.begin());
136 std::copy(shape.begin(), shape.end(), ret_shape.begin() + dim);
138 detail::asTensor_write<std::decay_t<T>,
decltype(ret)>::impl(arg, ret);
152template <
class T,
class R>
155 detail::asTensor_write<std::decay_t<T>, std::decay_t<R>>::impl(arg, ret);
166template <
class T,
class S>
169 return detail::asTensor_allocate<std::decay_t<T>, std::decay_t<S>>::impl(arg, shape);
175template <
class T,
class I,
size_t L>
176inline auto AsTensor(
const T& arg,
const I (&shape)[L])
178 auto s = detail::to_std_array(shape);
179 return detail::asTensor_allocate<std::decay_t<T>,
decltype(s)>::impl(arg, s);
191template <
size_t rank,
class T>
194 std::array<size_t, rank> shape;
195 std::fill(shape.begin(), shape.end(), n);
196 return detail::asTensor_allocate<std::decay_t<T>,
decltype(shape)>::impl(arg, shape);
209inline auto AsTensor(
size_t rank,
const T& arg,
size_t n)
211 std::vector<size_t> shape(rank);
212 std::fill(shape.begin(), shape.end(), n);
213 return detail::asTensor_allocate<std::decay_t<T>,
decltype(shape)>::impl(arg, shape);
228 if (arg.shape(1) == 3ul) {
232 T ret = xt::zeros<typename T::value_type>(std::array<size_t, 2>{arg.shape(0), 3ul});
234 if (arg.shape(1) == 2ul) {
235 xt::view(ret, xt::all(), xt::keep(0, 1)) = arg;
238 if (arg.shape(1) == 1ul) {
239 xt::view(ret, xt::all(), xt::keep(0)) = arg;