XCSF 1.4.8
XCSF learning classifier system
Loading...
Searching...
No Matches
pybind_callback_earlystop.h
Go to the documentation of this file.
1/*
2 * This program is free software: you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation, either version 3 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program. If not, see <http://www.gnu.org/licenses/>.
14 */
15
24#pragma once
25
26#include <limits>
27#include <pybind11/numpy.h>
28#include <pybind11/pybind11.h>
29
30namespace py = pybind11;
31
32extern "C" {
33#include "xcsf.h"
34}
35
36#include "pybind_callback.h"
37#include "pybind_utils.h"
38
43{
44 public:
57 double min_delta, int start_from, bool verbose) :
64 {
65 std::ostringstream err;
66 std::string str = monitor.cast<std::string>();
67 if (str != "train" && str != "val") {
68 err << "invalid metric to monitor: " << str << std::endl;
69 throw std::invalid_argument(err.str());
70 }
71 if (patience < 0) {
72 err << "patience cannot be negative" << std::endl;
73 throw std::invalid_argument(err.str());
74 }
75 if (min_delta < 0) {
76 err << "min_delta cannot be negative" << std::endl;
77 throw std::invalid_argument(err.str());
78 }
79 }
80
85 void
86 store(struct XCSF *xcsf)
87 {
88 do_restore = true;
90 if (verbose) {
91 std::ostringstream status;
92 status << get_timestamp() << " EarlyStoppingCallback: ";
93 status << std::fixed << std::setprecision(5) << best_error;
94 status << " best error at " << best_trial << " trials";
95 py::print(status.str());
96 }
97 }
98
103 void
105 {
106 do_restore = false;
108 if (verbose) {
109 std::ostringstream status;
110 status << get_timestamp() << " EarlyStoppingCallback: ";
111 status << "restoring system from trial " << best_trial;
112 status << " with error=";
113 status << std::fixed << std::setprecision(5) << best_error;
114 py::print(status.str());
115 }
116 }
117
124 bool
125 run(struct XCSF *xcsf, py::dict metrics) override
126 {
127 py::list data = metrics[monitor];
128 py::list trials = metrics["trials"];
129 const double current_error = py::cast<double>(data[data.size() - 1]);
130 const int current_trial = py::cast<int>(trials[trials.size() - 1]);
131 if (current_trial < start_from) {
132 return false;
133 }
134 if (current_error < best_error - min_delta) {
135 best_error = current_error;
136 best_trial = current_trial;
137 if (restore) {
138 store(xcsf);
139 }
140 }
141 if (current_trial - patience > best_trial) {
142 if (verbose) {
143 std::ostringstream status;
144 status << get_timestamp() << " EarlyStoppingCallback: stopping";
145 py::print(status.str());
146 }
147 if (restore) {
148 retrieve(xcsf);
149 }
150 return true;
151 }
152 return false;
153 }
154
159 void
160 finish(struct XCSF *xcsf) override
161 {
162 if (restore && do_restore) {
163 retrieve(xcsf);
164 }
165 }
166
167 private:
168 py::str monitor;
170 bool restore;
171 double min_delta;
173 bool verbose;
174
175 double best_error = std::numeric_limits<double>::max();
176 int best_trial = 0;
177 bool do_restore = false;
178};
Interface for Callbacks.
Callback to stop training when a certain metric has stopped improving.
py::str monitor
Name of the metric to monitor.
bool do_restore
Whether the population needs to be restored.
int patience
Stop training after this many trials with no improvement.
double min_delta
Minimum change to qualify as an improvement.
void retrieve(struct XCSF *xcsf)
Retrieves best XCSF population in memory.
int start_from
Trials to wait before starting to monitor.
bool verbose
Whether to display messages when an action is taken.
void finish(struct XCSF *xcsf) override
Executes any tasks at the end of fitting.
bool restore
Whether to restore the best population.
int best_trial
Trial number the best error was observed.
EarlyStoppingCallback(py::str monitor, int patience, bool restore, double min_delta, int start_from, bool verbose)
Constructs a new early stopping callback.
bool run(struct XCSF *xcsf, py::dict metrics) override
Checks whether early stopping criteria has been met.
void store(struct XCSF *xcsf)
Stores best XCSF population in memory.
Interface for callbacks.
Utilities for Python library.
std::string get_timestamp()
Returns a formatted string for displaying time.
XCSF data structure.
Definition xcsf.h:85
void xcsf_store_pset(struct XCSF *xcsf)
Stores the current population.
Definition xcsf.c:195
void xcsf_retrieve_pset(struct XCSF *xcsf)
Retrieves the previously stored population.
Definition xcsf.c:213
XCSF data structures.