Ginkgo Generated from branch based on main. Ginkgo version 1.11.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
range.hpp
1// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
2//
3// SPDX-License-Identifier: BSD-3-Clause
4
5#ifndef GKO_PUBLIC_CORE_BASE_RANGE_HPP_
6#define GKO_PUBLIC_CORE_BASE_RANGE_HPP_
7
8
9#include <type_traits>
10
11#include <ginkgo/core/base/math.hpp>
12#include <ginkgo/core/base/types.hpp>
13#include <ginkgo/core/base/utils.hpp>
14
15
16namespace gko {
17
18
46struct span {
54 GKO_ATTRIBUTES constexpr span(size_type point) noexcept
55 : span{point, point + 1}
56 {}
57
64 GKO_ATTRIBUTES constexpr span(size_type begin, size_type end) noexcept
65 : begin{begin}, end{end}
66 {}
67
73 GKO_ATTRIBUTES constexpr bool is_valid() const { return begin <= end; }
74
80 GKO_ATTRIBUTES constexpr size_type length() const { return end - begin; }
81
86
91};
92
93
94GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<(const span& first,
95 const span& second)
96{
97 return first.end < second.begin;
98}
99
100
101GKO_ATTRIBUTES GKO_INLINE constexpr bool operator<=(const span& first,
102 const span& second)
103{
104 return first.end <= second.begin;
105}
106
107
108GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>(const span& first,
109 const span& second)
110{
111 return second < first;
112}
113
114
115GKO_ATTRIBUTES GKO_INLINE constexpr bool operator>=(const span& first,
116 const span& second)
117{
118 return second <= first;
119}
120
121
122GKO_ATTRIBUTES GKO_INLINE constexpr bool operator==(const span& first,
123 const span& second)
124{
125 return first.begin == second.begin && first.end == second.end;
126}
127
128
129GKO_ATTRIBUTES GKO_INLINE constexpr bool operator!=(const span& first,
130 const span& second)
131{
132 return !(first == second);
133}
134
138struct local_span : span {
139 using span::span;
140};
141
142
143namespace detail {
144
145
146template <size_type CurrentDimension = 0, typename FirstRange,
147 typename SecondRange>
148GKO_ATTRIBUTES constexpr GKO_INLINE
149 std::enable_if_t<(CurrentDimension >= max(FirstRange::dimensionality,
150 SecondRange::dimensionality)),
151 bool>
152 equal_dimensions(const FirstRange&, const SecondRange&)
153{
154 return true;
155}
156
157template <size_type CurrentDimension = 0, typename FirstRange,
158 typename SecondRange>
159GKO_ATTRIBUTES constexpr GKO_INLINE
160 std::enable_if_t<(CurrentDimension < max(FirstRange::dimensionality,
161 SecondRange::dimensionality)),
162 bool>
163 equal_dimensions(const FirstRange& first, const SecondRange& second)
164{
165 return first.length(CurrentDimension) == second.length(CurrentDimension) &&
166 equal_dimensions<CurrentDimension + 1>(first, second);
167}
168
173template <class...>
174struct head;
175
179template <class First, class... Rest>
180struct head<First, Rest...> {
181 using type = First;
182};
183
187template <class... T>
188using head_t = typename head<T...>::type;
189
190
191} // namespace detail
192
193
303template <typename Accessor>
304class range {
305public:
309 using accessor = Accessor;
310
315
319 ~range() = default;
320
329 template <
330 typename... AccessorParams,
331 typename = std::enable_if_t<
332 sizeof...(AccessorParams) != 1 ||
333 !std::is_same<
334 range, std::decay<detail::head_t<AccessorParams...>>>::value>>
335 GKO_ATTRIBUTES constexpr explicit range(AccessorParams&&... params)
336 : accessor_{std::forward<AccessorParams>(params)...}
337 {}
338
351 template <typename... DimensionTypes>
352 GKO_ATTRIBUTES constexpr auto operator()(DimensionTypes&&... dimensions)
353 const -> decltype(std::declval<accessor>()(
354 std::forward<DimensionTypes>(dimensions)...))
355 {
356 static_assert(sizeof...(DimensionTypes) <= dimensionality,
357 "Too many dimensions in range call");
358 return accessor_(std::forward<DimensionTypes>(dimensions)...);
359 }
360
369 template <typename OtherAccessor>
370 GKO_ATTRIBUTES const range& operator=(
371 const range<OtherAccessor>& other) const
372 {
373 GKO_ASSERT(detail::equal_dimensions(*this, other));
374 accessor_.copy_from(other);
375 return *this;
376 }
377
391 GKO_ATTRIBUTES const range& operator=(const range& other) const
392 {
393 GKO_ASSERT(detail::equal_dimensions(*this, other));
394 accessor_.copy_from(other.get_accessor());
395 return *this;
396 }
397
398 range(const range& other) = default;
399
407 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
408 {
409 return accessor_.length(dimension);
410 }
411
419 GKO_ATTRIBUTES constexpr const accessor* operator->() const noexcept
420 {
421 return &accessor_;
422 }
423
429 GKO_ATTRIBUTES constexpr const accessor& get_accessor() const noexcept
430 {
431 return accessor_;
432 }
433
434private:
435 accessor accessor_;
436};
437
438
439// implementation of range operations follows
440// (you probably should not have to look at this unless you're interested in the
441// gory details)
442
443
444namespace detail {
445
446
447enum class operation_kind { range_by_range, scalar_by_range, range_by_scalar };
448
449
450template <typename Accessor, typename Operation>
451struct implement_unary_operation {
452 using accessor = Accessor;
453 static constexpr size_type dimensionality = accessor::dimensionality;
454
455 GKO_ATTRIBUTES constexpr explicit implement_unary_operation(
456 const Accessor& operand)
457 : operand{operand}
458 {}
459
460 template <typename... DimensionTypes>
461 GKO_ATTRIBUTES constexpr auto operator()(
462 const DimensionTypes&... dimensions) const
463 -> decltype(Operation::evaluate(std::declval<accessor>(),
464 dimensions...))
465 {
466 return Operation::evaluate(operand, dimensions...);
467 }
468
469 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
470 {
471 return operand.length(dimension);
472 }
473
474 template <typename OtherAccessor>
475 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
476
477 const accessor operand;
478};
479
480
481template <operation_kind Kind, typename FirstOperand, typename SecondOperand,
482 typename Operation>
483struct implement_binary_operation {};
484
485template <typename FirstAccessor, typename SecondAccessor, typename Operation>
486struct implement_binary_operation<operation_kind::range_by_range, FirstAccessor,
487 SecondAccessor, Operation> {
488 using first_accessor = FirstAccessor;
489 using second_accessor = SecondAccessor;
490 static_assert(first_accessor::dimensionality ==
491 second_accessor::dimensionality,
492 "Both ranges need to have the same number of dimensions");
493 static constexpr size_type dimensionality = first_accessor::dimensionality;
494
495 GKO_ATTRIBUTES explicit implement_binary_operation(
496 const FirstAccessor& first, const SecondAccessor& second)
497 : first{first}, second{second}
498 {
499 GKO_ASSERT(gko::detail::equal_dimensions(first, second));
500 }
501
502 template <typename... DimensionTypes>
503 GKO_ATTRIBUTES constexpr auto operator()(
504 const DimensionTypes&... dimensions) const
505 -> decltype(Operation::evaluate_range_by_range(
506 std::declval<first_accessor>(), std::declval<second_accessor>(),
507 dimensions...))
508 {
509 return Operation::evaluate_range_by_range(first, second, dimensions...);
510 }
511
512 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
513 {
514 return first.length(dimension);
515 }
516
517 template <typename OtherAccessor>
518 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
519
520 const first_accessor first;
521 const second_accessor second;
522};
523
524template <typename FirstOperand, typename SecondAccessor, typename Operation>
525struct implement_binary_operation<operation_kind::scalar_by_range, FirstOperand,
526 SecondAccessor, Operation> {
527 using second_accessor = SecondAccessor;
528 static constexpr size_type dimensionality = second_accessor::dimensionality;
529
530 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
531 const FirstOperand& first, const SecondAccessor& second)
532 : first{first}, second{second}
533 {}
534
535 template <typename... DimensionTypes>
536 GKO_ATTRIBUTES constexpr auto operator()(
537 const DimensionTypes&... dimensions) const
538 -> decltype(Operation::evaluate_scalar_by_range(
539 std::declval<FirstOperand>(), std::declval<second_accessor>(),
540 dimensions...))
541 {
542 return Operation::evaluate_scalar_by_range(first, second,
543 dimensions...);
544 }
545
546 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
547 {
548 return second.length(dimension);
549 }
550
551 template <typename OtherAccessor>
552 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
553
554 const FirstOperand first;
555 const second_accessor second;
556};
557
558template <typename FirstAccessor, typename SecondOperand, typename Operation>
559struct implement_binary_operation<operation_kind::range_by_scalar,
560 FirstAccessor, SecondOperand, Operation> {
561 using first_accessor = FirstAccessor;
562 static constexpr size_type dimensionality = first_accessor::dimensionality;
563
564 GKO_ATTRIBUTES constexpr explicit implement_binary_operation(
565 const FirstAccessor& first, const SecondOperand& second)
566 : first{first}, second{second}
567 {}
568
569 template <typename... DimensionTypes>
570 GKO_ATTRIBUTES constexpr auto operator()(
571 const DimensionTypes&... dimensions) const
572 -> decltype(Operation::evaluate_range_by_scalar(
573 std::declval<first_accessor>(), std::declval<SecondOperand>(),
574 dimensions...))
575 {
576 return Operation::evaluate_range_by_scalar(first, second,
577 dimensions...);
578 }
579
580 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
581 {
582 return first.length(dimension);
583 }
584
585 template <typename OtherAccessor>
586 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
587
588 const first_accessor first;
589 const SecondOperand second;
590};
591
592
593} // namespace detail
594
595#define GKO_DEPRECATED_UNARY_RANGE_OPERATION(_operation_deprecated_name, \
596 _operation_name) \
597 namespace accessor { \
598 template <typename Operand> \
599 struct GKO_DEPRECATED("Please use " #_operation_name) \
600 _operation_deprecated_name : _operation_name<Operand> {}; \
601 } \
602 static_assert(true, \
603 "This assert is used to counter the false positive extra " \
604 "semi-colon warnings")
605
606
607#define GKO_ENABLE_UNARY_RANGE_OPERATION(_operation_name, _operator_name, \
608 _operator) \
609 namespace accessor { \
610 template <typename Operand> \
611 struct _operation_name \
612 : ::gko::detail::implement_unary_operation<Operand, \
613 ::gko::_operator> { \
614 using ::gko::detail::implement_unary_operation< \
615 Operand, ::gko::_operator>::implement_unary_operation; \
616 }; \
617 } \
618 GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name)
619
620
621#define GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(_operation_name, \
622 _operator_name) \
623 template <typename Accessor> \
624 GKO_ATTRIBUTES constexpr GKO_INLINE \
625 range<accessor::_operation_name<Accessor>> \
626 _operator_name(const range<Accessor>& operand) \
627 { \
628 return range<accessor::_operation_name<Accessor>>( \
629 operand.get_accessor()); \
630 } \
631 static_assert(true, \
632 "This assert is used to counter the false positive extra " \
633 "semi-colon warnings")
634
635
636#define GKO_DEFINE_SIMPLE_UNARY_OPERATION(_name, ...) \
637 struct _name { \
638 private: \
639 template <typename Operand> \
640 GKO_ATTRIBUTES static constexpr auto simple_evaluate_impl( \
641 const Operand& operand) -> decltype(__VA_ARGS__) \
642 { \
643 return __VA_ARGS__; \
644 } \
645 \
646 public: \
647 template <typename AccessorType, typename... DimensionTypes> \
648 GKO_ATTRIBUTES static constexpr auto evaluate( \
649 const AccessorType& accessor, const DimensionTypes&... dimensions) \
650 -> decltype(simple_evaluate_impl(accessor(dimensions...))) \
651 { \
652 return simple_evaluate_impl(accessor(dimensions...)); \
653 } \
654 }
655
656
657namespace accessor {
658namespace detail {
659
660
661// unary arithmetic
662GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_plus, +operand);
663GKO_DEFINE_SIMPLE_UNARY_OPERATION(unary_minus, -operand);
664
665// unary logical
666GKO_DEFINE_SIMPLE_UNARY_OPERATION(logical_not, !operand);
667
668// unary bitwise
669GKO_DEFINE_SIMPLE_UNARY_OPERATION(bitwise_not, ~(operand));
670
671// common functions
672GKO_DEFINE_SIMPLE_UNARY_OPERATION(zero_operation, zero(operand));
673GKO_DEFINE_SIMPLE_UNARY_OPERATION(one_operation, one(operand));
674GKO_DEFINE_SIMPLE_UNARY_OPERATION(abs_operation, abs(operand));
675GKO_DEFINE_SIMPLE_UNARY_OPERATION(real_operation, real(operand));
676GKO_DEFINE_SIMPLE_UNARY_OPERATION(imag_operation, imag(operand));
677GKO_DEFINE_SIMPLE_UNARY_OPERATION(conj_operation, conj(operand));
678GKO_DEFINE_SIMPLE_UNARY_OPERATION(squared_norm_operation,
679 squared_norm(operand));
680
681} // namespace detail
682} // namespace accessor
683
684
685// unary arithmetic
686GKO_ENABLE_UNARY_RANGE_OPERATION(unary_plus, operator+,
687 accessor::detail::unary_plus);
688GKO_ENABLE_UNARY_RANGE_OPERATION(unary_minus, operator-,
689 accessor::detail::unary_minus);
690
691// unary logical
692GKO_ENABLE_UNARY_RANGE_OPERATION(logical_not, operator!,
693 accessor::detail::logical_not);
694
695// unary bitwise
696GKO_ENABLE_UNARY_RANGE_OPERATION(bitwise_not, operator~,
697 accessor::detail::bitwise_not);
698
699// common unary functions
700
701GKO_ENABLE_UNARY_RANGE_OPERATION(zero_operation, zero,
702 accessor::detail::zero_operation);
703GKO_ENABLE_UNARY_RANGE_OPERATION(one_operation, one,
704 accessor::detail::one_operation);
705GKO_ENABLE_UNARY_RANGE_OPERATION(abs_operation, abs,
706 accessor::detail::abs_operation);
707GKO_ENABLE_UNARY_RANGE_OPERATION(real_operation, real,
708 accessor::detail::real_operation);
709GKO_ENABLE_UNARY_RANGE_OPERATION(imag_operation, imag,
710 accessor::detail::imag_operation);
711GKO_ENABLE_UNARY_RANGE_OPERATION(conj_operation, conj,
712 accessor::detail::conj_operation);
713GKO_ENABLE_UNARY_RANGE_OPERATION(squared_norm_operation, squared_norm,
714 accessor::detail::squared_norm_operation);
715
716GKO_DEPRECATED_UNARY_RANGE_OPERATION(one_operaton, one_operation);
717GKO_DEPRECATED_UNARY_RANGE_OPERATION(abs_operaton, abs_operation);
718GKO_DEPRECATED_UNARY_RANGE_OPERATION(real_operaton, real_operation);
719GKO_DEPRECATED_UNARY_RANGE_OPERATION(imag_operaton, imag_operation);
720GKO_DEPRECATED_UNARY_RANGE_OPERATION(conj_operaton, conj_operation);
721GKO_DEPRECATED_UNARY_RANGE_OPERATION(squared_norm_operaton,
723
724namespace accessor {
725
726
727template <typename Accessor>
728struct transpose_operation {
729 using accessor = Accessor;
730 static constexpr size_type dimensionality = accessor::dimensionality;
731
732 GKO_ATTRIBUTES constexpr explicit transpose_operation(
733 const Accessor& operand)
734 : operand{operand}
735 {}
736
737 template <typename FirstDimensionType, typename SecondDimensionType,
738 typename... DimensionTypes>
739 GKO_ATTRIBUTES constexpr auto operator()(
740 const FirstDimensionType& first_dim,
741 const SecondDimensionType& second_dim,
742 const DimensionTypes&... dims) const
743 -> decltype(std::declval<accessor>()(second_dim, first_dim, dims...))
744 {
745 return operand(second_dim, first_dim, dims...);
746 }
747
748 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
749 {
750 return dimension < 2 ? operand.length(dimension ^ 1)
751 : operand.length(dimension);
752 }
753
754 template <typename OtherAccessor>
755 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
756
757 const accessor operand;
758};
759
760
761} // namespace accessor
762
763
764GKO_BIND_UNARY_RANGE_OPERATION_TO_OPERATOR(transpose_operation, transpose);
765
766
767#undef GKO_DEPRECATED_UNARY_RANGE_OPERATION
768#undef GKO_DEFINE_SIMPLE_UNARY_OPERATION
769#undef GKO_ENABLE_UNARY_RANGE_OPERATION
770
771
772#define GKO_ENABLE_BINARY_RANGE_OPERATION(_operation_name, _operator_name, \
773 _operator) \
774 namespace accessor { \
775 template <::gko::detail::operation_kind Kind, typename FirstOperand, \
776 typename SecondOperand> \
777 struct _operation_name \
778 : ::gko::detail::implement_binary_operation< \
779 Kind, FirstOperand, SecondOperand, ::gko::_operator> { \
780 using ::gko::detail::implement_binary_operation< \
781 Kind, FirstOperand, SecondOperand, \
782 ::gko::_operator>::implement_binary_operation; \
783 }; \
784 } \
785 GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name); \
786 static_assert(true, \
787 "This assert is used to counter the false positive extra " \
788 "semi-colon warnings")
789
790
791#define GKO_BIND_RANGE_OPERATION_TO_OPERATOR(_operation_name, _operator_name) \
792 template <typename Accessor> \
793 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
794 ::gko::detail::operation_kind::range_by_range, Accessor, Accessor>> \
795 _operator_name(const range<Accessor>& first, \
796 const range<Accessor>& second) \
797 { \
798 return range<accessor::_operation_name< \
799 ::gko::detail::operation_kind::range_by_range, Accessor, \
800 Accessor>>(first.get_accessor(), second.get_accessor()); \
801 } \
802 \
803 template <typename FirstAccessor, typename SecondAccessor> \
804 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
805 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
806 SecondAccessor>> \
807 _operator_name(const range<FirstAccessor>& first, \
808 const range<SecondAccessor>& second) \
809 { \
810 return range<accessor::_operation_name< \
811 ::gko::detail::operation_kind::range_by_range, FirstAccessor, \
812 SecondAccessor>>(first.get_accessor(), second.get_accessor()); \
813 } \
814 \
815 template <typename FirstAccessor, typename SecondOperand> \
816 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
817 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
818 SecondOperand>> \
819 _operator_name(const range<FirstAccessor>& first, \
820 const SecondOperand& second) \
821 { \
822 return range<accessor::_operation_name< \
823 ::gko::detail::operation_kind::range_by_scalar, FirstAccessor, \
824 SecondOperand>>(first.get_accessor(), second); \
825 } \
826 \
827 template <typename FirstOperand, typename SecondAccessor> \
828 GKO_ATTRIBUTES constexpr GKO_INLINE range<accessor::_operation_name< \
829 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
830 SecondAccessor>> \
831 _operator_name(const FirstOperand& first, \
832 const range<SecondAccessor>& second) \
833 { \
834 return range<accessor::_operation_name< \
835 ::gko::detail::operation_kind::scalar_by_range, FirstOperand, \
836 SecondAccessor>>(first, second.get_accessor()); \
837 } \
838 static_assert(true, \
839 "This assert is used to counter the false positive extra " \
840 "semi-colon warnings")
841
842
843#define GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(_deprecated_name, _name) \
844 struct GKO_DEPRECATED("Please use " #_name) _deprecated_name : _name {}
845
846#define GKO_DEFINE_SIMPLE_BINARY_OPERATION(_name, ...) \
847 struct _name { \
848 private: \
849 template <typename FirstOperand, typename SecondOperand> \
850 GKO_ATTRIBUTES constexpr static auto simple_evaluate_impl( \
851 const FirstOperand& first, const SecondOperand& second) \
852 -> decltype(__VA_ARGS__) \
853 { \
854 return __VA_ARGS__; \
855 } \
856 \
857 public: \
858 template <typename FirstAccessor, typename SecondAccessor, \
859 typename... DimensionTypes> \
860 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_range( \
861 const FirstAccessor& first, const SecondAccessor& second, \
862 const DimensionTypes&... dims) \
863 -> decltype(simple_evaluate_impl(first(dims...), second(dims...))) \
864 { \
865 return simple_evaluate_impl(first(dims...), second(dims...)); \
866 } \
867 \
868 template <typename FirstOperand, typename SecondAccessor, \
869 typename... DimensionTypes> \
870 GKO_ATTRIBUTES static constexpr auto evaluate_scalar_by_range( \
871 const FirstOperand& first, const SecondAccessor& second, \
872 const DimensionTypes&... dims) \
873 -> decltype(simple_evaluate_impl(first, second(dims...))) \
874 { \
875 return simple_evaluate_impl(first, second(dims...)); \
876 } \
877 \
878 template <typename FirstAccessor, typename SecondOperand, \
879 typename... DimensionTypes> \
880 GKO_ATTRIBUTES static constexpr auto evaluate_range_by_scalar( \
881 const FirstAccessor& first, const SecondOperand& second, \
882 const DimensionTypes&... dims) \
883 -> decltype(simple_evaluate_impl(first(dims...), second)) \
884 { \
885 return simple_evaluate_impl(first(dims...), second); \
886 } \
887 }
888
889
890namespace accessor {
891namespace detail {
892
893
894// binary arithmetic
895GKO_DEFINE_SIMPLE_BINARY_OPERATION(add, first + second);
896GKO_DEFINE_SIMPLE_BINARY_OPERATION(sub, first - second);
897GKO_DEFINE_SIMPLE_BINARY_OPERATION(mul, first* second);
898GKO_DEFINE_SIMPLE_BINARY_OPERATION(div, first / second);
899GKO_DEFINE_SIMPLE_BINARY_OPERATION(mod, first % second);
900
901// relational
902GKO_DEFINE_SIMPLE_BINARY_OPERATION(less, first < second);
903GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater, first > second);
904GKO_DEFINE_SIMPLE_BINARY_OPERATION(less_or_equal, first <= second);
905GKO_DEFINE_SIMPLE_BINARY_OPERATION(greater_or_equal, first >= second);
906GKO_DEFINE_SIMPLE_BINARY_OPERATION(equal, first == second);
907GKO_DEFINE_SIMPLE_BINARY_OPERATION(not_equal, first != second);
908
909// binary logical
910GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_or, first || second);
911GKO_DEFINE_SIMPLE_BINARY_OPERATION(logical_and, first&& second);
912
913// binary bitwise
914GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_or, first | second);
915GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_and, first& second);
916GKO_DEFINE_SIMPLE_BINARY_OPERATION(bitwise_xor, first ^ second);
917GKO_DEFINE_SIMPLE_BINARY_OPERATION(left_shift, first << second);
918GKO_DEFINE_SIMPLE_BINARY_OPERATION(right_shift, first >> second);
919
920// common binary functions
921GKO_DEFINE_SIMPLE_BINARY_OPERATION(max_operation, max(first, second));
922GKO_DEFINE_SIMPLE_BINARY_OPERATION(min_operation, min(first, second));
923
924GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(max_operaton, max_operation);
925GKO_DEPRECATED_SIMPLE_BINARY_OPERATION(min_operaton, min_operation);
926} // namespace detail
927} // namespace accessor
928
929
930// binary arithmetic
931GKO_ENABLE_BINARY_RANGE_OPERATION(add, operator+, accessor::detail::add);
932GKO_ENABLE_BINARY_RANGE_OPERATION(sub, operator-, accessor::detail::sub);
933GKO_ENABLE_BINARY_RANGE_OPERATION(mul, operator*, accessor::detail::mul);
934GKO_ENABLE_BINARY_RANGE_OPERATION(div, operator/, accessor::detail::div);
935GKO_ENABLE_BINARY_RANGE_OPERATION(mod, operator%, accessor::detail::mod);
936
937// relational
938GKO_ENABLE_BINARY_RANGE_OPERATION(less, operator<, accessor::detail::less);
939GKO_ENABLE_BINARY_RANGE_OPERATION(greater, operator>,
940 accessor::detail::greater);
941GKO_ENABLE_BINARY_RANGE_OPERATION(less_or_equal, operator<=,
942 accessor::detail::less_or_equal);
943GKO_ENABLE_BINARY_RANGE_OPERATION(greater_or_equal, operator>=,
944 accessor::detail::greater_or_equal);
945GKO_ENABLE_BINARY_RANGE_OPERATION(equal, operator==, accessor::detail::equal);
946GKO_ENABLE_BINARY_RANGE_OPERATION(not_equal, operator!=,
947 accessor::detail::not_equal);
948
949// binary logical
950GKO_ENABLE_BINARY_RANGE_OPERATION(logical_or, operator||,
951 accessor::detail::logical_or);
952GKO_ENABLE_BINARY_RANGE_OPERATION(logical_and, operator&&,
953 accessor::detail::logical_and);
954
955// binary bitwise
956GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_or, operator|,
957 accessor::detail::bitwise_or);
958GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_and, operator&,
959 accessor::detail::bitwise_and);
960GKO_ENABLE_BINARY_RANGE_OPERATION(bitwise_xor, operator^,
961 accessor::detail::bitwise_xor);
962GKO_ENABLE_BINARY_RANGE_OPERATION(left_shift, operator<<,
963 accessor::detail::left_shift);
964GKO_ENABLE_BINARY_RANGE_OPERATION(right_shift, operator>>,
965 accessor::detail::right_shift);
966
967// common binary functions
968GKO_ENABLE_BINARY_RANGE_OPERATION(max_operation, max,
969 accessor::detail::max_operation);
970GKO_ENABLE_BINARY_RANGE_OPERATION(min_operation, min,
971 accessor::detail::min_operation);
972
973
974// special binary range functions
975namespace accessor {
976
977
978template <gko::detail::operation_kind Kind, typename FirstAccessor,
979 typename SecondAccessor>
980struct mmul_operation {
981 static_assert(Kind == gko::detail::operation_kind::range_by_range,
982 "Matrix multiplication expects both operands to be ranges");
983 using first_accessor = FirstAccessor;
984 using second_accessor = SecondAccessor;
985 static_assert(first_accessor::dimensionality ==
986 second_accessor::dimensionality,
987 "Both ranges need to have the same number of dimensions");
988 static constexpr size_type dimensionality = first_accessor::dimensionality;
989
990 GKO_ATTRIBUTES explicit mmul_operation(const FirstAccessor& first,
991 const SecondAccessor& second)
992 : first{first}, second{second}
993 {
994 GKO_ASSERT(first.length(1) == second.length(0));
995 GKO_ASSERT(gko::detail::equal_dimensions<2>(first, second));
996 }
997
998 template <typename FirstDimension, typename SecondDimension,
999 typename... DimensionTypes>
1000 GKO_ATTRIBUTES auto operator()(const FirstDimension& row,
1001 const SecondDimension& col,
1002 const DimensionTypes&... rest) const
1003 -> decltype(std::declval<FirstAccessor>()(row, 0, rest...) *
1004 std::declval<SecondAccessor>()(0, col, rest...) +
1005 std::declval<FirstAccessor>()(row, 1, rest...) *
1006 std::declval<SecondAccessor>()(1, col, rest...))
1007 {
1008 using result_type =
1009 decltype(first(row, 0, rest...) * second(0, col, rest...) +
1010 first(row, 1, rest...) * second(1, col, rest...));
1011 GKO_ASSERT(first.length(1) == second.length(0));
1012 auto result = zero<result_type>();
1013 const auto size = first.length(1);
1014 for (auto i = zero(size); i < size; ++i) {
1015 result += first(row, i, rest...) * second(i, col, rest...);
1016 }
1017 return result;
1018 }
1019
1020 GKO_ATTRIBUTES constexpr size_type length(size_type dimension) const
1021 {
1022 return dimension == 1 ? second.length(1) : first.length(dimension);
1023 }
1024
1025 template <typename OtherAccessor>
1026 GKO_ATTRIBUTES void copy_from(const OtherAccessor& other) const = delete;
1027
1028 const first_accessor first;
1029 const second_accessor second;
1030};
1031
1032
1033} // namespace accessor
1034
1035
1036GKO_BIND_RANGE_OPERATION_TO_OPERATOR(mmul_operation, mmul);
1037
1038
1039#undef GKO_DEFINE_SIMPLE_BINARY_OPERATION
1040#undef GKO_ENABLE_BINARY_RANGE_OPERATION
1041
1042
1043} // namespace gko
1044
1045
1046#endif // GKO_PUBLIC_CORE_BASE_RANGE_HPP_
static constexpr size_type dimensionality
Number of dimensions of the accessor.
Definition range_accessors.hpp:60
A range is a multidimensional view of the memory.
Definition range.hpp:304
Accessor accessor
The type of the underlying accessor.
Definition range.hpp:309
constexpr auto operator()(DimensionTypes &&... dimensions) const -> decltype(std::declval< accessor >()(std::forward< DimensionTypes >(dimensions)...))
Returns a value (or a sub-range) with the specified indexes.
Definition range.hpp:352
constexpr size_type length(size_type dimension) const
Returns the length of the specified dimension of the range.
Definition range.hpp:407
constexpr const accessor * operator->() const noexcept
Returns a pointer to the accessor.
Definition range.hpp:419
static constexpr size_type dimensionality
Definition range.hpp:314
const range & operator=(const range &other) const
Assigns another range to this range.
Definition range.hpp:391
~range()=default
Use the default destructor.
const range & operator=(const range< OtherAccessor > &other) const
Definition range.hpp:370
constexpr range(AccessorParams &&... params)
Creates a new range.
Definition range.hpp:335
The accessor namespace.
Definition range.hpp:657
The Ginkgo namespace.
Definition abstract_factory.hpp:20
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:654
constexpr std::enable_if_t<!is_complex_s< T >::value, T > abs(const T &x)
Returns the absolute value of the object.
Definition math.hpp:962
constexpr T zero()
Returns the additive identity for T.
Definition math.hpp:626
constexpr auto imag(const T &x)
Returns the imaginary part of the object.
Definition math.hpp:916
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:90
constexpr T min(const T &x, const T &y)
Returns the smaller of the arguments.
Definition math.hpp:750
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:119
constexpr auto squared_norm(const T &x) -> decltype(real(conj(x) *x))
Returns the squared norm of the object.
Definition math.hpp:944
constexpr auto conj(const T &x)
Returns the conjugate of an object.
Definition math.hpp:930
constexpr T max(const T &x, const T &y)
Returns the larger of the arguments.
Definition math.hpp:732
constexpr auto real(const T &x)
Returns the real part of the object.
Definition math.hpp:900
Definition range.hpp:706
Definition range.hpp:717
Definition range.hpp:931
Definition range.hpp:959
Definition range.hpp:697
Definition range.hpp:957
Definition range.hpp:961
Definition range.hpp:712
Definition range.hpp:720
Definition range.hpp:934
Definition range.hpp:945
Definition range.hpp:944
Definition range.hpp:940
Definition range.hpp:710
Definition range.hpp:719
Definition range.hpp:963
Definition range.hpp:942
Definition range.hpp:938
Definition range.hpp:953
Definition range.hpp:693
Definition range.hpp:951
Definition range.hpp:969
Definition range.hpp:971
Definition range.hpp:935
Definition range.hpp:933
Definition range.hpp:947
Definition range.hpp:704
Definition range.hpp:716
Definition range.hpp:708
Definition range.hpp:718
Definition range.hpp:965
Definition range.hpp:722
Definition range.hpp:932
Definition range.hpp:689
Definition range.hpp:687
Definition range.hpp:702
A span that is used exclusively for local numbering.
Definition range.hpp:138
constexpr span(size_type point) noexcept
Creates a span representing a point point.
Definition range.hpp:54
A span is a lightweight structure used to create sub-ranges from other ranges.
Definition range.hpp:46
constexpr span(size_type begin, size_type end) noexcept
Creates a span.
Definition range.hpp:64
constexpr span(size_type point) noexcept
Creates a span representing a point point.
Definition range.hpp:54
constexpr bool is_valid() const
Checks if a span is valid.
Definition range.hpp:73
constexpr size_type length() const
Returns the length of a span.
Definition range.hpp:80
const size_type begin
Beginning of the span.
Definition range.hpp:85
const size_type end
End of the span.
Definition range.hpp:90