Callbacks - AI

Callback to draw predicted image

class Images(Callback):
    def __init__(self, val_data, **kwargs):
        super(Images, self).__init__(**kwargs)
        self.validation_data = (np.array(val_data), )

    def on_epoch_end(self, epoch, logs):
        indices = np.random.randint(self.validation_data[0].shape[0], size=8)
        test_data = self.validation_data[0][indices]
        pred_data = np.clip(self.model.predict(test_data), 0, 1)
        wandb.log({
            "examples": [
                wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i))
                for i, data in enumerate(test_data)]
        }, commit=False)

        

Specify callback for fit function

model.fit(
    X_train,
    is_five_train,
    epochs=config.epochs,
    validation_data=(X_test, is_five_test),
    callbacks=[WandbCallback(labels=labels, input_type="image"), Images()]
)
        

Example Callback for stopping training early upon reaching high accuracy

# Haven't technically tested
class HighAccuracyStoppage(tf.keras.callbacks.Callback):
    def on_batch_end(self, epoch, logs={}):
        if(logs.get('accuracy')>0.99):
        print("\nReached 99% accuracy so cancelling training!")
        self.model.stop_training = True