owned this note
owned this note
Published
Linked with GitHub
# 利用HackMD翻譯對照Demo
- [PyTorch strengthens its governance by joining the Linux Foundation | PyTorch](https://pytorch.org/blog/PyTorchfoundation/)
- [Introduction to PyTorch — PyTorch Tutorials 1.12.1+cu102 documentation](https://pytorch.org/tutorials/beginner/introyt/introyt1_tutorial.html)
- [Retraining an Image Classifier | TensorFlow Hub](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining)
----
Retraining an Image Classifier
==============================
bookmark_border
- 這個頁面中的內容
- [Introduction](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#introduction)
- [Looking for a tool instead?](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#looking_for_a_tool_instead)
- [Setup](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#setup)
- [Select the TF2 SavedModel module to use](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#select_the_tf2_savedmodel_module_to_use)
- [Set up the Flowers dataset](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#set_up_the_flowers_dataset)
- [Defining the model](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#defining_the_model)
- [Training the model](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#training_the_model)
- [Optional: Deployment to TensorFlow Lite](https://www.tensorflow.org/hub/tutorials/tf2_image_retraining#optional_deployment_to_tensorflow_lite)
| [![](https://www.tensorflow.org/images/colab_logo_32px.png)Run in Google Colab](https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_image_retraining.ipynb) | [![](https://www.tensorflow.org/images/GitHub-Mark-32px.png)View on GitHub](https://github.com/tensorflow/hub/blob/master/examples/colab/tf2_image_retraining.ipynb) | [![](https://www.tensorflow.org/images/download_logo_32px.png)Download notebook](https://storage.googleapis.com/tensorflow_docs/hub/examples/colab/tf2_image_retraining.ipynb) | [![](https://www.tensorflow.org/images/hub_logo_32px.png)See TF Hub models](https://tfhub.dev/google/collections/image/1) |
Introduction
-------------
Image classification models have millions of parameters. Training them from scratch requires a lot of labeled training data and a lot of computing power. Transfer learning is a technique that shortcuts much of this by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
This Colab demonstrates how to build a Keras model for classifying five species of flowers by using a pre-trained TF2 SavedModel from TensorFlow Hub for image feature extraction, trained on the much larger and more general ImageNet dataset. Optionally, the feature extractor can be trained ("fine-tuned") alongside the newly added classifier.
### Looking for a tool instead?
This is a TensorFlow coding tutorial. If you want a tool that just builds the TensorFlow or TFLite model for, take a look at the [make\_image\_classifier](https://github.com/tensorflow/hub/tree/master/tensorflow_hub/tools/make_image_classifier) command-line tool that gets [installed](https://www.tensorflow.org/hub/installation) by the PIP package `tensorflow-hub[make_image_classifier]`, or at [this](https://colab.sandbox.google.com/github/tensorflow/examples/blob/master/tensorflow_examples/lite/model_maker/demo/image_classification.ipynb) TFLite colab.
Setup
------
```
import itertoolsimport osimport matplotlib.pylab as pltimport numpy as npimport tensorflow as tfimport tensorflow_hub as hubprint("TF version:", tf.__version__)print("Hub version:", hub.__version__)print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")
```
TF version: 2.9.1
Hub version: 0.12.0
GPU is available
Select the TF2 SavedModel module to use
----------------------------------------
For starters, use [https://tfhub.dev/google/imagenet/mobilenet\_v2\_100\_224/feature\_vector/4](https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/feature_vector/4). The same URL can be used in code to identify the SavedModel and in your browser to show its documentation. (Note that models in TF1 Hub format won't work here.)
You can find more TF2 models that generate image feature vectors [here](https://tfhub.dev/s?module-type=image-feature-vector&tf-version=tf2).
There are multiple possible models to try. All you need to do is select a different one on the cell below and follow up with the notebook.
Toggle code
Selected model: efficientnetv2-xl-21k : https://tfhub.dev/google/imagenet/efficientnet\_v2\_imagenet21k\_xl/feature\_vector/2
Input size (512, 512)
Set up the Flowers dataset
---------------------------
Inputs are suitably resized for the selected module. Dataset augmentation (i.e., random distortions of an image each time it is read) improves training, esp. when fine-tuning.
```
data_dir = tf.keras.utils.get_file( 'flower_photos', 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', untar=True)
```
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example\_images/flower\_photos.tgz
228813984/228813984 \[==============================\] - 1s 0us/step
Toggle code
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Defining the model
-------------------
All it takes is to put a linear classifier on top of the `feature_extractor_layer` with the Hub module.
For speed, we start out with a non-trainable `feature_extractor_layer`, but you can also enable fine-tuning for greater accuracy.
```
do_fine_tuning = False
```
```
print("Building model with", model_handle)model = tf.keras.Sequential([ # Explicitly define the input shape so the model can be properly # loaded by the TFLiteConverter tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE + (3,)), hub.KerasLayer(model_handle, trainable=do_fine_tuning), tf.keras.layers.Dropout(rate=0.2), tf.keras.layers.Dense(len(class_names), kernel_regularizer=tf.keras.regularizers.l2(0.0001))])model.build((None,)+IMAGE_SIZE+(3,))model.summary()
```
Building model with https://tfhub.dev/google/imagenet/efficientnet\_v2\_imagenet21k\_xl/feature\_vector/2
Model: "sequential_1"
\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_
Layer (type) Output Shape Param #
=================================================================
keras_layer (KerasLayer) (None, 1280) 207615832
dropout (Dropout) (None, 1280) 0
dense (Dense) (None, 5) 6405
=================================================================
Total params: 207,622,237
Trainable params: 6,405
Non-trainable params: 207,615,832
\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_
Training the model
-------------------
```
model.compile( optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1), metrics=['accuracy'])
```
```
steps_per_epoch = train_size // BATCH_SIZEvalidation_steps = valid_size // BATCH_SIZEhist = model.fit( train_ds, epochs=5, steps_per_epoch=steps_per_epoch, validation_data=val_ds, validation_steps=validation_steps).history
```
Epoch 1/5
183/183 \[==============================\] - 188s 852ms/step - loss: 0.8193 - accuracy: 0.9098 - val\_loss: 0.6232 - val\_accuracy: 0.9569
Epoch 2/5
183/183 \[==============================\] - 152s 830ms/step - loss: 0.6269 - accuracy: 0.9510 - val\_loss: 0.5898 - val\_accuracy: 0.9556
Epoch 3/5
183/183 \[==============================\] - 151s 828ms/step - loss: 0.5668 - accuracy: 0.9634 - val\_loss: 0.5841 - val\_accuracy: 0.9625
Epoch 4/5
183/183 \[==============================\] - 151s 828ms/step - loss: 0.5460 - accuracy: 0.9740 - val\_loss: 0.5570 - val\_accuracy: 0.9681
Epoch 5/5
183/183 \[==============================\] - 151s 828ms/step - loss: 0.5206 - accuracy: 0.9777 - val\_loss: 0.5217 - val\_accuracy: 0.9653
```
plt.figure()plt.ylabel("Loss (training and validation)")plt.xlabel("Training Steps")plt.ylim([0,2])plt.plot(hist["loss"])plt.plot(hist["val_loss"])plt.figure()plt.ylabel("Accuracy (training and validation)")plt.xlabel("Training Steps")plt.ylim([0,1])plt.plot(hist["accuracy"])plt.plot(hist["val_accuracy"])
```
\[<matplotlib.lines.Line2D at 0x7f81614f1430>\]
![png](https://www.tensorflow.org/static/hub/tutorials/tf2_image_retraining_files/output_CYOw0fTO1W4x_1.png)
![png](https://www.tensorflow.org/static/hub/tutorials/tf2_image_retraining_files/output_CYOw0fTO1W4x_2.png)
Try out the model on an image from the validation data:
```
x, y = next(iter(val_ds))image = x[0, :, :, :]true_index = np.argmax(y[0])plt.imshow(image)plt.axis('off')plt.show()# Expand the validation image to (1, 224, 224, 3) before predicting the labelprediction_scores = model.predict(np.expand_dims(image, axis=0))predicted_index = np.argmax(prediction_scores)print("True label: " + class_names[true_index])print("Predicted label: " + class_names[predicted_index])
```
![png](https://www.tensorflow.org/static/hub/tutorials/tf2_image_retraining_files/output_oi1iCNB9K1Ai_0.png)
1/1 \[==============================\] - 4s 4s/step
True label: sunflowers
Predicted label: sunflowers
Finally, the trained model can be saved for deployment to TF Serving or TFLite (on mobile) as follows.
```
saved_model_path = f"/tmp/saved_flowers_model_{model_name}"tf.saved_model.save(model, saved_model_path)
```
WARNING:absl:Found untraced functions such as restored\_function\_body, restored\_function\_body, restored\_function\_body, restored\_function\_body, restored\_function\_body while saving (showing 5 of 1594). These functions will not be directly callable after loading.
INFO:tensorflow:Assets written to: /tmp/saved\_flowers\_model_efficientnetv2-xl-21k/assets
INFO:tensorflow:Assets written to: /tmp/saved\_flowers\_model_efficientnetv2-xl-21k/assets
Optional: Deployment to TensorFlow Lite
----------------------------------------
[TensorFlow Lite](https://www.tensorflow.org/lite) lets you deploy TensorFlow models to mobile and IoT devices. The code below shows how to convert the trained model to TFLite and apply post-training tools from the [TensorFlow Model Optimization Toolkit](https://www.tensorflow.org/model_optimization). Finally, it runs it in the TFLite Interpreter to examine the resulting quality
- Converting without optimization provides the same results as before (up to roundoff error).
- Converting with optimization without any data quantizes the model weights to 8 bits, but inference still uses floating-point computation for the neural network activations. This reduces model size almost by a factor of 4 and improves CPU latency on mobile devices.
- On top, computation of the neural network activations can be quantized to 8-bit integers as well if a small reference dataset is provided to calibrate the quantization range. On a mobile device, this accelerates inference further and makes it possible to run on accelerators like Edge TPU.
### Optimization settings
Toggle code
2022-06-29 11:22:34.450027: W tensorflow/compiler/mlir/lite/python/tf\_tfl\_flatbuffer\_helpers.cc:362\] Ignored output\_format.
2022-06-29 11:22:34.450101: W tensorflow/compiler/mlir/lite/python/tf\_tfl\_flatbuffer\_helpers.cc:365\] Ignored drop\_control_dependency.
Wrote TFLite model of 826249772 bytes.
```
interpreter = tf.lite.Interpreter(model_content=lite_model_content)# This little helper wraps the TFLite Interpreter as a numpy-to-numpy function.def lite_model(images): interpreter.allocate_tensors() interpreter.set_tensor(interpreter.get_input_details()[0]['index'], images) interpreter.invoke() return interpreter.get_tensor(interpreter.get_output_details()[0]['index'])
```
```
num_eval_examples = 50 eval_dataset = ((image, label) # TFLite expects batch size 1. for batch in train_ds for (image, label) in zip(*batch))count = 0count_lite_tf_agree = 0count_lite_correct = 0for image, label in eval_dataset: probs_lite = lite_model(image[None, ...])[0] probs_tf = model(image[None, ...]).numpy()[0] y_lite = np.argmax(probs_lite) y_tf = np.argmax(probs_tf) y_true = np.argmax(label) count +=1 if y_lite == y_tf: count_lite_tf_agree += 1 if y_lite == y_true: count_lite_correct += 1 if count >= num_eval_examples: breakprint("TFLite model agrees with original model on %d of %d examples (%g%%)." % (count_lite_tf_agree, count, 100.0 * count_lite_tf_agree / count))print("TFLite model is accurate on %d of %d examples (%g%%)." % (count_lite_correct, count, 100.0 * count_lite_correct / count))
```
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
TFLite model agrees with original model on 50 of 50 examples (100%).
TFLite model is accurate on 50 of 50 examples (100%).
這對你有幫助嗎?
Except as otherwise noted, the content of this page is licensed under the [Creative Commons Attribution 4.0 License](https://creativecommons.org/licenses/by/4.0/), and code samples are licensed under the [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). For details, see the [Google Developers Site Policies](https://developers.google.com/site-policies). Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2022-06-30 UTC.