XCSF 1.4.8
XCSF learning classifier system
Loading...
Searching...
No Matches
pybind_callback_checkpoint.h
Go to the documentation of this file.
1/*
2 * This program is free software: you can redistribute it and/or modify
3 * it under the terms of the GNU General Public License as published by
4 * the Free Software Foundation, either version 3 of the License, or
5 * (at your option) any later version.
6 *
7 * This program is distributed in the hope that it will be useful,
8 * but WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 * GNU General Public License for more details.
11 *
12 * You should have received a copy of the GNU General Public License
13 * along with this program. If not, see <http://www.gnu.org/licenses/>.
14 */
15
24#pragma once
25
26#include <limits>
27#include <pybind11/numpy.h>
28#include <pybind11/pybind11.h>
29
30namespace py = pybind11;
31
32extern "C" {
33#include "xcsf.h"
34}
35
36#include "pybind_callback.h"
37#include "pybind_utils.h"
38
43{
44 public:
53 CheckpointCallback(py::str monitor, std::string filename,
54 bool save_best_only, int save_freq, bool verbose) :
60 {
61 std::ostringstream err;
62 std::string str = monitor.cast<std::string>();
63 if (str != "train" && str != "val") {
64 err << "invalid metric to monitor: " << str << std::endl;
65 throw std::invalid_argument(err.str());
66 }
67 if (save_freq < 0) {
68 err << "save_freq cannot be negative" << std::endl;
69 throw std::invalid_argument(err.str());
70 }
71 }
72
77 void
78 save(struct XCSF *xcsf)
79 {
80 xcsf_save(xcsf, filename.c_str());
81 std::ostringstream status;
82 status << get_timestamp() << " CheckpointCallback: ";
83 status << "saved " << filename;
84 py::print(status.str());
85 }
86
93 bool
94 run(struct XCSF *xcsf, py::dict metrics) override
95 {
96 py::list data = metrics[monitor];
97 py::list trials = metrics["trials"];
98 const double current_error = py::cast<double>(data[data.size() - 1]);
99 const int current_trial = py::cast<int>(trials[trials.size() - 1]);
100 if (current_trial >= save_trial + save_freq) {
101 if (!save_best_only || (current_error < best_error)) {
102 save_trial = current_trial;
103 save(xcsf);
104 }
105 if (current_error < best_error) {
106 best_error = current_error;
107 }
108 }
109 return false;
110 }
111
116 void
117 finish(struct XCSF *xcsf) override
118 {
119 if (!save_best_only) {
120 save(xcsf);
121 }
122 }
123
124 private:
125 py::str monitor;
126 std::string filename;
129 bool verbose;
130
131 double best_error = std::numeric_limits<double>::max();
132 int save_trial = 0;
133};
Interface for Callbacks.
Callback to save XCSF at some frequency.
bool save_best_only
Whether to only save the best population.
int save_trial
Trial number the last checkpoint was made.
void finish(struct XCSF *xcsf) override
Executes any tasks at the end of fitting.
CheckpointCallback(py::str monitor, std::string filename, bool save_best_only, int save_freq, bool verbose)
Constructs a new checkpoint callback.
bool run(struct XCSF *xcsf, py::dict metrics) override
Performs callback operations.
int save_freq
Trial frequency to (possibly) make checkpoints.
std::string filename
Name of the file to save XCSF.
bool verbose
Whether to display messages when an action is taken.
void save(struct XCSF *xcsf)
Saves the state of XCSF.
py::str monitor
Name of the metric to monitor.
Interface for callbacks.
Utilities for Python library.
std::string get_timestamp()
Returns a formatted string for displaying time.
XCSF data structure.
Definition xcsf.h:85
size_t xcsf_save(const struct XCSF *xcsf, const char *filename)
Writes the current state of XCSF to a file.
Definition xcsf.c:90
XCSF data structures.