--- title: Writing your own callbacks authors: Rick Chao, Francois Chollet date-created: 2019/03/20 last-modified: 2023/06/25 description: Complete guide to writing new Keras callbacks. output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Writing your own callbacks} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ## Introduction A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include `keras.callbacks.TensorBoard` to visualize training progress and results with TensorBoard, or `keras.callbacks.ModelCheckpoint` to periodically save your model during training. In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started. ## Setup ``` r library(keras3) ``` ## Keras callbacks overview All callbacks subclass the `keras.callbacks.Callback` class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument `callbacks`) to the following model methods: - `fit()` - `evaluate()` - `predict()` ## An overview of callback methods ### Global methods #### `on_(train|test|predict)_begin(logs = NULL)` Called at the beginning of `fit`/`evaluate`/`predict`. #### `on_(train|test|predict)_end(logs = NULL)` Called at the end of `fit`/`evaluate`/`predict`. ### Batch-level methods for training/testing/predicting #### `on_(train|test|predict)_batch_begin(batch, logs = NULL)` Called right before processing a batch during training/testing/predicting. #### `on_(train|test|predict)_batch_end(batch, logs = NULL)` Called at the end of training/testing/predicting a batch. Within this method, `logs` is a named list containing the metrics results. ### Epoch-level methods (training only) #### `on_epoch_begin(epoch, logs = NULL)` Called at the beginning of an epoch during training. #### `on_epoch_end(epoch, logs = NULL)` Called at the end of an epoch during training. ## A basic example Let's take a look at a concrete example. To get started, let's import tensorflow and define a simple Sequential Keras model: ``` r # Define the Keras model to add callbacks to get_model <- function() { model <- keras_model_sequential() model |> layer_dense(units = 1) model |> compile( optimizer = optimizer_rmsprop(learning_rate = 0.1), loss = "mean_squared_error", metrics = "mean_absolute_error" ) model } ``` Then, load the MNIST data for training and testing from Keras datasets API: ``` r # Load example MNIST data and pre-process it mnist <- dataset_mnist() flatten_and_rescale <- function(x) { x <- array_reshape(x, c(-1, 784)) x <- x / 255 x } mnist$train$x <- flatten_and_rescale(mnist$train$x) mnist$test$x <- flatten_and_rescale(mnist$test$x) # limit to 1000 samples n <- 1000 mnist$train$x <- mnist$train$x[1:n,] mnist$train$y <- mnist$train$y[1:n] mnist$test$x <- mnist$test$x[1:n,] mnist$test$y <- mnist$test$y[1:n] ``` Now, define a simple custom callback that logs: - When `fit`/`evaluate`/`predict` starts & ends - When each epoch starts & ends - When each training batch starts & ends - When each evaluation (test) batch starts & ends - When each inference (prediction) batch starts & ends ``` r show <- function(msg, logs) { cat(glue::glue(msg, .envir = parent.frame()), "got logs: ", sep = "; ") str(logs); cat("\n") } callback_custom <- Callback( "CustomCallback", on_train_begin = \(logs = NULL) show("Starting training", logs), on_epoch_begin = \(epoch, logs = NULL) show("Start epoch {epoch} of training", logs), on_train_batch_begin = \(batch, logs = NULL) show("...Training: start of batch {batch}", logs), on_train_batch_end = \(batch, logs = NULL) show("...Training: end of batch {batch}", logs), on_epoch_end = \(epoch, logs = NULL) show("End epoch {epoch} of training", logs), on_train_end = \(logs = NULL) show("Stop training", logs), on_test_begin = \(logs = NULL) show("Start testing", logs), on_test_batch_begin = \(batch, logs = NULL) show("...Evaluating: start of batch {batch}", logs), on_test_batch_end = \(batch, logs = NULL) show("...Evaluating: end of batch {batch}", logs), on_test_end = \(logs = NULL) show("Stop testing", logs), on_predict_begin = \(logs = NULL) show("Start predicting", logs), on_predict_end = \(logs = NULL) show("Stop predicting", logs), on_predict_batch_begin = \(batch, logs = NULL) show("...Predicting: start of batch {batch}", logs), on_predict_batch_end = \(batch, logs = NULL) show("...Predicting: end of batch {batch}", logs), ) ``` Let's try it out: ``` r model <- get_model() model |> fit( mnist$train$x, mnist$train$y, batch_size = 128, epochs = 2, verbose = 0, validation_split = 0.5, callbacks = list(callback_custom()) ) ``` ``` ## Starting training; got logs: Named list() ## ## Start epoch 1 of training; got logs: Named list() ## ## ...Training: start of batch 1; got logs: Named list() ## ## ...Training: end of batch 1; got logs: List of 2 ## $ loss : num 25.9 ## $ mean_absolute_error: num 4.19 ## ## ...Training: start of batch 2; got logs: Named list() ## ## ...Training: end of batch 2; got logs: List of 2 ## $ loss : num 433 ## $ mean_absolute_error: num 15.5 ## ## ...Training: start of batch 3; got logs: Named list() ## ## ...Training: end of batch 3; got logs: List of 2 ## $ loss : num 297 ## $ mean_absolute_error: num 11.8 ## ## ...Training: start of batch 4; got logs: Named list() ## ## ...Training: end of batch 4; got logs: List of 2 ## $ loss : num 231 ## $ mean_absolute_error: num 9.68 ## ## Start testing; got logs: Named list() ## ## ...Evaluating: start of batch 1; got logs: Named list() ## ## ...Evaluating: end of batch 1; got logs: List of 2 ## $ loss : num 8.1 ## $ mean_absolute_error: num 2.3 ## ## ...Evaluating: start of batch 2; got logs: Named list() ## ## ...Evaluating: end of batch 2; got logs: List of 2 ## $ loss : num 7.58 ## $ mean_absolute_error: num 2.23 ## ## ...Evaluating: start of batch 3; got logs: Named list() ## ## ...Evaluating: end of batch 3; got logs: List of 2 ## $ loss : num 7.38 ## $ mean_absolute_error: num 2.21 ## ## ...Evaluating: start of batch 4; got logs: Named list() ## ## ...Evaluating: end of batch 4; got logs: List of 2 ## $ loss : num 7.3 ## $ mean_absolute_error: num 2.21 ## ## Stop testing; got logs: List of 2 ## $ loss : num 7.3 ## $ mean_absolute_error: num 2.21 ## ## End epoch 1 of training; got logs: List of 4 ## $ loss : num 231 ## $ mean_absolute_error : num 9.68 ## $ val_loss : num 7.3 ## $ val_mean_absolute_error: num 2.21 ## ## Start epoch 2 of training; got logs: Named list() ## ## ...Training: start of batch 1; got logs: Named list() ## ## ...Training: end of batch 1; got logs: List of 2 ## $ loss : num 7.44 ## $ mean_absolute_error: num 2.27 ## ## ...Training: start of batch 2; got logs: Named list() ## ## ...Training: end of batch 2; got logs: List of 2 ## $ loss : num 6.81 ## $ mean_absolute_error: num 2.16 ## ## ...Training: start of batch 3; got logs: Named list() ## ## ...Training: end of batch 3; got logs: List of 2 ## $ loss : num 6.12 ## $ mean_absolute_error: num 2.06 ## ## ...Training: start of batch 4; got logs: Named list() ## ## ...Training: end of batch 4; got logs: List of 2 ## $ loss : num 6.08 ## $ mean_absolute_error: num 2.04 ## ## Start testing; got logs: Named list() ## ## ...Evaluating: start of batch 1; got logs: Named list() ## ## ...Evaluating: end of batch 1; got logs: List of 2 ## $ loss : num 5.54 ## $ mean_absolute_error: num 1.92 ## ## ...Evaluating: start of batch 2; got logs: Named list() ## ## ...Evaluating: end of batch 2; got logs: List of 2 ## $ loss : num 5.31 ## $ mean_absolute_error: num 1.87 ## ## ...Evaluating: start of batch 3; got logs: Named list() ## ## ...Evaluating: end of batch 3; got logs: List of 2 ## $ loss : num 5.11 ## $ mean_absolute_error: num 1.8 ## ## ...Evaluating: start of batch 4; got logs: Named list() ## ## ...Evaluating: end of batch 4; got logs: List of 2 ## $ loss : num 5.15 ## $ mean_absolute_error: num 1.82 ## ## Stop testing; got logs: List of 2 ## $ loss : num 5.15 ## $ mean_absolute_error: num 1.82 ## ## End epoch 2 of training; got logs: List of 4 ## $ loss : num 6.08 ## $ mean_absolute_error : num 2.04 ## $ val_loss : num 5.15 ## $ val_mean_absolute_error: num 1.82 ## ## Stop training; got logs: List of 4 ## $ loss : num 6.08 ## $ mean_absolute_error : num 2.04 ## $ val_loss : num 5.15 ## $ val_mean_absolute_error: num 1.82 ``` ``` r res <- model |> evaluate( mnist$test$x, mnist$test$y, batch_size = 128, verbose = 0, callbacks = list(callback_custom()) ) ``` ``` ## Start testing; got logs: Named list() ## ## ...Evaluating: start of batch 1; got logs: Named list() ## ## ...Evaluating: end of batch 1; got logs: List of 2 ## $ loss : num 5.2 ## $ mean_absolute_error: num 1.84 ## ## ...Evaluating: start of batch 2; got logs: Named list() ## ## ...Evaluating: end of batch 2; got logs: List of 2 ## $ loss : num 4.62 ## $ mean_absolute_error: num 1.73 ## ## ...Evaluating: start of batch 3; got logs: Named list() ## ## ...Evaluating: end of batch 3; got logs: List of 2 ## $ loss : num 4.61 ## $ mean_absolute_error: num 1.74 ## ## ...Evaluating: start of batch 4; got logs: Named list() ## ## ...Evaluating: end of batch 4; got logs: List of 2 ## $ loss : num 4.65 ## $ mean_absolute_error: num 1.75 ## ## ...Evaluating: start of batch 5; got logs: Named list() ## ## ...Evaluating: end of batch 5; got logs: List of 2 ## $ loss : num 4.84 ## $ mean_absolute_error: num 1.77 ## ## ...Evaluating: start of batch 6; got logs: Named list() ## ## ...Evaluating: end of batch 6; got logs: List of 2 ## $ loss : num 4.76 ## $ mean_absolute_error: num 1.76 ## ## ...Evaluating: start of batch 7; got logs: Named list() ## ## ...Evaluating: end of batch 7; got logs: List of 2 ## $ loss : num 4.74 ## $ mean_absolute_error: num 1.76 ## ## ...Evaluating: start of batch 8; got logs: Named list() ## ## ...Evaluating: end of batch 8; got logs: List of 2 ## $ loss : num 4.67 ## $ mean_absolute_error: num 1.75 ## ## Stop testing; got logs: List of 2 ## $ loss : num 4.67 ## $ mean_absolute_error: num 1.75 ``` ``` r res <- model |> predict( mnist$test$x, batch_size = 128, verbose = 0, callbacks = list(callback_custom()) ) ``` ``` ## Start predicting; got logs: Named list() ## ## ...Predicting: start of batch 1; got logs: Named list() ## ## ...Predicting: end of batch 1; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 2; got logs: Named list() ## ## ...Predicting: end of batch 2; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 3; got logs: Named list() ## ## ...Predicting: end of batch 3; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 4; got logs: Named list() ## ## ...Predicting: end of batch 4; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 5; got logs: Named list() ## ## ...Predicting: end of batch 5; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 6; got logs: Named list() ## ## ...Predicting: end of batch 6; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 7; got logs: Named list() ## ## ...Predicting: end of batch 7; got logs: List of 1 ## $ outputs: ## ## ...Predicting: start of batch 8; got logs: Named list() ## ## ...Predicting: end of batch 8; got logs: List of 1 ## $ outputs: ## ## Stop predicting; got logs: Named list() ``` ### Usage of `logs` list The `logs` named list contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error. ``` r callback_print_loss_and_mae <- Callback( "LossAndErrorPrintingCallback", on_train_batch_end = function(batch, logs = NULL) cat(sprintf("Up to batch %i, the average loss is %7.2f.\n", batch, logs$loss)), on_test_batch_end = function(batch, logs = NULL) cat(sprintf("Up to batch %i, the average loss is %7.2f.\n", batch, logs$loss)), on_epoch_end = function(epoch, logs = NULL) cat(sprintf( "The average loss for epoch %2i is %9.2f and mean absolute error is %7.2f.\n", epoch, logs$loss, logs$mean_absolute_error )) ) model <- get_model() model |> fit( mnist$train$x, mnist$train$y, epochs = 2, verbose = 0, batch_size = 128, callbacks = list(callback_print_loss_and_mae()) ) ``` ``` ## Up to batch 1, the average loss is 25.12. ## Up to batch 2, the average loss is 398.92. ## Up to batch 3, the average loss is 274.04. ## Up to batch 4, the average loss is 208.32. ## Up to batch 5, the average loss is 168.15. ## Up to batch 6, the average loss is 141.31. ## Up to batch 7, the average loss is 122.19. ## Up to batch 8, the average loss is 110.05. ## The average loss for epoch 1 is 110.05 and mean absolute error is 5.79. ## Up to batch 1, the average loss is 4.71. ## Up to batch 2, the average loss is 4.74. ## Up to batch 3, the average loss is 4.81. ## Up to batch 4, the average loss is 5.07. ## Up to batch 5, the average loss is 5.08. ## Up to batch 6, the average loss is 5.09. ## Up to batch 7, the average loss is 5.19. ## Up to batch 8, the average loss is 5.51. ## The average loss for epoch 2 is 5.51 and mean absolute error is 1.90. ``` ``` r res = model |> evaluate( mnist$test$x, mnist$test$y, verbose = 0, batch_size = 128, callbacks = list(callback_print_loss_and_mae()) ) ``` ``` ## Up to batch 1, the average loss is 15.86. ## Up to batch 2, the average loss is 16.13. ## Up to batch 3, the average loss is 16.02. ## Up to batch 4, the average loss is 16.11. ## Up to batch 5, the average loss is 16.23. ## Up to batch 6, the average loss is 16.68. ## Up to batch 7, the average loss is 16.61. ## Up to batch 8, the average loss is 16.54. ``` For more information about callbacks, you can check out the [Keras callback API documentation](https://keras3.posit.co/reference/index.html#callbacks). ## Usage of `self$model` attribute In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: `self$model`. Here are of few of the things you can do with `self$model` in a callback: - Set `self$model$stop_training <- TRUE` to immediately interrupt training. - Mutate hyperparameters of the optimizer (available as `self$model$optimizer`), such as `self$model$optimizer$learning_rate`. - Save the model at period intervals. - Record the output of `model |> predict()` on a few test samples at the end of each epoch, to use as a sanity check during training. - Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time. - etc. Let's see this in action in a couple of examples. ## Examples of Keras callback applications ### Early stopping at minimum loss This first example shows the creation of a `Callback` that stops training when the minimum of loss has been reached, by setting the attribute `self$model$stop_training` (boolean). Optionally, you can provide an argument `patience` to specify how many epochs we should wait before stopping after having reached a local minimum. `callback_early_stopping()` provides a more complete and general implementation. ``` r callback_early_stopping_at_min_loss <- Callback( "EarlyStoppingAtMinLoss", `__doc__` = "Stop training when the loss is at its min, i.e. the loss stops decreasing. Arguments: patience: Number of epochs to wait after min has been hit. After this number of no improvement, training stops. ", initialize = function(patience = 0) { super$initialize() self$patience <- patience # best_weights to store the weights at which the minimum loss occurs. self$best_weights <- NULL }, on_train_begin = function(logs = NULL) { # The number of epoch it has waited when loss is no longer minimum. self$wait <- 0 # The epoch the training stops at. self$stopped_epoch <- 0 # Initialize the best as infinity. self$best <- Inf }, on_epoch_end = function(epoch, logs = NULL) { current <- logs$loss if (current < self$best) { self$best <- current self$wait <- 0L # Record the best weights if current results is better (less). self$best_weights <- get_weights(self$model) } else { add(self$wait) <- 1L if (self$wait >= self$patience) { self$stopped_epoch <- epoch self$model$stop_training <- TRUE cat("Restoring model weights from the end of the best epoch.\n") model$set_weights(self$best_weights) } } }, on_train_end = function(logs = NULL) if (self$stopped_epoch > 0) cat(sprintf("Epoch %05d: early stopping\n", self$stopped_epoch + 1)) ) `add<-` <- `+` model <- get_model() model |> fit( mnist$train$x, mnist$train$y, epochs = 30, batch_size = 64, verbose = 0, callbacks = list(callback_print_loss_and_mae(), callback_early_stopping_at_min_loss()) ) ``` ``` ## Up to batch 1, the average loss is 30.54. ## Up to batch 2, the average loss is 513.27. ## Up to batch 3, the average loss is 352.60. ## Up to batch 4, the average loss is 266.37. ## Up to batch 5, the average loss is 214.68. ## Up to batch 6, the average loss is 179.97. ## Up to batch 7, the average loss is 155.06. ## Up to batch 8, the average loss is 136.59. ## Up to batch 9, the average loss is 121.96. ## Up to batch 10, the average loss is 110.28. ## Up to batch 11, the average loss is 100.72. ## Up to batch 12, the average loss is 92.71. ## Up to batch 13, the average loss is 85.95. ## Up to batch 14, the average loss is 80.21. ## Up to batch 15, the average loss is 75.17. ## Up to batch 16, the average loss is 72.48. ## The average loss for epoch 1 is 72.48 and mean absolute error is 4.08. ## Up to batch 1, the average loss is 7.98. ## Up to batch 2, the average loss is 9.92. ## Up to batch 3, the average loss is 12.88. ## Up to batch 4, the average loss is 16.61. ## Up to batch 5, the average loss is 20.49. ## Up to batch 6, the average loss is 26.14. ## Up to batch 7, the average loss is 30.44. ## Up to batch 8, the average loss is 33.76. ## Up to batch 9, the average loss is 36.32. ## Up to batch 10, the average loss is 35.26. ## Up to batch 11, the average loss is 34.22. ## Up to batch 12, the average loss is 33.53. ## Up to batch 13, the average loss is 32.84. ## Up to batch 14, the average loss is 31.80. ## Up to batch 15, the average loss is 31.39. ## Up to batch 16, the average loss is 31.45. ## The average loss for epoch 2 is 31.45 and mean absolute error is 4.82. ## Up to batch 1, the average loss is 39.60. ## Up to batch 2, the average loss is 41.95. ## Up to batch 3, the average loss is 41.29. ## Up to batch 4, the average loss is 36.77. ## Up to batch 5, the average loss is 32.08. ## Up to batch 6, the average loss is 28.17. ## Up to batch 7, the average loss is 25.33. ## Up to batch 8, the average loss is 23.56. ## Up to batch 9, the average loss is 22.28. ## Up to batch 10, the average loss is 21.22. ## Up to batch 11, the average loss is 20.87. ## Up to batch 12, the average loss is 22.25. ## Up to batch 13, the average loss is 25.08. ## Up to batch 14, the average loss is 27.87. ## Up to batch 15, the average loss is 31.72. ## Up to batch 16, the average loss is 33.21. ## The average loss for epoch 3 is 33.21 and mean absolute error is 4.79. ## Restoring model weights from the end of the best epoch. ## Epoch 00004: early stopping ``` ### Learning rate scheduling In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training. See `keras$callbacks$LearningRateScheduler` for a more general implementations (in RStudio, press F1 while the cursor is over `LearningRateScheduler` and a browser will open to [this page](https://www.tensorflow.org/versions/r2.5/api_docs/python/tf/keras/callbacks/LearningRateScheduler)). ``` r callback_custom_learning_rate_scheduler <- Callback( "CustomLearningRateScheduler", `__doc__` = "Learning rate scheduler which sets the learning rate according to schedule. Arguments: schedule: a function that takes an epoch index (integer, indexed from 0) and current learning rate as inputs and returns a new learning rate as output (float). ", initialize = function(schedule) { super$initialize() self$schedule <- schedule }, on_epoch_begin = function(epoch, logs = NULL) { ## When in doubt about what types of objects are in scope (e.g., self$model) ## use a debugger to interact with the actual objects at the console! # browser() if (!"learning_rate" %in% names(self$model$optimizer)) stop('Optimizer must have a "learning_rate" attribute.') # # Get the current learning rate from model's optimizer. # use as.numeric() to convert the keras variablea to an R numeric lr <- as.numeric(self$model$optimizer$learning_rate) # # Call schedule function to get the scheduled learning rate. scheduled_lr <- self$schedule(epoch, lr) # # Set the value back to the optimizer before this epoch starts optimizer <- self$model$optimizer optimizer$learning_rate <- scheduled_lr cat(sprintf("\nEpoch %03d: Learning rate is %6.4f.\n", epoch, scheduled_lr)) } ) LR_SCHEDULE <- tibble::tribble( ~start_epoch, ~learning_rate, 0, 0.1, 3, 0.05, 6, 0.01, 9, 0.005, 12, 0.001, ) last <- function(x) x[length(x)] lr_schedule <- function(epoch, learning_rate) { "Helper function to retrieve the scheduled learning rate based on epoch." with(LR_SCHEDULE, learning_rate[last(which(epoch >= start_epoch))]) } model <- get_model() model |> fit( mnist$train$x, mnist$train$y, epochs = 14, batch_size = 64, verbose = 0, callbacks = list( callback_print_loss_and_mae(), callback_custom_learning_rate_scheduler(lr_schedule) ) ) ``` ``` ## ## Epoch 001: Learning rate is 0.1000. ## Up to batch 1, the average loss is 29.36. ## Up to batch 2, the average loss is 513.95. ## Up to batch 3, the average loss is 352.70. ## Up to batch 4, the average loss is 266.46. ## Up to batch 5, the average loss is 214.73. ## Up to batch 6, the average loss is 180.00. ## Up to batch 7, the average loss is 155.05. ## Up to batch 8, the average loss is 136.64. ## Up to batch 9, the average loss is 121.97. ## Up to batch 10, the average loss is 110.30. ## Up to batch 11, the average loss is 100.76. ## Up to batch 12, the average loss is 92.74. ## Up to batch 13, the average loss is 85.95. ## Up to batch 14, the average loss is 80.18. ## Up to batch 15, the average loss is 75.11. ## Up to batch 16, the average loss is 72.38. ## The average loss for epoch 1 is 72.38 and mean absolute error is 4.04. ## ## Epoch 002: Learning rate is 0.1000. ## Up to batch 1, the average loss is 6.95. ## Up to batch 2, the average loss is 8.71. ## Up to batch 3, the average loss is 11.42. ## Up to batch 4, the average loss is 15.15. ## Up to batch 5, the average loss is 19.28. ## Up to batch 6, the average loss is 25.54. ## Up to batch 7, the average loss is 30.38. ## Up to batch 8, the average loss is 33.95. ## Up to batch 9, the average loss is 36.58. ## Up to batch 10, the average loss is 35.46. ## Up to batch 11, the average loss is 34.34. ## Up to batch 12, the average loss is 33.51. ## Up to batch 13, the average loss is 32.67. ## Up to batch 14, the average loss is 31.54. ## Up to batch 15, the average loss is 31.05. ## Up to batch 16, the average loss is 31.09. ## The average loss for epoch 2 is 31.09 and mean absolute error is 4.77. ## ## Epoch 003: Learning rate is 0.0500. ## Up to batch 1, the average loss is 40.40. ## Up to batch 2, the average loss is 22.33. ## Up to batch 3, the average loss is 16.18. ## Up to batch 4, the average loss is 13.09. ## Up to batch 5, the average loss is 11.48. ## Up to batch 6, the average loss is 10.21. ## Up to batch 7, the average loss is 9.22. ## Up to batch 8, the average loss is 8.70. ## Up to batch 9, the average loss is 8.16. ## Up to batch 10, the average loss is 7.80. ## Up to batch 11, the average loss is 7.50. ## Up to batch 12, the average loss is 7.17. ## Up to batch 13, the average loss is 6.89. ## Up to batch 14, the average loss is 6.70. ## Up to batch 15, the average loss is 6.52. ## Up to batch 16, the average loss is 6.54. ## The average loss for epoch 3 is 6.54 and mean absolute error is 1.93. ## ## Epoch 004: Learning rate is 0.0500. ## Up to batch 1, the average loss is 8.74. ## Up to batch 2, the average loss is 8.34. ## Up to batch 3, the average loss is 9.09. ## Up to batch 4, the average loss is 9.72. ## Up to batch 5, the average loss is 10.48. ## Up to batch 6, the average loss is 11.69. ## Up to batch 7, the average loss is 11.83. ## Up to batch 8, the average loss is 11.56. ## Up to batch 9, the average loss is 11.24. ## Up to batch 10, the average loss is 10.84. ## Up to batch 11, the average loss is 10.66. ## Up to batch 12, the average loss is 10.44. ## Up to batch 13, the average loss is 10.21. ## Up to batch 14, the average loss is 10.06. ## Up to batch 15, the average loss is 10.00. ## Up to batch 16, the average loss is 10.20. ## The average loss for epoch 4 is 10.20 and mean absolute error is 2.71. ## ## Epoch 005: Learning rate is 0.0500. ## Up to batch 1, the average loss is 17.26. ## Up to batch 2, the average loss is 14.09. ## Up to batch 3, the average loss is 12.67. ## Up to batch 4, the average loss is 11.44. ## Up to batch 5, the average loss is 10.54. ## Up to batch 6, the average loss is 10.10. ## Up to batch 7, the average loss is 9.53. ## Up to batch 8, the average loss is 9.17. ## Up to batch 9, the average loss is 8.78. ## Up to batch 10, the average loss is 8.49. ## Up to batch 11, the average loss is 8.50. ## Up to batch 12, the average loss is 8.59. ## Up to batch 13, the average loss is 8.68. ## Up to batch 14, the average loss is 8.86. ## Up to batch 15, the average loss is 9.17. ## Up to batch 16, the average loss is 9.53. ## The average loss for epoch 5 is 9.53 and mean absolute error is 2.58. ## ## Epoch 006: Learning rate is 0.0100. ## Up to batch 1, the average loss is 17.04. ## Up to batch 2, the average loss is 14.85. ## Up to batch 3, the average loss is 11.53. ## Up to batch 4, the average loss is 9.65. ## Up to batch 5, the average loss is 8.44. ## Up to batch 6, the average loss is 7.50. ## Up to batch 7, the average loss is 6.74. ## Up to batch 8, the average loss is 6.56. ## Up to batch 9, the average loss is 6.18. ## Up to batch 10, the average loss is 5.87. ## Up to batch 11, the average loss is 5.63. ## Up to batch 12, the average loss is 5.45. ## Up to batch 13, the average loss is 5.23. ## Up to batch 14, the average loss is 5.12. ## Up to batch 15, the average loss is 4.96. ## Up to batch 16, the average loss is 4.91. ## The average loss for epoch 6 is 4.91 and mean absolute error is 1.67. ## ## Epoch 007: Learning rate is 0.0100. ## Up to batch 1, the average loss is 3.65. ## Up to batch 2, the average loss is 3.04. ## Up to batch 3, the average loss is 2.88. ## Up to batch 4, the average loss is 2.85. ## Up to batch 5, the average loss is 2.88. ## Up to batch 6, the average loss is 2.81. ## Up to batch 7, the average loss is 2.70. ## Up to batch 8, the average loss is 2.96. ## Up to batch 9, the average loss is 2.96. ## Up to batch 10, the average loss is 2.93. ## Up to batch 11, the average loss is 2.95. ## Up to batch 12, the average loss is 2.98. ## Up to batch 13, the average loss is 2.97. ## Up to batch 14, the average loss is 3.01. ## Up to batch 15, the average loss is 3.00. ## Up to batch 16, the average loss is 3.05. ## The average loss for epoch 7 is 3.05 and mean absolute error is 1.34. ## ## Epoch 008: Learning rate is 0.0100. ## Up to batch 1, the average loss is 3.69. ## Up to batch 2, the average loss is 3.21. ## Up to batch 3, the average loss is 3.00. ## Up to batch 4, the average loss is 2.91. ## Up to batch 5, the average loss is 2.94. ## Up to batch 6, the average loss is 2.85. ## Up to batch 7, the average loss is 2.72. ## Up to batch 8, the average loss is 2.95. ## Up to batch 9, the average loss is 2.97. ## Up to batch 10, the average loss is 2.93. ## Up to batch 11, the average loss is 2.96. ## Up to batch 12, the average loss is 2.98. ## Up to batch 13, the average loss is 2.99. ## Up to batch 14, the average loss is 3.05. ## Up to batch 15, the average loss is 3.08. ## Up to batch 16, the average loss is 3.14. ## The average loss for epoch 8 is 3.14 and mean absolute error is 1.36. ## ## Epoch 009: Learning rate is 0.0050. ## Up to batch 1, the average loss is 3.71. ## Up to batch 2, the average loss is 2.93. ## Up to batch 3, the average loss is 2.76. ## Up to batch 4, the average loss is 2.70. ## Up to batch 5, the average loss is 2.76. ## Up to batch 6, the average loss is 2.69. ## Up to batch 7, the average loss is 2.57. ## Up to batch 8, the average loss is 2.79. ## Up to batch 9, the average loss is 2.80. ## Up to batch 10, the average loss is 2.77. ## Up to batch 11, the average loss is 2.79. ## Up to batch 12, the average loss is 2.80. ## Up to batch 13, the average loss is 2.78. ## Up to batch 14, the average loss is 2.81. ## Up to batch 15, the average loss is 2.80. ## Up to batch 16, the average loss is 2.83. ## The average loss for epoch 9 is 2.83 and mean absolute error is 1.28. ## ## Epoch 010: Learning rate is 0.0050. ## Up to batch 1, the average loss is 3.02. ## Up to batch 2, the average loss is 2.69. ## Up to batch 3, the average loss is 2.58. ## Up to batch 4, the average loss is 2.57. ## Up to batch 5, the average loss is 2.65. ## Up to batch 6, the average loss is 2.60. ## Up to batch 7, the average loss is 2.48. ## Up to batch 8, the average loss is 2.72. ## Up to batch 9, the average loss is 2.74. ## Up to batch 10, the average loss is 2.71. ## Up to batch 11, the average loss is 2.74. ## Up to batch 12, the average loss is 2.75. ## Up to batch 13, the average loss is 2.74. ## Up to batch 14, the average loss is 2.77. ## Up to batch 15, the average loss is 2.77. ## Up to batch 16, the average loss is 2.80. ## The average loss for epoch 10 is 2.80 and mean absolute error is 1.27. ## ## Epoch 011: Learning rate is 0.0050. ## Up to batch 1, the average loss is 3.01. ## Up to batch 2, the average loss is 2.69. ## Up to batch 3, the average loss is 2.58. ## Up to batch 4, the average loss is 2.56. ## Up to batch 5, the average loss is 2.63. ## Up to batch 6, the average loss is 2.58. ## Up to batch 7, the average loss is 2.47. ## Up to batch 8, the average loss is 2.70. ## Up to batch 9, the average loss is 2.72. ## Up to batch 10, the average loss is 2.69. ## Up to batch 11, the average loss is 2.71. ## Up to batch 12, the average loss is 2.72. ## Up to batch 13, the average loss is 2.71. ## Up to batch 14, the average loss is 2.75. ## Up to batch 15, the average loss is 2.74. ## Up to batch 16, the average loss is 2.77. ## The average loss for epoch 11 is 2.77 and mean absolute error is 1.27. ## ## Epoch 012: Learning rate is 0.0010. ## Up to batch 1, the average loss is 2.96. ## Up to batch 2, the average loss is 2.53. ## Up to batch 3, the average loss is 2.47. ## Up to batch 4, the average loss is 2.46. ## Up to batch 5, the average loss is 2.54. ## Up to batch 6, the average loss is 2.48. ## Up to batch 7, the average loss is 2.39. ## Up to batch 8, the average loss is 2.60. ## Up to batch 9, the average loss is 2.62. ## Up to batch 10, the average loss is 2.59. ## Up to batch 11, the average loss is 2.61. ## Up to batch 12, the average loss is 2.62. ## Up to batch 13, the average loss is 2.60. ## Up to batch 14, the average loss is 2.64. ## Up to batch 15, the average loss is 2.62. ## Up to batch 16, the average loss is 2.64. ## The average loss for epoch 12 is 2.64 and mean absolute error is 1.24. ## ## Epoch 013: Learning rate is 0.0010. ## Up to batch 1, the average loss is 2.82. ## Up to batch 2, the average loss is 2.46. ## Up to batch 3, the average loss is 2.42. ## Up to batch 4, the average loss is 2.42. ## Up to batch 5, the average loss is 2.50. ## Up to batch 6, the average loss is 2.45. ## Up to batch 7, the average loss is 2.36. ## Up to batch 8, the average loss is 2.57. ## Up to batch 9, the average loss is 2.59. ## Up to batch 10, the average loss is 2.57. ## Up to batch 11, the average loss is 2.59. ## Up to batch 12, the average loss is 2.60. ## Up to batch 13, the average loss is 2.59. ## Up to batch 14, the average loss is 2.62. ## Up to batch 15, the average loss is 2.61. ## Up to batch 16, the average loss is 2.63. ## The average loss for epoch 13 is 2.63 and mean absolute error is 1.23. ## ## Epoch 014: Learning rate is 0.0010. ## Up to batch 1, the average loss is 2.79. ## Up to batch 2, the average loss is 2.44. ## Up to batch 3, the average loss is 2.40. ## Up to batch 4, the average loss is 2.41. ## Up to batch 5, the average loss is 2.49. ## Up to batch 6, the average loss is 2.44. ## Up to batch 7, the average loss is 2.34. ## Up to batch 8, the average loss is 2.56. ## Up to batch 9, the average loss is 2.58. ## Up to batch 10, the average loss is 2.56. ## Up to batch 11, the average loss is 2.58. ## Up to batch 12, the average loss is 2.59. ## Up to batch 13, the average loss is 2.58. ## Up to batch 14, the average loss is 2.61. ## Up to batch 15, the average loss is 2.60. ## Up to batch 16, the average loss is 2.62. ## The average loss for epoch 14 is 2.62 and mean absolute error is 1.23. ``` ### Built-in Keras callbacks Be sure to check out the existing Keras callbacks by reading the [API docs](https://keras3.posit.co/reference/index.html#callbacks). Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!