Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add keras callback #182

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft

add keras callback #182

wants to merge 1 commit into from

Conversation

lvwerra
Copy link
Member

@lvwerra lvwerra commented Jul 11, 2022

This PR adds a KerasCallback class which lets a user wrap an evaluation in a keras compatible wrapper. Following the discussions in #10 (comment).

import evaluate

recall = evaluate.load("recall")

transform = lambda x: np.argmax(x, axis-1) # labels are 1-hot encoded but recall needs integer representation

recall_callback = evaluate.KerasCallback(
    model, # model object
    x_test, # model input
    recall, # metric object
    {"references": transform(y_test), "average":"micro"}, # metric input
    predictions_processor=transform # transform between model ouput/metric inputs
    ) 

model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

hist = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.1,
    callbacks=[recall_callback],
)

If there are multiple metrics we can either create several callbacks or combine them all with combine into a single metric. There are still a few namings/rough edges to iron out but I would be interested in some early feedback.

cc @lhoestq @sashavor

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@lvwerra lvwerra requested a review from lhoestq July 12, 2022 11:33
@lhoestq
Copy link
Member

lhoestq commented Jul 18, 2022

Nice ! Cool that it's similar to https://keras.io/examples/keras_recipes/sklearn_metric_callbacks/ :)

@lvwerra
Copy link
Member Author

lvwerra commented Jul 18, 2022

Yep, I took that as inspiration :) I'll flesh out the PR for review.

@Rocketknight1
Copy link
Member

We also have a KerasMetricCallback in transformers - but it might make sense to deprecate that in favour of a class in evaluate!

I definitely think having an easy way to wrap evaluate metrics in a Keras-compatible way would be really useful, and would let me get rid of a lot of boilerplate in all my example code.

@lvwerra
Copy link
Member Author

lvwerra commented Aug 5, 2022

Thanks @Rocketknight1! That looks pretty much like what we want. Maybe we can/need change two things:

  1. Pass an evaluate metric instead of a callable and add use the prediction_processor as the piece between model and metric.
  2. We want evaluate to be as framework (transformers) agnostic as possible. The main issue I see is integrating the generate related functions as they are not native to Keras and make quite a lot of use of transformers logic. Similar for padding across batches.

What are your thoughts on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants