XCSF  1.4.7
XCSF learning classifier system
pybind_callback_earlystop.h
Go to the documentation of this file.
1 /*
2  * This program is free software: you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation, either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * This program is distributed in the hope that it will be useful,
8  * but WITHOUT ANY WARRANTY; without even the implied warranty of
9  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10  * GNU General Public License for more details.
11  *
12  * You should have received a copy of the GNU General Public License
13  * along with this program. If not, see <http://www.gnu.org/licenses/>.
14  */
15 
24 #pragma once
25 
26 #include <limits>
27 #include <pybind11/numpy.h>
28 #include <pybind11/pybind11.h>
29 
30 namespace py = pybind11;
31 
32 extern "C" {
33 #include "xcsf.h"
34 }
35 
36 #include "pybind_callback.h"
37 #include "pybind_utils.h"
38 
43 {
44  public:
57  double min_delta, int start_from, bool verbose) :
64  {
65  std::ostringstream err;
66  std::string str = monitor.cast<std::string>();
67  if (str != "train" && str != "val") {
68  err << "invalid metric to monitor: " << str << std::endl;
69  throw std::invalid_argument(err.str());
70  }
71  if (patience < 0) {
72  err << "patience cannot be negative" << std::endl;
73  throw std::invalid_argument(err.str());
74  }
75  if (min_delta < 0) {
76  err << "min_delta cannot be negative" << std::endl;
77  throw std::invalid_argument(err.str());
78  }
79  }
80 
85  void
86  store(struct XCSF *xcsf)
87  {
88  do_restore = true;
90  if (verbose) {
91  std::ostringstream status;
92  status << get_timestamp() << " EarlyStoppingCallback: ";
93  status << std::fixed << std::setprecision(5) << best_error;
94  status << " best error at " << best_trial << " trials";
95  py::print(status.str());
96  }
97  }
98 
103  void
104  retrieve(struct XCSF *xcsf)
105  {
106  do_restore = false;
108  if (verbose) {
109  std::ostringstream status;
110  status << get_timestamp() << " EarlyStoppingCallback: ";
111  status << "restoring system from trial " << best_trial;
112  status << " with error=";
113  status << std::fixed << std::setprecision(5) << best_error;
114  py::print(status.str());
115  }
116  }
117 
124  bool
125  run(struct XCSF *xcsf, py::dict metrics) override
126  {
127  py::list data = metrics[monitor];
128  py::list trials = metrics["trials"];
129  const double current_error = py::cast<double>(data[data.size() - 1]);
130  const int current_trial = py::cast<int>(trials[trials.size() - 1]);
131  if (current_trial < start_from) {
132  return false;
133  }
134  if (current_error < best_error - min_delta) {
135  best_error = current_error;
136  best_trial = current_trial;
137  if (restore) {
138  store(xcsf);
139  }
140  }
141  if (current_trial - patience > best_trial) {
142  if (verbose) {
143  std::ostringstream status;
144  status << get_timestamp() << " EarlyStoppingCallback: stopping";
145  py::print(status.str());
146  }
147  if (restore) {
148  retrieve(xcsf);
149  }
150  return true;
151  }
152  return false;
153  }
154 
159  void
160  finish(struct XCSF *xcsf) override
161  {
162  if (restore && do_restore) {
163  retrieve(xcsf);
164  }
165  }
166 
167  private:
168  py::str monitor;
169  int patience;
170  bool restore;
171  double min_delta;
173  bool verbose;
174 
175  double best_error = std::numeric_limits<double>::max();
176  int best_trial = 0;
177  bool do_restore = false;
178 };
Interface for Callbacks.
Callback to stop training when a certain metric has stopped improving.
py::str monitor
Name of the metric to monitor.
bool do_restore
Whether the population needs to be restored.
int patience
Stop training after this many trials with no improvement.
double min_delta
Minimum change to qualify as an improvement.
void retrieve(struct XCSF *xcsf)
Retrieves best XCSF population in memory.
int start_from
Trials to wait before starting to monitor.
bool verbose
Whether to display messages when an action is taken.
void finish(struct XCSF *xcsf) override
Executes any tasks at the end of fitting.
bool restore
Whether to restore the best population.
int best_trial
Trial number the best error was observed.
EarlyStoppingCallback(py::str monitor, int patience, bool restore, double min_delta, int start_from, bool verbose)
Constructs a new early stopping callback.
bool run(struct XCSF *xcsf, py::dict metrics) override
Checks whether early stopping criteria has been met.
void store(struct XCSF *xcsf)
Stores best XCSF population in memory.
Definition: __init__.py:1
Interface for callbacks.
Utilities for Python library.
std::string get_timestamp()
Returns a formatted string for displaying time.
Definition: pybind_utils.h:31
XCSF data structure.
Definition: xcsf.h:85
void xcsf_store_pset(struct XCSF *xcsf)
Stores the current population.
Definition: xcsf.c:195
void xcsf_retrieve_pset(struct XCSF *xcsf)
Retrieves the previously stored population.
Definition: xcsf.c:213
XCSF data structures.