[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest_common.hxx
1/************************************************************************/
2/* */
3/* Copyright 2014-2015 by Ullrich Koethe and Philip Schill */
4/* */
5/* This file is part of the VIGRA computer vision library. */
6/* The VIGRA Website is */
7/* http://hci.iwr.uni-heidelberg.de/vigra/ */
8/* Please direct questions, bug reports, and contributions to */
9/* ullrich.koethe@iwr.uni-heidelberg.de or */
10/* vigra@informatik.uni-hamburg.de */
11/* */
12/* Permission is hereby granted, free of charge, to any person */
13/* obtaining a copy of this software and associated documentation */
14/* files (the "Software"), to deal in the Software without */
15/* restriction, including without limitation the rights to use, */
16/* copy, modify, merge, publish, distribute, sublicense, and/or */
17/* sell copies of the Software, and to permit persons to whom the */
18/* Software is furnished to do so, subject to the following */
19/* conditions: */
20/* */
21/* The above copyright notice and this permission notice shall be */
22/* included in all copies or substantial portions of the */
23/* Software. */
24/* */
25/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27/* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28/* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29/* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30/* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31/* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32/* OTHER DEALINGS IN THE SOFTWARE. */
33/* */
34/************************************************************************/
35#ifndef VIGRA_RF3_COMMON_HXX
36#define VIGRA_RF3_COMMON_HXX
37
38#include <iterator>
39#include <type_traits>
40#include <cmath>
41#include <numeric>
42
43#include "../multi_array.hxx"
44#include "../mathutil.hxx"
45
46namespace vigra
47{
48
49namespace rf3
50{
51
52/** \addtogroup MachineLearning
53**/
54//@{
55
56template <typename T>
57struct LessEqualSplitTest
58{
59public:
60 LessEqualSplitTest(size_t dim = 0, T const & val = 0)
61 :
62 dim_(dim),
63 val_(val)
64 {}
65
66 template<typename FEATURES>
67 size_t operator()(FEATURES const & features) const
68 {
69 return features(dim_) <= val_ ? 0 : 1;
70 }
71
72 size_t dim_;
73 T val_;
74};
75
76
77
78struct ArgMaxAcc
79{
80public:
81 typedef size_t input_type;
82
83 template <typename ITER, typename OUTITER>
84 void operator()(ITER begin, ITER end, OUTITER out)
85 {
86 std::fill(buffer_.begin(), buffer_.end(), 0);
87 size_t max_v = 0;
88 size_t n = 0;
89 for (ITER it = begin; it != end; ++it)
90 {
91 size_t const v = *it;
92 if (v >= buffer_.size())
93 {
94 buffer_.resize(v+1, 0);
95 }
96 ++buffer_[v];
97 ++n;
98 max_v = std::max(max_v, v);
99 }
100 for (size_t i = 0; i <= max_v; ++i)
101 {
102 *out = buffer_[i] / static_cast<double>(n);
103 ++out;
104 }
105 }
106private:
107 std::vector<size_t> buffer_;
108};
109
110
111
112template <typename VALUETYPE>
113struct ArgMaxVectorAcc
114{
115public:
116 typedef VALUETYPE value_type;
117 typedef std::vector<value_type> input_type;
118 template <typename ITER, typename OUTITER>
119 void operator()(ITER begin, ITER end, OUTITER out)
120 {
121 std::fill(buffer_.begin(), buffer_.end(), 0);
122 size_t max_v = 0;
123 for (ITER it = begin; it != end; ++it)
124 {
125 input_type const & vec = *it;
126 if (vec.size() >= buffer_.size())
127 {
128 buffer_.resize(vec.size(), 0);
129 }
130 value_type const n = std::accumulate(vec.begin(), vec.end(), static_cast<value_type>(0));
131 for (size_t i = 0; i < vec.size(); ++i)
132 {
133 buffer_[i] += vec[i] / static_cast<double>(n);
134 }
135 max_v = std::max(vec.size()-1, max_v);
136 }
137 for (size_t i = 0; i <= max_v; ++i)
138 {
139 *out = buffer_[i];
140 ++out;
141 }
142 }
143 private:
144 std::vector<double> buffer_;
145};
146
147
148
149// struct LargestSumAcc
150// {
151// public:
152// typedef std::vector<size_t> input_type;
153// template <typename ITER>
154// size_t operator()(ITER begin, ITER end)
155// {
156// std::fill(buffer_.begin(), buffer_.end(), 0);
157// for (ITER it = begin; it != end; ++it)
158// {
159// auto const & v = *it;
160// if (v.size() > buffer_.size())
161// {
162// buffer_.resize(v.size(), 0);
163// }
164// for (size_t i = 0; i < v.size(); ++i)
165// {
166// buffer_[i] += v[i];
167// }
168// }
169// size_t max_label = 0;
170// size_t max_count = 0;
171// for (size_t i = 0; i < buffer_.size(); ++i)
172// {
173// if (buffer_[i] > max_count)
174// {
175// max_count = buffer_[i];
176// max_label = i;
177// }
178// }
179// return max_label;
180// }
181// private:
182// std::vector<size_t> buffer_;
183// };
184
185
186
187// struct ForestGarroteAcc
188// {
189// public:
190// typedef double input_type;
191// template <typename ITER, typename OUTITER>
192// void operator()(ITER begin, ITER end, OUTITER out)
193// {
194// double s = 0.0;
195// for (ITER it = begin; it != end; ++it)
196// {
197// s += *it;
198// }
199// if (s < 0.0)
200// s = 0.0;
201// else if (s > 1.0)
202// s = 1.0;
203// *out = 1.0-s;
204// ++out;
205// *out = s;
206// }
207// };
208
209
210
211namespace detail
212{
213
214 /// Abstract scorer that iterates over all split candidates, uses FUNCTOR to compute a score,
215 /// and saves the split with the minimum score.
216 template <typename FUNCTOR>
217 class GeneralScorer
218 {
219 public:
220
221 typedef FUNCTOR Functor;
222
223 GeneralScorer(std::vector<double> const & priors)
224 :
225 split_found_(false),
226 best_split_(0),
227 best_dim_(0),
228 best_score_(std::numeric_limits<double>::max()),
229 priors_(priors),
230 n_total_(std::accumulate(priors.begin(), priors.end(), 0.0))
231 {}
232
233 template <typename FEATURES, typename LABELS, typename WEIGHTS, typename ITER>
234 void operator()(
235 FEATURES const & features,
236 LABELS const & labels,
237 WEIGHTS const & weights,
238 ITER begin,
239 ITER end,
240 size_t dim
241 ){
242 if (begin == end)
243 return;
244
245 Functor score;
246
247 std::vector<double> counts(priors_.size(), 0.0);
248 double n_left = 0;
249 ITER next = begin;
250 ++next;
251 for (; next != end; ++begin, ++next)
252 {
253 // Move the label from the right side to the left side.
254 size_t const left_index = *begin;
255 size_t const right_index = *next;
256 size_t const label = static_cast<size_t>(labels(left_index));
257 counts[label] += weights[left_index];
258 n_left += weights[left_index];
259
260 // Skip if there is no new split.
261 auto const left = features(left_index, dim);
262 auto const right = features(right_index, dim);
263 if (left == right)
264 continue;
265
266 // Update the score.
267 split_found_ = true;
268 double const s = score(priors_, counts, n_total_, n_left);
269 bool const better_score = s < best_score_;
270 if (better_score)
271 {
272 best_score_ = s;
273 best_split_ = 0.5*(left+right);
274 best_dim_ = dim;
275 }
276 }
277 }
278
279 bool split_found_; // whether a split was found at all
280 double best_split_; // the threshold of the best split
281 size_t best_dim_; // the dimension of the best split
282 double best_score_; // the score of the best split
283
284 private:
285
286 std::vector<double> const priors_; // the weighted number of datapoints per class
287 double const n_total_; // the weighted number of datapoints
288 };
289
290} // namespace detail
291
292/// \brief Functor that computes the gini score.
293///
294/// This functor is typically selected indirectly by passing the value <tt>RF_GINI</tt>
295/// to vigra::rf3::RandomForestOptions::split().
297{
298public:
299 double operator()(std::vector<double> const & priors,
300 std::vector<double> const & counts, double n_total, double n_left) const
301 {
302 double const n_right = n_total - n_left;
303 double gini_left = 1.0;
304 double gini_right = 1.0;
305 for (size_t i = 0; i < counts.size(); ++i)
306 {
307 double const p_left = counts[i] / n_left;
308 double const p_right = (priors[i] - counts[i]) / n_right;
309 gini_left -= (p_left*p_left);
310 gini_right -= (p_right*p_right);
311 }
312 return n_left*gini_left + n_right*gini_right;
313 }
314
315 // needed for Gini-based variable importance calculation
316 template <typename LABELS, typename WEIGHTS, typename ITER>
317 static double region_score(LABELS const & labels, WEIGHTS const & weights, ITER begin, ITER end)
318 {
319 // Count the occurences.
320 std::vector<double> counts;
321 double total = 0.0;
322 for (auto it = begin; it != end; ++it)
323 {
324 auto const d = *it;
325 auto const lbl = labels[d];
326 if (counts.size() <= lbl)
327 {
328 counts.resize(lbl+1, 0.0);
329 }
330 counts[lbl] += weights[d];
331 total += weights[d];
332 }
333
334 // Compute the gini.
335 double gini = total;
336 for (auto x : counts)
337 {
338 gini -= x*x/total;
339 }
340 return gini;
341 }
342};
343
344/// \brief Functor that computes the entropy score.
345///
346/// This functor is typically selected indirectly by passing the value <tt>RF_ENTROPY</tt>
347/// to vigra::rf3::RandomForestOptions::split().
349{
350public:
351 double operator()(std::vector<double> const & priors, std::vector<double> const & counts, double n_total, double n_left) const
352 {
353 double const n_right = n_total - n_left;
354 double ig = 0;
355 for (size_t i = 0; i < counts.size(); ++i)
356 {
357 double c = counts[i];
358 if (c != 0)
359 ig -= c * std::log(c / n_left);
360
361 c = priors[i] - c;
362 if (c != 0)
363 ig -= c * std::log(c / n_right);
364 }
365 return ig;
366 }
367
368 template <typename LABELS, typename WEIGHTS, typename ITER>
369 double region_score(LABELS const & /*labels*/, WEIGHTS const & /*weights*/, ITER /*begin*/, ITER /*end*/) const
370 {
371 vigra_fail("EntropyScore::region_score(): Not implemented yet.");
372 return 0.0; // FIXME
373 }
374};
375
376/// \brief Functor that computes the Kolmogorov-Smirnov score.
377///
378/// Actually, it reutrns the negated KSD score, because we want to minimize.
379///
380/// This functor is typically selected indirectly by passing the value <tt>RF_KSD</tt>
381/// to vigra::rf3::RandomForestOptions::split().
383{
384public:
385 double operator()(std::vector<double> const & priors, std::vector<double> const & counts, double /*n_total*/, double /*n_left*/) const // Fix unused parameter warning, but leave in to not break compatibility with overall API
386 {
387 double const eps = 1e-10;
388 double nnz = 0;
389 std::vector<double> norm_counts(counts.size(), 0.0);
390 for (size_t i = 0; i < counts.size(); ++i)
391 {
392 if (priors[i] > eps)
393 {
394 norm_counts[i] = counts[i] / priors[i];
395 ++nnz;
396 }
397 }
398 if (nnz < eps)
399 return 0.0;
400
401 // NOTE to future self:
402 // In std::accumulate, it makes a huge difference whether you use 0 or 0.0 as init. Think about that before making changes.
403 double const mean = std::accumulate(norm_counts.begin(), norm_counts.end(), 0.0) / nnz;
404
405 // Compute the sum of the squared distances.
406 double ksd = 0.0;
407 for (size_t i = 0; i < norm_counts.size(); ++i)
408 {
409 if (priors[i] != 0)
410 {
411 double const v = (mean-norm_counts[i]);
412 ksd += v*v;
413 }
414 }
415 return -ksd;
416 }
417
418 template <typename LABELS, typename WEIGHTS, typename ITER>
419 double region_score(LABELS const & /*labels*/, WEIGHTS const & /*weights*/, ITER /*begin*/, ITER /*end*/) const
420 {
421 vigra_fail("KolmogorovSmirnovScore::region_score(): Region score not available for the Kolmogorov-Smirnov split.");
422 return 0.0;
423 }
424};
425
426// This struct holds the depth and the weighted number of datapoints per class of a single node.
427template <typename ARR>
428struct RFNodeDescription
429{
430public:
431 RFNodeDescription(size_t depth, ARR const & priors)
432 :
433 depth_(depth),
434 priors_(priors)
435 {}
436 size_t depth_;
437 ARR const & priors_;
438};
439
440
441
442// Return true if the given node is pure.
443template <typename LABELS, typename ITER>
444bool is_pure(LABELS const & /*labels*/, RFNodeDescription<ITER> const & desc)
445{
446 bool found = false;
447 for (auto n : desc.priors_)
448 {
449 if (n > 0)
450 {
451 if (found)
452 return false;
453 else
454 found = true;
455 }
456 }
457 return true;
458}
459
460/// @brief Random forest 'node purity' stop criterion.
461///
462/// Stop splitting a node when it contains only instanes of a single class.
464{
465public:
466 template <typename LABELS, typename ITER>
467 bool operator()(LABELS const & labels, RFNodeDescription<ITER> const & desc) const
468 {
469 return is_pure(labels, desc);
470 }
471};
472
473/// @brief Random forest 'maximum depth' stop criterion.
474///
475/// Stop splitting a node when the its depth reaches a given value or when it is pure.
477{
478public:
479 /// @brief Constructor: terminate tree construction at \a max_depth.
480 DepthStop(size_t max_depth)
481 :
482 max_depth_(max_depth)
483 {}
484
485 template <typename LABELS, typename ITER>
486 bool operator()(LABELS const & labels, RFNodeDescription<ITER> const & desc) const
487 {
488 if (desc.depth_ >= max_depth_)
489 return true;
490 else
491 return is_pure(labels, desc);
492 }
493 size_t max_depth_;
494};
495
496/// @brief Random forest 'number of datapoints' stop criterion.
497///
498/// Stop splitting a node when it contains too few instances or when it is pure.
500{
501public:
502 /// @brief Constructor: terminate tree construction when node contains less than \a min_n instances.
503 NumInstancesStop(size_t min_n)
504 :
505 min_n_(min_n)
506 {}
507
508 template <typename LABELS, typename ARR>
509 bool operator()(LABELS const & labels, RFNodeDescription<ARR> const & desc) const
510 {
511 typedef typename ARR::value_type value_type;
512 if (std::accumulate(desc.priors_.begin(), desc.priors_.end(), static_cast<value_type>(0)) <= min_n_)
513 return true;
514 else
515 return is_pure(labels, desc);
516 }
517 size_t min_n_;
518};
519
520/// @brief Random forest 'node complexity' stop criterion.
521///
522/// Stop splitting a node when it allows for too few different data arrangements.
523/// This includes purity, which offers only a sinlge data arrangement.
525{
526public:
527 /// @brief Constructor: stop when fewer than <tt>1/tau</tt> label arrangements are possible.
528 NodeComplexityStop(double tau = 0.001)
529 :
530 logtau_(std::log(tau))
531 {
532 vigra_precondition(tau > 0 && tau < 1, "NodeComplexityStop(): Tau must be in the open interval (0, 1).");
533 }
534
535 template <typename LABELS, typename ARR>
536 bool operator()(LABELS const & /*labels*/, RFNodeDescription<ARR> const & desc) // Fix unused parameter, but leave in for API compatability
537 {
538 typedef typename ARR::value_type value_type;
539
540 // Count the labels.
541 size_t const total = std::accumulate(desc.priors_.begin(), desc.priors_.end(), static_cast<value_type>(0));
542
543 // Compute log(prod_k(n_k!)).
544 size_t nnz = 0;
545 double lg = 0.0;
546 for (auto v : desc.priors_)
547 {
548 if (v > 0)
549 {
550 ++nnz;
551 lg += loggamma(static_cast<double>(v+1));
552 }
553 }
554 lg += loggamma(static_cast<double>(nnz+1));
555 lg -= loggamma(static_cast<double>(total+1));
556 if (nnz <= 1)
557 return true;
558
559 return lg >= logtau_;
560 }
561
562 double logtau_;
563};
564
565enum RandomForestOptionTags
566{
567 RF_SQRT,
568 RF_LOG,
569 RF_CONST,
570 RF_ALL,
571 RF_GINI,
572 RF_ENTROPY,
573 RF_KSD
574};
575
576
577/** \brief Options class for \ref vigra::rf3::RandomForest version 3.
578
579 <b>\#include</b> <vigra/random_forest_3.hxx><br/>
580 Namespace: vigra::rf3
581*/
582class RandomForestOptions
583{
584public:
585
586 RandomForestOptions()
587 :
588 tree_count_(255),
589 features_per_node_(0),
590 features_per_node_switch_(RF_SQRT),
591 bootstrap_sampling_(true),
592 resample_count_(0),
593 split_(RF_GINI),
594 max_depth_(0),
595 node_complexity_tau_(-1),
596 min_num_instances_(1),
597 use_stratification_(false),
598 n_threads_(-1),
599 class_weights_()
600 {}
601
602 /**
603 * @brief The number of trees.
604 *
605 * Default: 255
606 */
607 RandomForestOptions & tree_count(int p_tree_count)
608 {
609 tree_count_ = p_tree_count;
610 return *this;
611 }
612
613 /**
614 * @brief The number of features that are considered when computing the split.
615 *
616 * @param p_features_per_node the number of features
617 *
618 * Default: use sqrt of the total number of features.
619 */
620 RandomForestOptions & features_per_node(int p_features_per_node)
621 {
622 features_per_node_switch_ = RF_CONST;
623 features_per_node_ = p_features_per_node;
624 return *this;
625 }
626
627 /**
628 * @brief The number of features that are considered when computing the split.
629 *
630 * @param p_features_per_node_switch possible values: <br/>
631 <tt>vigra::rf3::RF_SQRT</tt> (use square root of total number of features, recommended for classification), <br/>
632 <tt>vigra::rf3::RF_LOG</tt> (use logarithm of total number of features, recommended for regression), <br/>
633 <tt>vigra::rf3::RF_ALL</tt> (use all features).
634 *
635 * Default: <tt>vigra::rf3::RF_SQRT</tt>
636 */
637 RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
638 {
639 vigra_precondition(p_features_per_node_switch == RF_SQRT ||
640 p_features_per_node_switch == RF_LOG ||
641 p_features_per_node_switch == RF_ALL,
642 "RandomForestOptions::features_per_node(): Input must be RF_SQRT, RF_LOG or RF_ALL.");
643 features_per_node_switch_ = p_features_per_node_switch;
644 return *this;
645 }
646
647 /**
648 * @brief Use bootstrap sampling.
649 *
650 * Default: true
651 */
652 RandomForestOptions & bootstrap_sampling(bool b)
653 {
654 bootstrap_sampling_ = b;
655 return *this;
656 }
657
658 /**
659 * @brief If resample_count is greater than zero, the split in each node is computed using only resample_count data points.
660 *
661 * Default: \a n = 0 (don't resample in every node)
662 */
663 RandomForestOptions & resample_count(size_t n)
664 {
665 resample_count_ = n;
666 bootstrap_sampling_ = false;
667 return *this;
668 }
669
670 /**
671 * @brief The split criterion.
672 *
673 * @param p_split possible values: <br/>
674 <tt>vigra::rf3::RF_GINI</tt> (use Gini criterion, \ref vigra::rf3::GiniScorer), <br/>
675 <tt>vigra::rf3::RF_ENTROPY</tt> (use entropy criterion, \ref vigra::rf3::EntropyScorer), <br/>
676 <tt>vigra::rf3::RF_KSD</tt> (use Kolmogorov-Smirnov criterion, \ref vigra::rf3::KSDScorer).
677 *
678 * Default: <tt>vigra::rf3::RF_GINI</tt>
679 */
680 RandomForestOptions & split(RandomForestOptionTags p_split)
681 {
682 vigra_precondition(p_split == RF_GINI ||
683 p_split == RF_ENTROPY ||
684 p_split == RF_KSD,
685 "RandomForestOptions::split(): Input must be RF_GINI, RF_ENTROPY or RF_KSD.");
686 split_ = p_split;
687 return *this;
688 }
689
690 /**
691 * @brief Do not split a node if its depth is greater or equal to max_depth.
692 *
693 * Default: \a d = 0 (don't use depth as a termination criterion)
694 */
695 RandomForestOptions & max_depth(size_t d)
696 {
697 max_depth_ = d;
698 return *this;
699 }
700
701 /**
702 * @brief Value of the node complexity termination criterion.
703 *
704 * Default: \a tau = -1 (don't use complexity as a termination criterion)
705 */
706 RandomForestOptions & node_complexity_tau(double tau)
707 {
708 node_complexity_tau_ = tau;
709 return *this;
710 }
711
712 /**
713 * @brief Do not split a node if it contains less than min_num_instances data points.
714 *
715 * Default: \a n = 1 (don't use instance count as a termination criterion)
716 */
717 RandomForestOptions & min_num_instances(size_t n)
718 {
719 min_num_instances_ = n;
720 return *this;
721 }
722
723 /**
724 * @brief Use stratification when creating the bootstrap samples.
725 *
726 * That is, preserve the proportion between the number of class instances exactly
727 * rather than on average.
728 *
729 * Default: false
730 */
731 RandomForestOptions & use_stratification(bool b)
732 {
733 use_stratification_ = b;
734 return *this;
735 }
736
737 /**
738 * @brief The number of threads that are used in training.
739 *
740 * \a n = -1 means use number of cores, \a n = 0 means single-threaded training.
741 *
742 * Default: \a n = -1 (use as many threads as there are cores in the machine).
743 */
744 RandomForestOptions & n_threads(int n)
745 {
746 n_threads_ = n;
747 return *this;
748 }
749
750 /**
751 * @brief Each datapoint is weighted by its class weight. By default, each class has weight 1.
752 * @details
753 * The classes in the random forest training have to follow a strict ordering. The weights must be given in that order.
754 * Example:
755 * You have the classes 3, 8 and 5 and use the vector {0.2, 0.3, 0.4} for the class weights.
756 * The ordering of the classes is 3, 5, 8, so class 3 will get weight 0.2, class 5 will get weight 0.3
757 * and class 8 will get weight 0.4.
758 */
759 RandomForestOptions & class_weights(std::vector<double> const & v)
760 {
761 class_weights_ = v;
762 return *this;
763 }
764
765 /**
766 * @brief Get the actual number of features per node.
767 *
768 * @param total the total number of features
769 *
770 * This function is normally only called internally before training is started.
771 */
772 size_t get_features_per_node(size_t total) const
773 {
774 if (features_per_node_switch_ == RF_SQRT)
775 return std::ceil(std::sqrt(total));
776 else if (features_per_node_switch_ == RF_LOG)
777 return std::ceil(std::log(total));
778 else if (features_per_node_switch_ == RF_CONST)
779 return features_per_node_;
780 else if (features_per_node_switch_ == RF_ALL)
781 return total;
782 vigra_fail("RandomForestOptions::get_features_per_node(): Unknown switch.");
783 return 0;
784 }
785
786 int tree_count_;
787 int features_per_node_;
788 RandomForestOptionTags features_per_node_switch_;
789 bool bootstrap_sampling_;
790 size_t resample_count_;
791 RandomForestOptionTags split_;
792 size_t max_depth_;
793 double node_complexity_tau_;
794 size_t min_num_instances_;
795 bool use_stratification_;
796 int n_threads_;
797 std::vector<double> class_weights_;
798
799};
800
801
802
803template <typename LabelType>
804class ProblemSpec
805{
806public:
807
809 :
810 num_features_(0),
811 num_instances_(0),
812 num_classes_(0),
813 distinct_classes_(),
814 actual_mtry_(0),
815 actual_msample_(0)
816 {}
817
818 ProblemSpec & num_features(size_t n)
819 {
820 num_features_ = n;
821 return *this;
822 }
823
824 ProblemSpec & num_instances(size_t n)
825 {
826 num_instances_ = n;
827 return *this;
828 }
829
830 ProblemSpec & num_classes(size_t n)
831 {
832 num_classes_ = n;
833 return *this;
834 }
835
836 ProblemSpec & distinct_classes(std::vector<LabelType> v)
837 {
838 distinct_classes_ = v;
839 num_classes_ = v.size();
840 return *this;
841 }
842
843 ProblemSpec & actual_mtry(size_t m)
844 {
845 actual_mtry_ = m;
846 return *this;
847 }
848
849 ProblemSpec & actual_msample(size_t m)
850 {
851 actual_msample_ = m;
852 return *this;
853 }
854
855 bool operator==(ProblemSpec const & other) const
856 {
857 #define COMPARE(field) if (field != other.field) return false;
858 COMPARE(num_features_);
859 COMPARE(num_instances_);
860 COMPARE(num_classes_);
861 COMPARE(distinct_classes_);
862 COMPARE(actual_mtry_);
863 COMPARE(actual_msample_);
864 #undef COMPARE
865 return true;
866 }
867
868 size_t num_features_;
869 size_t num_instances_;
870 size_t num_classes_;
871 std::vector<LabelType> distinct_classes_;
872 size_t actual_mtry_;
873 size_t actual_msample_;
874
875};
876
877//@}
878
879} // namespace rf3
880
881} // namespace vigra
882
883#endif
884
problem specification class for the random forest.
Definition rf_common.hxx:539
DepthStop(size_t max_depth)
Constructor: terminate tree construction at max_depth.
Definition random_forest_common.hxx:480
Functor that computes the entropy score.
Definition random_forest_common.hxx:349
Functor that computes the gini score.
Definition random_forest_common.hxx:297
Functor that computes the Kolmogorov-Smirnov score.
Definition random_forest_common.hxx:383
NodeComplexityStop(double tau=0.001)
Constructor: stop when fewer than 1/tau label arrangements are possible.
Definition random_forest_common.hxx:528
NumInstancesStop(size_t min_n)
Constructor: terminate tree construction when node contains less than min_n instances.
Definition random_forest_common.hxx:503
Random forest 'node purity' stop criterion.
Definition random_forest_common.hxx:464
RandomForestOptions & features_per_node(int p_features_per_node)
The number of features that are considered when computing the split.
Definition random_forest_common.hxx:620
RandomForestOptions & split(RandomForestOptionTags p_split)
The split criterion.
Definition random_forest_common.hxx:680
RandomForestOptions & resample_count(size_t n)
If resample_count is greater than zero, the split in each node is computed using only resample_count ...
Definition random_forest_common.hxx:663
RandomForestOptions & features_per_node(RandomForestOptionTags p_features_per_node_switch)
The number of features that are considered when computing the split.
Definition random_forest_common.hxx:637
RandomForestOptions & use_stratification(bool b)
Use stratification when creating the bootstrap samples.
Definition random_forest_common.hxx:731
RandomForestOptions & tree_count(int p_tree_count)
The number of trees.
Definition random_forest_common.hxx:607
RandomForestOptions & max_depth(size_t d)
Do not split a node if its depth is greater or equal to max_depth.
Definition random_forest_common.hxx:695
RandomForestOptions & bootstrap_sampling(bool b)
Use bootstrap sampling.
Definition random_forest_common.hxx:652
RandomForestOptions & n_threads(int n)
The number of threads that are used in training.
Definition random_forest_common.hxx:744
RandomForestOptions & node_complexity_tau(double tau)
Value of the node complexity termination criterion.
Definition random_forest_common.hxx:706
RandomForestOptions & class_weights(std::vector< double > const &v)
Each datapoint is weighted by its class weight. By default, each class has weight 1.
Definition random_forest_common.hxx:759
size_t get_features_per_node(size_t total) const
Get the actual number of features per node.
Definition random_forest_common.hxx:772
RandomForestOptions & min_num_instances(size_t n)
Do not split a node if it contains less than min_num_instances data points.
Definition random_forest_common.hxx:717
Random forest version 3.
Definition random_forest_3.hxx:66
double loggamma(double x)
The natural logarithm of the gamma function.
Definition mathutil.hxx:1603

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.12.1 (Thu Feb 27 2025)