Callback to draw predicted image
class Images(Callback):
def __init__(self, val_data, **kwargs):
super(Images, self).__init__(**kwargs)
self.validation_data = (np.array(val_data), )
def on_epoch_end(self, epoch, logs):
indices = np.random.randint(self.validation_data[0].shape[0], size=8)
test_data = self.validation_data[0][indices]
pred_data = np.clip(self.model.predict(test_data), 0, 1)
wandb.log({
"examples": [
wandb.Image(np.hstack([data, pred_data[i]]), caption=str(i))
for i, data in enumerate(test_data)]
}, commit=False)
Specify callback for fit function
model.fit(
X_train,
is_five_train,
epochs=config.epochs,
validation_data=(X_test, is_five_test),
callbacks=[WandbCallback(labels=labels, input_type="image"), Images()]
)
Example Callback for stopping training early upon reaching high accuracy
class HighAccuracyStoppage(tf.keras.callbacks.Callback):
def on_batch_end(self, epoch, logs={}):
if(logs.get('accuracy')>0.99):
print("\nReached 99% accuracy so cancelling training!")
self.model.stop_training = True