[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_earlystopping.hxx VIGRA

1 #ifndef RF_EARLY_STOPPING_P_HXX
2 #define RF_EARLY_STOPPING_P_HXX
3 #include <cmath>
4 #include <stdexcept>
5 #include "rf_common.hxx"
6 
7 namespace vigra
8 {
9 
10 #if 0
11 namespace es_detail
12 {
13  template<class T>
14  T power(T const & in, int n)
15  {
16  T result = NumericTraits<T>::one();
17  for(int ii = 0; ii < n ;++ii)
18  result *= in;
19  return result;
20  }
21 }
22 #endif
23 
24 /**Base class from which all EarlyStopping Functors derive.
25  */
26 class StopBase
27 {
28 protected:
29  ProblemSpec<> ext_param_;
30  int tree_count_ ;
31  bool is_weighted_;
32 
33 public:
34  template<class T>
35  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
36  {
37  ext_param_ = prob;
38  is_weighted_ = is_weighted;
39  tree_count_ = tree_count;
40  }
41 
42 #ifdef DOXYGEN
43  /** called after the prediction of a tree was added to the total prediction
44  * \param weightIter Iterator to the weights delivered by current tree.
45  * \param k after kth tree
46  * \param prob Total probability array
47  * \param totalCt sum of probability array.
48  */
49  template<class WeightIter, class T, class C>
50  bool after_prediction(WeightIter weightIter, int k, MultiArrayView<2, T, C> const & prob , double totalCt)
51 #else
52  template<class WeightIter, class T, class C>
53  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
54  {return false;}
55 #endif //DOXYGEN
56 };
57 
58 
59 /**Stop predicting after a set number of trees
60  */
61 class StopAfterTree : public StopBase
62 {
63 public:
64  double max_tree_p;
65  int max_tree_;
66  typedef StopBase SB;
67 
68  ArrayVector<double> depths;
69 
70  /** Constructor
71  * \param max_tree number of trees to be used for prediction
72  */
73  StopAfterTree(double max_tree)
74  :
75  max_tree_p(max_tree)
76  {}
77 
78  template<class T>
79  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
80  {
81  max_tree_ = ceil(max_tree_p * tree_count);
82  SB::set_external_parameters(prob, tree_count, is_weighted);
83  }
84 
85  template<class WeightIter, class T, class C>
86  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
87  {
88  if(k == SB::tree_count_ -1)
89  {
90  depths.push_back(double(k+1)/double(SB::tree_count_));
91  return false;
92  }
93  if(k < max_tree_)
94  return false;
95  depths.push_back(double(k+1)/double(SB::tree_count_));
96  return true;
97  }
98 };
99 
100 /** Stop predicting after a certain amount of votes exceed certain proportion.
101  * case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_
102  * case weighted voting: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
103  * (maximal number of votes possible in both cases)
104  */
106 {
107 public:
108  double proportion_;
109  typedef StopBase SB;
110  ArrayVector<double> depths;
111 
112  /** Constructor
113  * \param proportion specify proportion to be used.
114  */
115  StopAfterVoteCount(double proportion)
116  :
117  proportion_(proportion)
118  {}
119 
120  template<class WeightIter, class T, class C>
121  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
122  {
123  if(k == SB::tree_count_ -1)
124  {
125  depths.push_back(double(k+1)/double(SB::tree_count_));
126  return false;
127  }
128 
129 
130  if(SB::is_weighted_)
131  {
132  if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
133  {
134  depths.push_back(double(k+1)/double(SB::tree_count_));
135  return true;
136  }
137  }
138  else
139  {
140  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
141  {
142  depths.push_back(double(k+1)/double(SB::tree_count_));
143  return true;
144  }
145  }
146  return false;
147  }
148 
149 };
150 
151 
152 /** Stop predicting if the 2norm of the probabilities does not change*/
154 
155 {
156 public:
157  double thresh_;
158  int num_;
159  MultiArray<2, double> last_;
161  ArrayVector<double> depths;
162  typedef StopBase SB;
163 
164  /** Constructor
165  * \param thresh: If the two norm of the probabilities changes less then thresh then stop
166  * \param num : look at atleast num trees before stopping
167  */
168  StopIfConverging(double thresh, int num = 10)
169  :
170  thresh_(thresh),
171  num_(num)
172  {}
173 
174  template<class T>
175  void set_external_parameters(ProblemSpec<T> const &prob, int tree_count = 0, bool is_weighted = false)
176  {
177  last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
178  cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
179  SB::set_external_parameters(prob, tree_count, is_weighted);
180  }
181  template<class WeightIter, class T, class C>
182  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> const & prob, double)
183  {
184  if(k == SB::tree_count_ -1)
185  {
186  depths.push_back(double(k+1)/double(SB::tree_count_));
187  return false;
188  }
189  if(k <= num_)
190  {
191  last_ = prob;
192  last_/= last_.norm(1);
193  return false;
194  }
195  else
196  {
197  cur_ = prob;
198  cur_ /= cur_.norm(1);
199  last_ -= cur_;
200  double nrm = last_.norm();
201  if(nrm < thresh_)
202  {
203  depths.push_back(double(k+1)/double(SB::tree_count_));
204  return true;
205  }
206  else
207  {
208  last_ = cur_;
209  }
210  }
211  return false;
212  }
213 };
214 
215 
216 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
217  * case unweighted voting: stop if margin exceeds proportion * SB::tree_count_
218  * case weighted voting: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
219  * (maximal number of votes possible in both cases)
220  */
221 class StopIfMargin : public StopBase
222 {
223 public:
224  double proportion_;
225  typedef StopBase SB;
226  ArrayVector<double> depths;
227 
228  /** Constructor
229  * \param proportion specify proportion to be used.
230  */
231  StopIfMargin(double proportion)
232  :
233  proportion_(proportion)
234  {}
235 
236  template<class WeightIter, class T, class C>
237  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
238  {
239  if(k == SB::tree_count_ -1)
240  {
241  depths.push_back(double(k+1)/double(SB::tree_count_));
242  return false;
243  }
244  int index = argMax(prob);
245  double a = prob[argMax(prob)];
246  prob[argMax(prob)] = 0;
247  double b = prob[argMax(prob)];
248  prob[index] = a;
249  double margin = a - b;
250  if(SB::is_weighted_)
251  {
252  if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
253  {
254  depths.push_back(double(k+1)/double(SB::tree_count_));
255  return true;
256  }
257  }
258  else
259  {
260  if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
261  {
262  depths.push_back(double(k+1)/double(SB::tree_count_));
263  return true;
264  }
265  }
266  return false;
267  }
268 };
269 
270 
271 /**Probabilistic Stopping criterion (binomial test)
272  *
273  * Can only be used in a two class setting
274  *
275  * Stop if the Parameters estimated for the underlying binomial distribution
276  * can be estimated with certainty over 1-alpha.
277  * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
278  */
279 class StopIfBinTest : public StopBase
280 {
281 public:
282  double alpha_;
283  MultiArrayView<2, double> n_choose_k;
284  /** Constructor
285  * \param alpha specify alpha (=proportion) value for binomial test.
286  * \param nck_ Matrix with precomputed values for n choose k
287  * nck_(n, k) is n choose k.
288  */
290  :
291  alpha_(alpha),
292  n_choose_k(nck_)
293  {}
294  typedef StopBase SB;
295 
296  /**ArrayVector that will contain the fraction of trees that was visited before terminating
297  */
299 
300  double binomial(int N, int k, double p)
301  {
302 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
303  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
304  }
305 
306  template<class WeightIter, class T, class C>
307  bool after_prediction(WeightIter, int k,
308  MultiArrayView<2, T, C> const &prob, double)
309  {
310  if(k == SB::tree_count_ -1)
311  {
312  depths.push_back(double(k+1)/double(SB::tree_count_));
313  return false;
314  }
315  if(k < 10)
316  {
317  return false;
318  }
319  int index = argMax(prob);
320  int n_a = prob[index];
321  int n_b = prob[(index+1)%2];
322  int n_tilde = (SB::tree_count_ - n_a + n_b);
323  double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
324  vigra_precondition(p_a <= 1, "probability should be smaller than 1");
325  double cum_val = 0;
326  int c = 0;
327  // std::cerr << "prob: " << p_a << std::endl;
328  if(n_a <= 0)n_a = 0;
329  if(n_b <= 0)n_b = 0;
330  for(int ii = 0; ii <= n_b + n_a;++ii)
331  {
332 // std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
333  cum_val += binomial(n_b + n_a, ii, p_a);
334  if(cum_val >= 1 -alpha_)
335  {
336  c = ii;
337  break;
338  }
339  }
340 // std::cerr << c << " " << n_a << " " << n_b << " " << p_a << alpha_ << std::endl;
341  if(c < n_a)
342  {
343  depths.push_back(double(k+1)/double(SB::tree_count_));
344  return true;
345  }
346 
347  return false;
348  }
349 };
350 
351 /**Probabilistic Stopping criteria. (toChange)
352  *
353  * Can only be used in a two class setting
354  *
355  * Stop if the probability that the decision will change after seeing all trees falls under
356  * a specified value alpha.
357  * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
358  */
359 class StopIfProb : public StopBase
360 {
361 public:
362  double alpha_;
363  MultiArrayView<2, double> n_choose_k;
364 
365 
366  /** Constructor
367  * \param alpha specify alpha (=proportion) value
368  * \param nck_ Matrix with precomputed values for n choose k
369  * nck_(n, k) is n choose k.
370  */
372  :
373  alpha_(alpha),
374  n_choose_k(nck_)
375  {}
376  typedef StopBase SB;
377  /**ArrayVector that will contain the fraction of trees that was visited before terminating
378  */
380 
381  double binomial(int N, int k, double p)
382  {
383 // return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
384  return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
385  }
386 
387  template<class WeightIter, class T, class C>
388  bool after_prediction(WeightIter, int k, MultiArrayView<2, T, C> prob, double)
389  {
390  if(k == SB::tree_count_ -1)
391  {
392  depths.push_back(double(k+1)/double(SB::tree_count_));
393  return false;
394  }
395  if(k <= 10)
396  {
397  return false;
398  }
399  int index = argMax(prob);
400  int n_a = prob[index];
401  int n_b = prob[(index+1)%2];
402  int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
403  int n_tilde = SB::tree_count_ - (n_a +n_b);
404  if(n_tilde <= 0) n_tilde = 0;
405  if(n_needed <= 0) n_needed = 0;
406  double p = 0;
407  for(int ii = n_needed; ii < n_tilde; ++ii)
408  p += binomial(n_tilde, ii, 0.5);
409 
410  if(p >= 1-alpha_)
411  {
412  depths.push_back(double(k+1)/double(SB::tree_count_));
413  return true;
414  }
415 
416  return false;
417  }
418 };
419 
420 
421 class DepthAndSizeStopping: public StopBase
422 {
423 public:
424  int max_depth_;
425  int min_size_;
426 
427  int max_depth_reached; //for debug maximum reached depth
428 
429  DepthAndSizeStopping()
430  : max_depth_(NumericTraits<int>::max()), min_size_(0)
431  {}
432 
433  /** Constructor DepthAndSize Criterion
434  * Stop growing the tree if a certain depth or size is reached or make a
435  * leaf if the node is smaller than a certain size. Note this is checked
436  * before the split so it is still possible that smaller leafs are created
437  */
438 
439  DepthAndSizeStopping(int depth, int size) :
440  max_depth_(depth <= 0 ? NumericTraits<int>::max() : depth),
441  min_size_(size)
442  {}
443 
444  template<class T>
445  void set_external_parameters(ProblemSpec<T> const &,
446  int /*tree_count*/ = 0, bool /* is_weighted_ */= false)
447  {}
448 
449  template<class Region>
450  bool operator()(Region& region)
451  {
452  if (region.depth() > max_depth_)
453  throw std::runtime_error("violation in the stopping criterion");
454 
455  return (region.depth() >= max_depth_) || (region.size() < min_size_) ;
456  }
457 
458  template<class WeightIter, class T, class C>
459  bool after_prediction(WeightIter, int /* k */,
460  MultiArrayView<2, T, C> const &/* prob */, double /* totalCt */)
461  {
462  return true;
463  }
464 };
465 
466 } //namespace vigra;
467 #endif //RF_EARLY_STOPPING_P_HXX
Definition: rf_earlystopping.hxx:279
problem specification class for the random forest.
Definition: rf_common.hxx:538
StopAfterVoteCount(double proportion)
Definition: rf_earlystopping.hxx:115
V power(const V &x)
Exponentiation to a positive integer power by squaring.
Definition: mathutil.hxx:427
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:379
NormTraits< MultiArrayView >::NormType norm(int type=2, bool useSquaredNorm=true) const
Definition: multi_array.hxx:2372
Definition: rf_earlystopping.hxx:61
Definition: multi_fwd.hxx:63
StopIfConverging(double thresh, int num=10)
Definition: rf_earlystopping.hxx:168
StopIfProb(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:371
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_earlystopping.hxx:153
Definition: rf_earlystopping.hxx:26
Definition: rf_earlystopping.hxx:105
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Definition: rf_earlystopping.hxx:221
StopAfterTree(double max_tree)
Definition: rf_earlystopping.hxx:73
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
Definition: rf_earlystopping.hxx:359
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:298
StopIfBinTest(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:289
StopIfMargin(double proportion)
Definition: rf_earlystopping.hxx:231

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.1 (Fri May 19 2017)