36 #ifndef VIGRA_RF_PREPROCESSING_HXX
37 #define VIGRA_RF_PREPROCESSING_HXX
40 #include <vigra/mathutil.hxx>
41 #include "rf_common.hxx"
62 template<
class Tag,
class LabelType,
class T1,
class C1,
class T2,
class C2>
77 switch(options.mtry_switch_)
80 ext_param.actual_mtry_ =
82 std::sqrt(
double(ext_param.column_count_))
87 ext_param.actual_mtry_ =
88 int(1+(
std::log(
double(ext_param.column_count_))
92 ext_param.actual_mtry_ =
93 options.mtry_func_(ext_param.column_count_);
96 ext_param.actual_mtry_ = ext_param.column_count_;
99 ext_param.actual_mtry_ =
103 switch(options.training_set_calc_switch_)
106 ext_param.actual_msample_ =
107 options.training_set_size_;
109 case RF_PROPORTIONAL:
110 ext_param.actual_msample_ =
111 static_cast<int>(
std::ceil(options.training_set_proportion_ *
112 ext_param.row_count_));
115 ext_param.actual_msample_ =
116 options.training_set_func_(ext_param.row_count_);
119 vigra_precondition(1!= 1,
"unexpected error");
127 template<
unsigned int N,
class T,
class C>
128 bool contains_nan(MultiArrayView<N, T, C>
const & in)
130 typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
131 Iter i = in.begin(), end = in.end();
133 if(isnan(NumericTraits<T>::toRealPromote(*i)))
140 template<
unsigned int N,
class T,
class C>
141 bool contains_inf(MultiArrayView<N, T, C>
const & in)
143 if(!std::numeric_limits<T>::has_infinity)
145 typedef typename MultiArrayView<N, T, C>::const_iterator Iter;
146 Iter i = in.begin(), end = in.end();
148 if(
abs(*i) == std::numeric_limits<T>::infinity())
161 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
162 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
165 typedef Int32 LabelInt;
181 vigra_precondition(!detail::contains_nan(features),
"RandomForest(): Feature matrix "
183 vigra_precondition(!detail::contains_nan(response),
"RandomForest(): Response "
185 vigra_precondition(!detail::contains_inf(features),
"RandomForest(): Feature matrix "
187 vigra_precondition(!detail::contains_inf(response),
"RandomForest(): Response "
190 ext_param.column_count_ = features.
shape(1);
191 ext_param.row_count_ = features.
shape(0);
192 ext_param.problem_type_ = CLASSIFICATION;
193 ext_param.used_ =
true;
194 intLabels_.reshape(response.
shape());
197 if(ext_param.class_count_ == 0)
201 std::set<T2> labelToInt;
203 labelToInt.insert(response(k,0));
204 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
205 ext_param.
classes_(tmp_.begin(), tmp_.end());
209 if(std::find(ext_param.classes.
begin(), ext_param.classes.
end(), response(k,0)) == ext_param.classes.
end())
211 throw std::runtime_error(
"RandomForest(): invalid label in training data.");
214 intLabels_(k, 0) = std::find(ext_param.classes.
begin(), ext_param.classes.
end(), response(k,0))
215 - ext_param.classes.
begin();
218 if(ext_param.class_weights_.
size() == 0)
221 tmp(static_cast<std::size_t>(ext_param.class_count_),
222 NumericTraits<T2>::one());
227 detail::fill_external_parameters(options, ext_param);
230 strata_ = intLabels_;
268 template<
class LabelType,
class T1,
class C1,
class T2,
class C2>
269 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
292 ext_param_(ext_param)
295 ext_param.column_count_ = features.
shape(1);
296 ext_param.row_count_ = features.
shape(0);
297 ext_param.problem_type_ = REGRESSION;
298 ext_param.used_ =
true;
299 detail::fill_external_parameters(options, ext_param);
300 vigra_precondition(!detail::contains_nan(features),
"Processor(): Feature Matrix "
302 vigra_precondition(!detail::contains_nan(response),
"Processor(): Response "
304 vigra_precondition(!detail::contains_inf(features),
"Processor(): Feature Matrix "
306 vigra_precondition(!detail::contains_inf(response),
"Processor(): Response "
309 ext_param.response_size_ = response.
shape(1);
310 ext_param.class_count_ = response_.shape(1);
311 std::vector<T2> tmp_(ext_param.class_count_, 0);
312 ext_param.
classes_(tmp_.begin(), tmp_.end());
337 #endif //VIGRA_RF_PREPROCESSING_HXX
ArrayVectorView< double > strata_prob()
Definition: rf_preprocessing.hxx:257
MultiArrayView< 2, LabelInt > response()
Definition: rf_preprocessing.hxx:243
Definition: rf_preprocessing.hxx:63
const difference_type & shape() const
Definition: multi_array.hxx:1648
Definition: array_vector.hxx:76
const_iterator begin() const
Definition: array_vector.hxx:223
problem specification class for the random forest.
Definition: rf_common.hxx:538
MultiArrayView< 2, T1, C1 > & features()
Definition: rf_preprocessing.hxx:317
Main MultiArray class containing the memory management.
Definition: multi_array.hxx:2474
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
MultiArrayView< 2, T1, C1 > const & features()
Definition: rf_preprocessing.hxx:236
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:828
Definition: array_vector.hxx:58
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
MultiArrayView< 2, T2, C2 > & response()
Definition: rf_preprocessing.hxx:324
TinyVector< MultiArrayIndex, N > type
Definition: multi_shape.hxx:272
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:844
linalg::TemporaryMatrix< T > log(MultiArrayView< 2, T, C > const &v)
MultiArray< 2, int > & strata()
Definition: rf_preprocessing.hxx:331
ArrayVectorView< LabelInt > strata()
Definition: rf_preprocessing.hxx:250
FFTWComplex< R >::NormType abs(const FFTWComplex< R > &a)
absolute value (= magnitude)
Definition: fftw3.hxx:1002
Options object for the random forest.
Definition: rf_common.hxx:170
const_iterator end() const
Definition: array_vector.hxx:237
size_type size() const
Definition: array_vector.hxx:358
int ceil(FixedPoint< IntBits, FracBits > v)
rounding up.
Definition: fixedpoint.hxx:675
int floor(FixedPoint< IntBits, FracBits > v)
rounding down.
Definition: fixedpoint.hxx:667
SquareRootTraits< FixedPoint< IntBits, FracBits > >::SquareRootResult sqrt(FixedPoint< IntBits, FracBits > v)
square root.
Definition: fixedpoint.hxx:616