In [ ]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
In [ ]:
#@title MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

Text classification with TensorFlow Hub: Movie reviews

This notebook classifies movie reviews as positive or negative using the text of the review. This is an example of binary—or two-class—classification, an important and widely applicable kind of machine learning problem.

The tutorial demonstrates the basic application of transfer learning with TensorFlow Hub and Keras.

It uses the IMDB dataset that contains the text of 50,000 movie reviews from the Internet Movie Database. These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are balanced, meaning they contain an equal number of positive and negative reviews.

This notebook uses tf.keras, a high-level API to build and train models in TensorFlow, and tensorflow_hub, a library for loading trained models from TFHub in a single line of code. For a more advanced text classification tutorial using tf.keras, see the MLCC Text Classification Guide.

In [1]:
!pip install tensorflow-hub==0.14.0
!pip install tensorflow-datasets==4.9.4
Defaulting to user installation because normal site-packages is not writeable
Collecting tensorflow-hub==0.14.0
  Downloading tensorflow_hub-0.14.0-py2.py3-none-any.whl.metadata (1.4 kB)
Requirement already satisfied: numpy>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow-hub==0.14.0) (1.26.4)
Requirement already satisfied: protobuf>=3.19.6 in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-hub==0.14.0) (3.20.3)
Downloading tensorflow_hub-0.14.0-py2.py3-none-any.whl (90 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 90.3/90.3 kB 2.7 MB/s eta 0:00:00
Installing collected packages: tensorflow-hub
Successfully installed tensorflow-hub-0.14.0
Defaulting to user installation because normal site-packages is not writeable
Requirement already satisfied: tensorflow-datasets==4.9.4 in /home/sciencedata/.local/lib/python3.10/site-packages (4.9.4)
Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets==4.9.4) (2.1.0)
Requirement already satisfied: click in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (8.1.7)
Requirement already satisfied: dm-tree in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (0.1.8)
Requirement already satisfied: etils>=0.9.0 in /home/sciencedata/.local/lib/python3.10/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets==4.9.4) (1.7.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets==4.9.4) (1.26.4)
Requirement already satisfied: promise in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (2.3)
Requirement already satisfied: protobuf>=3.20 in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (3.20.3)
Requirement already satisfied: psutil in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (6.0.0)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets==4.9.4) (2.32.3)
Requirement already satisfied: tensorflow-metadata in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (1.15.0)
Requirement already satisfied: termcolor in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets==4.9.4) (2.4.0)
Requirement already satisfied: toml in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (0.10.2)
Requirement already satisfied: tqdm in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (4.66.4)
Requirement already satisfied: wrapt in /usr/local/lib/python3.10/dist-packages (from tensorflow-datasets==4.9.4) (1.14.1)
Requirement already satisfied: array-record>=0.5.0 in /home/sciencedata/.local/lib/python3.10/site-packages (from tensorflow-datasets==4.9.4) (0.5.1)
Requirement already satisfied: fsspec in /home/sciencedata/.local/lib/python3.10/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets==4.9.4) (2024.6.1)
Requirement already satisfied: importlib_resources in /home/sciencedata/.local/lib/python3.10/site-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets==4.9.4) (6.4.0)
Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets==4.9.4) (4.12.2)
Requirement already satisfied: zipp in /usr/lib/python3/dist-packages (from etils[enp,epath,etree]>=0.9.0->tensorflow-datasets==4.9.4) (1.0.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->tensorflow-datasets==4.9.4) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->tensorflow-datasets==4.9.4) (3.7)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->tensorflow-datasets==4.9.4) (2.2.2)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->tensorflow-datasets==4.9.4) (2024.7.4)
Requirement already satisfied: six in /usr/lib/python3/dist-packages (from promise->tensorflow-datasets==4.9.4) (1.16.0)
In [2]:
import os
import numpy as np

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print("Hub version: ", hub.__version__)
print("GPU is", "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
2024-07-16 00:14:14.251529: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-16 00:14:14.251579: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-16 00:14:14.252934: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-16 00:14:14.260609: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-16 00:14:15.087450: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Version:  2.15.0
Eager mode:  True
Hub version:  0.14.0
GPU is available

Download the IMDB dataset

The IMDB dataset is available on imdb reviews or on TensorFlow datasets. The following code downloads the IMDB dataset to your machine (or the colab runtime):

In [3]:
# Split the training set into 60% and 40% to end up with 15,000 examples
# for training, 10,000 examples for validation and 25,000 examples for testing.
train_data, validation_data, test_data = tfds.load(
    name="imdb_reviews", 
    split=('train[:60%]', 'train[60%:]', 'test'),
    as_supervised=True)
2024-07-16 00:14:36.831474: W external/local_tsl/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
Downloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/sciencedata/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...
Dataset imdb_reviews downloaded and prepared to /home/sciencedata/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.
2024-07-16 00:15:39.583046: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 10525 MB memory:  -> device: 0, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:06:00.0, compute capability: 6.1
2024-07-16 00:15:39.588015: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 10525 MB memory:  -> device: 1, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:07:00.0, compute capability: 6.1
2024-07-16 00:15:39.588930: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 10525 MB memory:  -> device: 2, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:08:00.0, compute capability: 6.1
2024-07-16 00:15:39.589915: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 10525 MB memory:  -> device: 3, name: NVIDIA GeForce GTX 1080 Ti, pci bus id: 0000:0b:00.0, compute capability: 6.1

Explore the data

Let's take a moment to understand the format of the data. Each example is a sentence representing the movie review and a corresponding label. The sentence is not preprocessed in any way. The label is an integer value of either 0 or 1, where 0 is a negative review, and 1 is a positive review.

Let's print first 10 examples.

In [4]:
train_examples_batch, train_labels_batch = next(iter(train_data.batch(10)))
train_examples_batch
2024-07-16 09:53:13.153430: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Out[4]:
<tf.Tensor: shape=(10,), dtype=string, numpy=
array([b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.",
       b'I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However on this occasion I fell asleep because the film was rubbish. The plot development was constant. Constantly slow and boring. Things seemed to happen, but with no explanation of what was causing them or why. I admit, I may have missed part of the film, but i watched the majority of it and everything just seemed to happen of its own accord without any real concern for anything else. I cant recommend this film at all.',
       b'Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Mountie telling the people of Dawson City, Yukon to elect themselves a marshal (yes a marshal!) and to enforce the law themselves, then gunfighters battling it out on the streets for control of the town? <br /><br />Nothing even remotely resembling that happened on the Canadian side of the border during the Klondike gold rush. Mr. Mann and company appear to have mistaken Dawson City for Deadwood, the Canadian North for the American Wild West.<br /><br />Canadian viewers be prepared for a Reefer Madness type of enjoyable howl with this ludicrous plot, or, to shake your head in disgust.',
       b'This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful performances from Cher and Nicolas Cage (as always) gently row the plot along. There are no rapids to cross, no dangerous waters, just a warm and witty paddle through New York life at its best. A family film in every sense and one that deserves the praise it received.',
       b'As others have mentioned, all the women that go nude in this film are mostly absolutely gorgeous. The plot very ably shows the hypocrisy of the female libido. When men are around they want to be pursued, but when no "men" are around, they become the pursuers of a 14 year old boy. And the boy becomes a man really fast (we should all be so lucky at this age!). He then gets up the courage to pursue his true love.',
       b"This is a film which should be seen by anybody interested in, effected by, or suffering from an eating disorder. It is an amazingly accurate and sensitive portrayal of bulimia in a teenage girl, its causes and its symptoms. The girl is played by one of the most brilliant young actresses working in cinema today, Alison Lohman, who was later so spectacular in 'Where the Truth Lies'. I would recommend that this film be shown in all schools, as you will never see a better on this subject. Alison Lohman is absolutely outstanding, and one marvels at her ability to convey the anguish of a girl suffering from this compulsive disorder. If barometers tell us the air pressure, Alison Lohman tells us the emotional pressure with the same degree of accuracy. Her emotional range is so precise, each scene could be measured microscopically for its gradations of trauma, on a scale of rising hysteria and desperation which reaches unbearable intensity. Mare Winningham is the perfect choice to play her mother, and does so with immense sympathy and a range of emotions just as finely tuned as Lohman's. Together, they make a pair of sensitive emotional oscillators vibrating in resonance with one another. This film is really an astonishing achievement, and director Katt Shea should be proud of it. The only reason for not seeing it is if you are not interested in people. But even if you like nature films best, this is after all animal behaviour at the sharp edge. Bulimia is an extreme version of how a tormented soul can destroy her own body in a frenzy of despair. And if we don't sympathise with people suffering from the depths of despair, then we are dead inside.",
       b'Okay, you have:<br /><br />Penelope Keith as Miss Herringbone-Tweed, B.B.E. (Backbone of England.) She\'s killed off in the first scene - that\'s right, folks; this show has no backbone!<br /><br />Peter O\'Toole as Ol\' Colonel Cricket from The First War and now the emblazered Lord of the Manor.<br /><br />Joanna Lumley as the ensweatered Lady of the Manor, 20 years younger than the colonel and 20 years past her own prime but still glamourous (Brit spelling, not mine) enough to have a toy-boy on the side. It\'s alright, they have Col. Cricket\'s full knowledge and consent (they guy even comes \'round for Christmas!) Still, she\'s considerate of the colonel enough to have said toy-boy her own age (what a gal!)<br /><br />David McCallum as said toy-boy, equally as pointlessly glamourous as his squeeze. Pilcher couldn\'t come up with any cover for him within the story, so she gave him a hush-hush job at the Circus.<br /><br />and finally:<br /><br />Susan Hampshire as Miss Polonia Teacups, Venerable Headmistress of the Venerable Girls\' Boarding-School, serving tea in her office with a dash of deep, poignant advice for life in the outside world just before graduation. Her best bit of advice: "I\'ve only been to Nancherrow (the local Stately Home of England) once. I thought it was very beautiful but, somehow, not part of the real world." Well, we can\'t say they didn\'t warn us.<br /><br />Ah, Susan - time was, your character would have been running the whole show. They don\'t write \'em like that any more. Our loss, not yours.<br /><br />So - with a cast and setting like this, you have the re-makings of "Brideshead Revisited," right?<br /><br />Wrong! They took these 1-dimensional supporting roles because they paid so well. After all, acting is one of the oldest temp-jobs there is (YOU name another!)<br /><br />First warning sign: lots and lots of backlighting. They get around it by shooting outdoors - "hey, it\'s just the sunlight!"<br /><br />Second warning sign: Leading Lady cries a lot. When not crying, her eyes are moist. That\'s the law of romance novels: Leading Lady is "dewy-eyed."<br /><br />Henceforth, Leading Lady shall be known as L.L.<br /><br />Third warning sign: L.L. actually has stars in her eyes when she\'s in love. Still, I\'ll give Emily Mortimer an award just for having to act with that spotlight in her eyes (I wonder . did they use contacts?)<br /><br />And lastly, fourth warning sign: no on-screen female character is "Mrs." She\'s either "Miss" or "Lady."<br /><br />When all was said and done, I still couldn\'t tell you who was pursuing whom and why. I couldn\'t even tell you what was said and done.<br /><br />To sum up: they all live through World War II without anything happening to them at all.<br /><br />OK, at the end, L.L. finds she\'s lost her parents to the Japanese prison camps and baby sis comes home catatonic. Meanwhile (there\'s always a "meanwhile,") some young guy L.L. had a crush on (when, I don\'t know) comes home from some wartime tough spot and is found living on the street by Lady of the Manor (must be some street if SHE\'s going to find him there.) Both war casualties are whisked away to recover at Nancherrow (SOMEBODY has to be "whisked away" SOMEWHERE in these romance stories!)<br /><br />Great drama.',
       b'The film is based on a genuine 1950s novel.<br /><br />Journalist Colin McInnes wrote a set of three "London novels": "Absolute Beginners", "City of Spades" and "Mr Love and Justice". I have read all three. The first two are excellent. The last, perhaps an experiment that did not come off. But McInnes\'s work is highly acclaimed; and rightly so. This musical is the novelist\'s ultimate nightmare - to see the fruits of one\'s mind being turned into a glitzy, badly-acted, soporific one-dimensional apology of a film that says it captures the spirit of 1950s London, and does nothing of the sort.<br /><br />Thank goodness Colin McInnes wasn\'t alive to witness it.',
       b'I really love the sexy action and sci-fi films of the sixties and its because of the actress\'s that appeared in them. They found the sexiest women to be in these films and it didn\'t matter if they could act (Remember "Candy"?). The reason I was disappointed by this film was because it wasn\'t nostalgic enough. The story here has a European sci-fi film called "Dragonfly" being made and the director is fired. So the producers decide to let a young aspiring filmmaker (Jeremy Davies) to complete the picture. They\'re is one real beautiful woman in the film who plays Dragonfly but she\'s barely in it. Film is written and directed by Roman Coppola who uses some of his fathers exploits from his early days and puts it into the script. I wish the film could have been an homage to those early films. They could have lots of cameos by actors who appeared in them. There is one actor in this film who was popular from the sixties and its John Phillip Law (Barbarella). Gerard Depardieu, Giancarlo Giannini and Dean Stockwell appear as well. I guess I\'m going to have to continue waiting for a director to make a good homage to the films of the sixties. If any are reading this, "Make it as sexy as you can"! I\'ll be waiting!',
       b'Sure, this one isn\'t really a blockbuster, nor does it target such a position. "Dieter" is the first name of a quite popular German musician, who is either loved or hated for his kind of acting and thats exactly what this movie is about. It is based on the autobiography "Dieter Bohlen" wrote a few years ago but isn\'t meant to be accurate on that. The movie is filled with some sexual offensive content (at least for American standard) which is either amusing (not for the other "actors" of course) or dumb - it depends on your individual kind of humor or on you being a "Bohlen"-Fan or not. Technically speaking there isn\'t much to criticize. Speaking of me I find this movie to be an OK-movie.'],
      dtype=object)>

Let's also print the first 10 labels.

In [5]:
train_labels_batch
Out[5]:
<tf.Tensor: shape=(10,), dtype=int64, numpy=array([0, 0, 0, 1, 1, 1, 0, 0, 0, 0])>

Build the model

The neural network is created by stacking layers—this requires three main architectural decisions:

  • How to represent the text?
  • How many layers to use in the model?
  • How many hidden units to use for each layer?

In this example, the input data consists of sentences. The labels to predict are either 0 or 1.

One way to represent the text is to convert sentences into embeddings vectors. Use a pre-trained text embedding as the first layer, which will have three advantages:

  • You don't have to worry about text preprocessing,
  • Benefit from transfer learning,
  • the embedding has a fixed size, so it's simpler to process.

For this example you use a pre-trained text embedding model from TensorFlow Hub called google/nnlm-en-dim50/2.

There are many other pre-trained text embeddings from TFHub that can be used in this tutorial:

And many more! Find more text embedding models on TFHub.

Let's first create a Keras layer that uses a TensorFlow Hub model to embed the sentences, and try it out on a couple of input examples. Note that no matter the length of the input text, the output shape of the embeddings is: (num_examples, embedding_dimension).

In [6]:
embedding = "https://tfhub.dev/google/nnlm-en-dim50/2"
hub_layer = hub.KerasLayer(embedding, input_shape=[], 
                           dtype=tf.string, trainable=True)
hub_layer(train_examples_batch[:3])
2024-07-16 09:55:15.564708: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
Out[6]:
<tf.Tensor: shape=(3, 50), dtype=float32, numpy=
array([[ 0.5423194 , -0.01190171,  0.06337537,  0.0686297 , -0.16776839,
        -0.10581177,  0.168653  , -0.04998823, -0.31148052,  0.07910344,
         0.15442258,  0.01488661,  0.03930155,  0.19772716, -0.12215477,
        -0.04120982, -0.27041087, -0.21922147,  0.26517656, -0.80739075,
         0.25833526, -0.31004202,  0.2868321 ,  0.19433866, -0.29036498,
         0.0386285 , -0.78444123, -0.04793238,  0.41102988, -0.36388886,
        -0.58034706,  0.30269453,  0.36308962, -0.15227163, -0.4439151 ,
         0.19462997,  0.19528405,  0.05666233,  0.2890704 , -0.28468323,
        -0.00531206,  0.0571938 , -0.3201319 , -0.04418665, -0.08550781,
        -0.55847436, -0.2333639 , -0.20782956, -0.03543065, -0.17533456],
       [ 0.56338924, -0.12339553, -0.10862677,  0.7753425 , -0.07667087,
        -0.15752274,  0.01872334, -0.08169781, -0.3521876 ,  0.46373403,
        -0.08492758,  0.07166861, -0.00670818,  0.12686071, -0.19326551,
        -0.5262643 , -0.32958236,  0.14394784,  0.09043556, -0.54175544,
         0.02468163, -0.15456744,  0.68333143,  0.09068333, -0.45327246,
         0.23180094, -0.8615696 ,  0.3448039 ,  0.12838459, -0.58759046,
        -0.40712303,  0.23061076,  0.48426905, -0.2712814 , -0.5380918 ,
         0.47016335,  0.2257274 , -0.00830665,  0.28462422, -0.30498496,
         0.04400366,  0.25025868,  0.14867125,  0.4071703 , -0.15422425,
        -0.06878027, -0.40825695, -0.31492147,  0.09283663, -0.20183429],
       [ 0.7456156 ,  0.21256858,  0.1440033 ,  0.52338624,  0.11032254,
         0.00902788, -0.36678016, -0.08938274, -0.24165548,  0.33384597,
        -0.111946  , -0.01460045, -0.00716449,  0.19562715,  0.00685217,
        -0.24886714, -0.42796353,  0.1862    , -0.05241097, -0.664625  ,
         0.13449019, -0.22205493,  0.08633009,  0.43685383,  0.2972681 ,
         0.36140728, -0.71968895,  0.05291242, -0.1431612 , -0.15733941,
        -0.15056324, -0.05988007, -0.08178931, -0.15569413, -0.09303784,
        -0.18971168,  0.0762079 , -0.02541647, -0.27134502, -0.3392682 ,
        -0.10296471, -0.27275252, -0.34078008,  0.20083308, -0.26644838,
         0.00655449, -0.05141485, -0.04261916, -0.4541363 ,  0.20023566]],
      dtype=float32)>

Let's now build the full model:

In [7]:
model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Dense(1))

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 keras_layer (KerasLayer)    (None, 50)                48190600  
                                                                 
 dense (Dense)               (None, 16)                816       
                                                                 
 dense_1 (Dense)             (None, 1)                 17        
                                                                 
=================================================================
Total params: 48191433 (183.84 MB)
Trainable params: 48191433 (183.84 MB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________

The layers are stacked sequentially to build the classifier:

  1. The first layer is a TensorFlow Hub layer. This layer uses a pre-trained Saved Model to map a sentence into its embedding vector. The pre-trained text embedding model that you are using (google/nnlm-en-dim50/2) splits the sentence into tokens, embeds each token and then combines the embedding. The resulting dimensions are: (num_examples, embedding_dimension). For this NNLM model, the embedding_dimension is 50.
  2. This fixed-length output vector is piped through a fully-connected (Dense) layer with 16 hidden units.
  3. The last layer is densely connected with a single output node.

Let's compile the model.

Loss function and optimizer

A model needs a loss function and an optimizer for training. Since this is a binary classification problem and the model outputs logits (a single-unit layer with a linear activation), you'll use the binary_crossentropy loss function.

This isn't the only choice for a loss function, you could, for instance, choose mean_squared_error. But, generally, binary_crossentropy is better for dealing with probabilities—it measures the "distance" between probability distributions, or in our case, between the ground-truth distribution and the predictions.

Later, when you are exploring regression problems (say, to predict the price of a house), you'll see how to use another loss function called mean squared error.

Now, configure the model to use an optimizer and a loss function:

In [8]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

Train the model

Train the model for 10 epochs in mini-batches of 512 samples. This is 10 iterations over all samples in the x_train and y_train tensors. While training, monitor the model's loss and accuracy on the 10,000 samples from the validation set:

In [9]:
history = model.fit(train_data.shuffle(10000).batch(512),
                    epochs=10,
                    validation_data=validation_data.batch(512),
                    verbose=1)
Epoch 1/10
2024-07-16 09:55:32.826781: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f61f0da9d50 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2024-07-16 09:55:32.826840: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1
2024-07-16 09:55:32.826855: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (1): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1
2024-07-16 09:55:32.826868: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (2): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1
2024-07-16 09:55:32.826879: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (3): NVIDIA GeForce GTX 1080 Ti, Compute Capability 6.1
2024-07-16 09:55:32.838655: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-07-16 09:55:32.879428: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8904
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721123732.975789     613 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
30/30 [==============================] - 10s 269ms/step - loss: 0.6507 - accuracy: 0.5580 - val_loss: 0.5876 - val_accuracy: 0.6587
Epoch 2/10
30/30 [==============================] - 7s 245ms/step - loss: 0.5165 - accuracy: 0.7282 - val_loss: 0.4772 - val_accuracy: 0.7757
Epoch 3/10
30/30 [==============================] - 7s 238ms/step - loss: 0.3840 - accuracy: 0.8373 - val_loss: 0.3908 - val_accuracy: 0.8278
Epoch 4/10
30/30 [==============================] - 7s 241ms/step - loss: 0.2784 - accuracy: 0.8949 - val_loss: 0.3411 - val_accuracy: 0.8506
Epoch 5/10
30/30 [==============================] - 7s 232ms/step - loss: 0.2025 - accuracy: 0.9307 - val_loss: 0.3171 - val_accuracy: 0.8595
Epoch 6/10
30/30 [==============================] - 7s 243ms/step - loss: 0.1475 - accuracy: 0.9554 - val_loss: 0.3085 - val_accuracy: 0.8628
Epoch 7/10
30/30 [==============================] - 7s 228ms/step - loss: 0.1073 - accuracy: 0.9711 - val_loss: 0.3094 - val_accuracy: 0.8671
Epoch 8/10
30/30 [==============================] - 6s 201ms/step - loss: 0.0781 - accuracy: 0.9821 - val_loss: 0.3156 - val_accuracy: 0.8672
Epoch 9/10
30/30 [==============================] - 7s 223ms/step - loss: 0.0566 - accuracy: 0.9901 - val_loss: 0.3236 - val_accuracy: 0.8648
Epoch 10/10
30/30 [==============================] - 6s 193ms/step - loss: 0.0411 - accuracy: 0.9948 - val_loss: 0.3349 - val_accuracy: 0.8636

Evaluate the model

And let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.

In [10]:
results = model.evaluate(test_data.batch(512), verbose=2)

for name, value in zip(model.metrics_names, results):
  print("%s: %.3f" % (name, value))
49/49 - 1s - loss: 0.3629 - accuracy: 0.8525 - 1s/epoch - 30ms/step
loss: 0.363
accuracy: 0.852

This fairly naive approach achieves an accuracy of about 87%. With more advanced approaches, the model should get closer to 95%.

Further reading