1 #ifndef RF_EARLY_STOPPING_P_HXX
2 #define RF_EARLY_STOPPING_P_HXX
5 #include "rf_common.hxx"
14 T
power(T
const & in,
int n)
16 T result = NumericTraits<T>::one();
17 for(
int ii = 0; ii < n ;++ii)
35 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
38 is_weighted_ = is_weighted;
39 tree_count_ = tree_count;
49 template<
class WeightIter,
class T,
class C>
52 template<
class WeightIter,
class T,
class C>
79 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
81 max_tree_ =
ceil(max_tree_p * tree_count);
82 SB::set_external_parameters(prob, tree_count, is_weighted);
85 template<
class WeightIter,
class T,
class C>
86 bool after_prediction(WeightIter,
int k, MultiArrayView<2, T, C>
const & ,
double )
88 if(k == SB::tree_count_ -1)
90 depths.push_back(
double(k+1)/
double(SB::tree_count_));
95 depths.push_back(
double(k+1)/
double(SB::tree_count_));
117 proportion_(proportion)
120 template<
class WeightIter,
class T,
class C>
123 if(k == SB::tree_count_ -1)
125 depths.push_back(
double(k+1)/
double(SB::tree_count_));
132 if(prob[
argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
134 depths.push_back(
double(k+1)/
double(SB::tree_count_));
140 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
142 depths.push_back(
double(k+1)/
double(SB::tree_count_));
175 void set_external_parameters(
ProblemSpec<T> const &prob,
int tree_count = 0,
bool is_weighted =
false)
179 SB::set_external_parameters(prob, tree_count, is_weighted);
181 template<
class WeightIter,
class T,
class C>
182 bool after_prediction(WeightIter,
int k, MultiArrayView<2, T, C>
const & prob,
double)
184 if(k == SB::tree_count_ -1)
186 depths.push_back(
double(k+1)/
double(SB::tree_count_));
192 last_/= last_.
norm(1);
198 cur_ /= cur_.
norm(1);
200 double nrm = last_.
norm();
203 depths.push_back(
double(k+1)/
double(SB::tree_count_));
233 proportion_(proportion)
236 template<
class WeightIter,
class T,
class C>
239 if(k == SB::tree_count_ -1)
241 depths.push_back(
double(k+1)/
double(SB::tree_count_));
245 double a = prob[
argMax(prob)];
247 double b = prob[
argMax(prob)];
249 double margin = a - b;
252 if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
254 depths.push_back(
double(k+1)/
double(SB::tree_count_));
260 if(prob[
argMax(prob)] > proportion_ * SB::tree_count_)
262 depths.push_back(
double(k+1)/
double(SB::tree_count_));
300 double binomial(
int N,
int k,
double p)
303 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
306 template<
class WeightIter,
class T,
class C>
307 bool after_prediction(WeightIter,
int k,
310 if(k == SB::tree_count_ -1)
312 depths.push_back(
double(k+1)/
double(SB::tree_count_));
320 int n_a = prob[index];
321 int n_b = prob[(index+1)%2];
322 int n_tilde = (SB::tree_count_ - n_a + n_b);
323 double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
324 vigra_precondition(p_a <= 1,
"probability should be smaller than 1");
330 for(
int ii = 0; ii <= n_b + n_a;++ii)
333 cum_val += binomial(n_b + n_a, ii, p_a);
334 if(cum_val >= 1 -alpha_)
343 depths.push_back(
double(k+1)/
double(SB::tree_count_));
381 double binomial(
int N,
int k,
double p)
384 return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
387 template<
class WeightIter,
class T,
class C>
390 if(k == SB::tree_count_ -1)
392 depths.push_back(
double(k+1)/
double(SB::tree_count_));
400 int n_a = prob[index];
401 int n_b = prob[(index+1)%2];
402 int n_needed =
ceil(
double(SB::tree_count_)/2.0)-n_a;
403 int n_tilde = SB::tree_count_ - (n_a +n_b);
404 if(n_tilde <= 0) n_tilde = 0;
405 if(n_needed <= 0) n_needed = 0;
407 for(
int ii = n_needed; ii < n_tilde; ++ii)
408 p += binomial(n_tilde, ii, 0.5);
412 depths.push_back(
double(k+1)/
double(SB::tree_count_));
421 class DepthAndSizeStopping:
public StopBase
427 int max_depth_reached;
429 DepthAndSizeStopping()
430 : max_depth_(NumericTraits<int>::max()), min_size_(0)
439 DepthAndSizeStopping(
int depth,
int size) :
440 max_depth_(depth <= 0 ? NumericTraits<int>::max() : depth),
445 void set_external_parameters(ProblemSpec<T>
const &,
446 int = 0,
bool =
false)
449 template<
class Region>
450 bool operator()(Region& region)
452 if (region.depth() > max_depth_)
453 throw std::runtime_error(
"violation in the stopping criterion");
455 return (region.depth() >= max_depth_) || (region.size() < min_size_) ;
458 template<
class WeightIter,
class T,
class C>
459 bool after_prediction(WeightIter,
int ,
460 MultiArrayView<2, T, C>
const &,
double )
467 #endif //RF_EARLY_STOPPING_P_HXX
Definition: rf_earlystopping.hxx:279
problem specification class for the random forest.
Definition: rf_common.hxx:538
StopAfterVoteCount(double proportion)
Definition: rf_earlystopping.hxx:115
V power(const V &x)
Exponentiation to a positive integer power by squaring.
Definition: mathutil.hxx:427
void reshape(const difference_type &shape)
Definition: multi_array.hxx:2861
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:379
NormTraits< MultiArrayView >::NormType norm(int type=2, bool useSquaredNorm=true) const
Definition: multi_array.hxx:2372
Definition: rf_earlystopping.hxx:61
Definition: multi_fwd.hxx:63
StopIfConverging(double thresh, int num=10)
Definition: rf_earlystopping.hxx:168
StopIfProb(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:371
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_earlystopping.hxx:153
Definition: rf_earlystopping.hxx:26
Definition: rf_earlystopping.hxx:105
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:704
Definition: rf_earlystopping.hxx:221
StopAfterTree(double max_tree)
Definition: rf_earlystopping.hxx:73
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
Definition: rf_earlystopping.hxx:359
ArrayVector< double > depths
Definition: rf_earlystopping.hxx:298
StopIfBinTest(double alpha, MultiArrayView< 2, double > nck_)
Definition: rf_earlystopping.hxx:289
StopIfMargin(double proportion)
Definition: rf_earlystopping.hxx:231