27 #include <pybind11/numpy.h>
28 #include <pybind11/pybind11.h>
30 namespace py = pybind11;
65 std::ostringstream err;
66 std::string str =
monitor.cast<std::string>();
67 if (str !=
"train" && str !=
"val") {
68 err <<
"invalid metric to monitor: " << str << std::endl;
69 throw std::invalid_argument(err.str());
72 err <<
"patience cannot be negative" << std::endl;
73 throw std::invalid_argument(err.str());
76 err <<
"min_delta cannot be negative" << std::endl;
77 throw std::invalid_argument(err.str());
91 std::ostringstream status;
93 status << std::fixed << std::setprecision(5) <<
best_error;
94 status <<
" best error at " <<
best_trial <<
" trials";
95 py::print(status.str());
109 std::ostringstream status;
111 status <<
"restoring system from trial " <<
best_trial;
112 status <<
" with error=";
113 status << std::fixed << std::setprecision(5) <<
best_error;
114 py::print(status.str());
127 py::list data = metrics[
monitor];
128 py::list trials = metrics[
"trials"];
129 const double current_error = py::cast<double>(data[data.size() - 1]);
130 const int current_trial = py::cast<int>(trials[trials.size() - 1]);
143 std::ostringstream status;
144 status <<
get_timestamp() <<
" EarlyStoppingCallback: stopping";
145 py::print(status.str());
Callback to stop training when a certain metric has stopped improving.
py::str monitor
Name of the metric to monitor.
bool do_restore
Whether the population needs to be restored.
int patience
Stop training after this many trials with no improvement.
double min_delta
Minimum change to qualify as an improvement.
void retrieve(struct XCSF *xcsf)
Retrieves best XCSF population in memory.
int start_from
Trials to wait before starting to monitor.
bool verbose
Whether to display messages when an action is taken.
void finish(struct XCSF *xcsf) override
Executes any tasks at the end of fitting.
bool restore
Whether to restore the best population.
int best_trial
Trial number the best error was observed.
EarlyStoppingCallback(py::str monitor, int patience, bool restore, double min_delta, int start_from, bool verbose)
Constructs a new early stopping callback.
bool run(struct XCSF *xcsf, py::dict metrics) override
Checks whether early stopping criteria has been met.
double best_error
Best error.
void store(struct XCSF *xcsf)
Stores best XCSF population in memory.
Utilities for Python library.
std::string get_timestamp()
Returns a formatted string for displaying time.
void xcsf_store_pset(struct XCSF *xcsf)
Stores the current population.
void xcsf_retrieve_pset(struct XCSF *xcsf)
Retrieves the previously stored population.