mutable
A Database System for Research and Fast Prototyping
Loading...
Searching...
No Matches
GridSearch.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <array>
5#include <cfenv>
6#include <cmath>
7#include <functional>
8#include <iostream>
10#include <stdexcept>
11#include <tuple>
12#include <type_traits>
13#include <utility>
14#include <vector>
15
16
17namespace m {
18
19namespace gs {
20
21template<typename T, template<typename> typename Derived>
22struct Space
23{
24 using derived_type = Derived<T>;
25 using value_type = T;
26
27#define CDERIVED (*static_cast<const derived_type*>(this))
28 value_type lo() const { return CDERIVED.lo(); }
29 value_type hi() const { return CDERIVED.hi(); }
30 double step() const { return CDERIVED.step(); }
31 unsigned num_steps() const { return CDERIVED.num_steps(); }
32
33 value_type at(unsigned n) const { return CDERIVED.at(n); }
34 value_type operator()(unsigned n) const {return at(n); }
35
36 std::vector<value_type> sequence() const { return CDERIVED.sequence(); }
37#undef CDERIVED
38
40 friend std::ostream & operator<<(std::ostream &out, const Space &S) {
41 return out << static_cast<const derived_type&>(S);
42 }
43
44 void dump(std::ostream &out) const { out << *this << std::endl; }
45 void dump() const { dump(std::cerr); }
47};
48
49template<typename T>
50struct LinearSpace : Space<T, LinearSpace>
51{
52 static_assert(std::is_arithmetic_v<T>, "type T must be an arithmetic type");
53 using value_type = T;
54 using difference_type = typename std::conditional_t<std::is_integral_v<T>,
55 std::make_signed<T>,
56 std::common_type<T>>::type;
57
58 private:
61 double step_;
62 unsigned num_steps_;
64
65 public:
66 LinearSpace(value_type lowest, value_type highest, unsigned num_steps, bool is_ascending = true)
67 : lo_(lowest), hi_(highest), num_steps_(num_steps), is_ascending_(is_ascending)
68 {
69 if (lo_ > hi_)
70 throw std::invalid_argument("invalid range");
71 if (num_steps_ == 0)
72 throw std::invalid_argument("number of steps must not be zero");
73
74 const int save_round = std::fegetround();
75 std::fesetround(FE_TOWARDZERO);
76 step_ = (double(hi_) - double(lo_)) / num_steps_;
77 std::fesetround(save_round);
78 }
79
80 static LinearSpace Ascending(value_type lowest, value_type highest, unsigned num_steps) {
81 return LinearSpace(lowest, highest, num_steps, true);
82 }
83
84 static LinearSpace Descending(value_type lowest, value_type highest, unsigned num_steps) {
85 return LinearSpace(lowest, highest, num_steps, false);
86 }
87
88 value_type lo() const { return lo_; }
89 value_type hi() const { return hi_; }
90 double step() const { return step_; }
91 unsigned num_steps() const { return num_steps_; }
92 difference_type delta() const { return hi_ - lo_; }
93 bool ascending() const { return is_ascending_; }
94 bool descending() const { return not is_ascending_; }
95
96 value_type at(unsigned n) const {
97 if (n > num_steps_)
98 throw std::out_of_range("n must be between 0 and num_steps()");
99 if constexpr (std::is_integral_v<value_type>) {
100 const typename std::make_unsigned_t<value_type> delta = std::round(n * step());
101 if (ascending())
102 return lo() + delta;
103 else
104 return hi() - delta;
105 } else {
106 if (ascending())
107 return std::clamp<value_type>(value_type(double(lo()) + n * step_), lo_, hi_);
108 else
109 return std::clamp<value_type>(value_type(double(hi()) - n * step_), lo_, hi_);
110 }
111 }
112 value_type operator()(unsigned n) const { return at(n); }
113
114 std::vector<value_type> sequence() const {
115 std::vector<value_type> vec;
116 vec.reserve(num_steps());
117
118 for (unsigned i = 0; i <= num_steps(); ++i)
119 vec.push_back(at(i));
120
121 return vec;
122 }
123
125 friend std::ostream & operator<<(std::ostream &out, const LinearSpace &S) {
126 return out << "linear space from " << S.lo() << " to " << S.hi() << " with " << S.num_steps() << " steps of "
127 << S.step();
128 }
129
130 void dump(std::ostream &out) const { out << *this << std::endl; }
131 void dump() const { dump(std::cerr); }
133};
134
135template<typename... Spaces>
137{
138 using callback_type = std::function<void(typename Spaces::value_type...)>;
139 static constexpr std::size_t NUM_SPACES = sizeof...(Spaces);
140
141 private:
142 std::tuple<Spaces...> spaces_;
143
144 public:
145 GridSearch(Spaces... spaces) : spaces_(std::forward<Spaces>(spaces)...) { }
146
147 constexpr std::size_t num_spaces() const { return NUM_SPACES; }
148
149 std::size_t num_points() const {
150 return std::apply([](auto&... space) {
151 return ((space.num_steps() + 1) * ... );
152 }, spaces_);
153 }
154
155 void search(callback_type fn) const;
156 void operator()(callback_type fn) const { search(fn); }
157
159 friend std::ostream & operator<<(std::ostream &out, const GridSearch &GS) {
160 out << "grid search with";
161
162 std::apply([&out](auto&... space) {
163 ((out << "\n " << space), ...); // use C++17 fold-expression
164 }, GS.spaces_);
165
166 return out;
167 }
168
169 void dump(std::ostream &out) const { out << *this << std::endl; }
170 void dump() const { dump(std::cerr); }
172
173 private:
174 template<std::size_t... I>
175 std::tuple<typename Spaces::value_type...>
176 make_args(std::array<unsigned, NUM_SPACES> &counters, std::index_sequence<I...>) const {
177 return std::apply([&counters](auto&... space) {
178 return std::make_tuple(space(counters[I])...);
179 }, spaces_);
180 }
181};
182
183template<typename... Spaces>
185{
186 std::array<unsigned, NUM_SPACES> counters;
187 std::fill(counters.begin(), counters.end(), 0U);
188 const std::array<unsigned, NUM_SPACES> num_steps = std::apply([](auto&... space) {
189 return std::array<unsigned, NUM_SPACES>{ space.num_steps()... };
190 }, spaces_);
191
192 for (;;) {
193 auto args = make_args(counters, std::index_sequence_for<Spaces...>{});
194 std::apply(fn, args);
195
196 std::size_t idx = NUM_SPACES - 1;
197
198 while (counters[idx] == num_steps[idx]) {
199 if (idx == 0) goto finished;
200 counters[idx] = 0;
201 --idx;
202 }
203 ++counters[idx];
204 }
205finished:;
206}
207
208}
209
210}
#define CDERIVED
Definition: GridSearch.hpp:27
struct @5 args
‍mutable namespace
Definition: Backend.hpp:10
T(x)
and arithmetic< U > and same_signedness< T, U > U
Definition: concepts.hpp:90
STL namespace.
static constexpr std::size_t NUM_SPACES
Definition: GridSearch.hpp:139
std::function< void(typename Spaces::value_type...)> callback_type
Definition: GridSearch.hpp:138
M_LCOV_EXCL_START friend std::ostream & operator<<(std::ostream &out, const GridSearch &GS)
Definition: GridSearch.hpp:159
std::tuple< typename Spaces::value_type... > make_args(std::array< unsigned, NUM_SPACES > &counters, std::index_sequence< I... >) const
Definition: GridSearch.hpp:176
void dump() const
Definition: GridSearch.hpp:170
std::size_t num_points() const
Definition: GridSearch.hpp:149
GridSearch(Spaces... spaces)
Definition: GridSearch.hpp:145
void dump(std::ostream &out) const
Definition: GridSearch.hpp:169
constexpr std::size_t num_spaces() const
Definition: GridSearch.hpp:147
void operator()(callback_type fn) const
Definition: GridSearch.hpp:156
std::tuple< Spaces... > spaces_
Definition: GridSearch.hpp:142
void search(callback_type fn) const
Definition: GridSearch.hpp:184
M_LCOV_EXCL_START friend std::ostream & operator<<(std::ostream &out, const LinearSpace &S)
Definition: GridSearch.hpp:125
bool ascending() const
Definition: GridSearch.hpp:93
static LinearSpace Ascending(value_type lowest, value_type highest, unsigned num_steps)
Definition: GridSearch.hpp:80
LinearSpace(value_type lowest, value_type highest, unsigned num_steps, bool is_ascending=true)
Definition: GridSearch.hpp:66
double step() const
Definition: GridSearch.hpp:90
void dump(std::ostream &out) const
Definition: GridSearch.hpp:130
std::vector< value_type > sequence() const
Definition: GridSearch.hpp:114
value_type lo() const
Definition: GridSearch.hpp:88
static LinearSpace Descending(value_type lowest, value_type highest, unsigned num_steps)
Definition: GridSearch.hpp:84
value_type at(unsigned n) const
Definition: GridSearch.hpp:96
difference_type delta() const
Definition: GridSearch.hpp:92
value_type hi() const
Definition: GridSearch.hpp:89
unsigned num_steps() const
Definition: GridSearch.hpp:91
bool descending() const
Definition: GridSearch.hpp:94
typename std::conditional_t< std::is_integral_v< T >, std::make_signed< T >, std::common_type< T > >::type difference_type
Definition: GridSearch.hpp:56
void dump() const
Definition: GridSearch.hpp:131
value_type operator()(unsigned n) const
Definition: GridSearch.hpp:112
value_type lo() const
Definition: GridSearch.hpp:28
double step() const
Definition: GridSearch.hpp:30
unsigned num_steps() const
Definition: GridSearch.hpp:31
M_LCOV_EXCL_START friend std::ostream & operator<<(std::ostream &out, const Space &S)
Definition: GridSearch.hpp:40
std::vector< value_type > sequence() const
Definition: GridSearch.hpp:36
Derived< T > derived_type
Definition: GridSearch.hpp:24
value_type operator()(unsigned n) const
Definition: GridSearch.hpp:34
void dump(std::ostream &out) const
Definition: GridSearch.hpp:44
value_type at(unsigned n) const
Definition: GridSearch.hpp:33
void dump() const
Definition: GridSearch.hpp:45
value_type hi() const
Definition: GridSearch.hpp:29