Text Classification with Transformers

Classify Text into Sustainable Development Goals with DistilBERT

Samuel Ozechi
17 min readAug 12, 2024

Text classification is a basic natural language processing task that categorizes text data into predefined classes. The input is a piece of text (such as a sentence, paragraph, or document), and the output is a label or category that best represents the content of the text within the predefined categories. This is widely useful in various applications such as spam detection, sentiment analysis, topic labelling, etc.

Transformers perform well in text classification due to their ability to better understand the context of texts. The self-attention mechanism of Transformers allows them to weigh the importance of each word in a sentence relative to other words. This helps the model capture long-range dependencies and understand the context better, leading to more accurate predictions. Models like BERT (Bidirectional Encoder Representations from Transformers), process text bidirectionally, meaning that they consider the context from both the left and right sides of a word, which further improves the model’s ability to grasp the nuanced meaning of words in different contexts. Also, pre-trained transformers can handle large vocabulary and be further trained to understand and classify text across diverse domains and languages. Refer to my previous article for more details on the Transformer architecture and why they generally excel for Natural Language Processing Tasks.

In this article, we’ll tackle text classification using a variant of BERT called DistilBERT. This model performs comparable to BERT while being significantly smaller and more efficient. This allows us to achieve similar performance in a shorter training period.

The Problem

We tackle the challenge of categorizing input texts into one of 15 distinct categories, each corresponding to a specific Sustainable Development Goal (SDG). The aim is to accurately identify which SDG the text is related to or addresses, helping to align the content of the text with the global objectives, as outlined by the United Nations.

The Dataset

The Sustainable Development Goals dataset contains 32,121 unique texts with their respective SGD label (from 1–15) that represent the Sustainable Development Goal related to that text.

Loading the Data

Since the dataset is freely hosted on Kaggle, we could easily download it locally and load it into our notebook or use the Kaggle dataset API to download the dataset into our environment directly. We would load the data into a Pandas dataframe using the latter approach, but first, we import all the required libraries for use.

# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import datasets
from datasets import Dataset
from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
from umap import UMAP
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, accuracy_score, f1_score, classification_report
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
import torch
from torch.nn.functional import cross_entropy
from transformers import pipeline
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer, AutoModelForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
# Download dataset
!kaggle datasets download -d kartikbhatnagar18/sustainable-development-goals -f osdg-community-dataset-v21-09-30.csv

# View downloaded files
!ls -d $PWD/*

# Unzip files and delete initial zip file
!unzip \*.zip && rm *.zip
# Load dataset into pandas dataframe
data = pd.read_csv('osdg-community-dataset-v21-09-30.csv')

# Preview dataset
data.head()
Preview the dataset

Next, we only keep the relevant columns for our task which are the input text and the respective SGD labels.

# Keep relevant columns
data = data[['text', 'sdg']]
data.head()
Preview of the dataset

Review Data Information

To have a better understanding of the data, we would look at vital information such as its dimensions, data types, sample input texts and distinct labels.

# Sheck the dimensions of the train data
n_rows= data.shape[0]
print(f'There are {n_rows} samples in the train dataset')

# Check vital data information
data.info()

# Show sample text data
for i in range(3):
print(data['text'].sample(1).values[0] + '\n') # print 1st three text samples

We notice that the sample texts relate to different issues. The first text discusses various forms of support for biofuel use across different countries, the second text highlights the complexity of water governance and the third highlights the inadequacy of warehouse storage capacity for rice. Each of the texts relates to an SDG as defined by the United Nations. Let’s review the labels to see the specific SDGs that are captured in our dataset.

# Show the distict classes 
print(f'There are {data["sdg"].nunique()} distinct SDG categories in the data \n {np.sort(data["sdg"].unique())}')

We see that there are 15 distinct SGD categories in our dataset. We can retrieve the specific goals represented by each label as defined by the United Nations.

# Define the 15 SDGs in the data
sdg_goals = {
1: "No Poverty",
2: "Zero Hunger",
3: "Good Health and Well-being",
4: "Quality Education",
5: "Gender Equality",
6: "Clean Water and Sanitation",
7: "Affordable and Clean Energy",
8: "Decent Work and Economic Growth",
9: "Industry, Innovation, and Infrastructure",
10:"Reduced Inequality",
11:"Sustainable Cities and Communities",
12:"Responsible Consumption and Production",
13:"Climate Action",
14:"Life Below Water",
15:"Life on Land"
}

Now that we have the classes for the labels in our data, we can carry out exploratory data analysis to identify patterns, and better understand the distribution and relationships within the data. First, we identify any data cleaning needs for our data.

# check for duplicates in the data
data.duplicated(subset=['text']).any()

# check for missing values in the data
data.isnull().any()

Thankfully, we have no duplicates to remove or missing data to deal with, so we can head right on to EDA.

Exploratory Data Analysis

Word clouds are useful for visualizing the most frequently occurring words in a text, allowing for a quick and intuitive understanding of key themes, topics, or sentiments in the data. Let’s visualize the word cloud for each class in the data to identify distinguishing words across different classes.

# Display wordclouds per category
stopwords = set(STOPWORDS) # Define stopwords

# Get unique categories
categories = data['sdg'].unique()

# Calculate number of rows and columns
rows = 5
cols = 3

# Create subplots
fig, axes = plt.subplots(rows, cols, figsize=(20, 20))

# Flatten axes for easy iteration
axes = axes.flatten()

# Iterate through each category and corresponding subplot axis
for idx, category in enumerate(categories):
# Get the text for the current category
category_text = data[data['sdg'] == category]['text']

# Generate a word cloud for the current category
wordcloud = WordCloud(stopwords=stopwords, background_color='white').generate(' '.join(category_text))

# Display the word cloud on the corresponding subplot
axes[idx].imshow(wordcloud, interpolation='bilinear')
axes[idx].set_title(sdg_goals[category])
axes[idx].axis('off')

# Add a rectangular box around each subplot
rect = patches.Rectangle((0, 0), 1, 1, linewidth=2, edgecolor='black', facecolor='none', transform=axes[idx].transAxes)
axes[idx].add_patch(rect)

# Turn off any remaining empty subplots
for i in range(len(categories), len(axes)):
axes[i].axis('off')

# Adjust layout
plt.tight_layout()
plt.show()
Word Cloud across different classes in the data

These word clouds provide a visual summary of the most significant words related to each SDG category. We see clear distinctions in the distinguishing words for each category which already provides a high level of intuition on the possible words that could define a particular category. We expect our model to identify the importance of these words in the input texts and use that relationship to identify the category that each text should belong to.

We also need to look at the class distribution in our dataset to identify possible data imbalance which can lead to biased models that perform poorly on the minority classes. This would also inform us of the possible training loss and evaluation metrics to use. For example, precision, recall and F1 scores would be preferable to evaluate model performance for imbalanced datasets, as against using just the accuracy score.

# Show data distribution per category

class_counts = data['sdg'].value_counts() # Count the number of samples per class

# Replace the index values with the exact SDG goals
class_counts.index = class_counts.index.map(sdg_goals)

# Create the horizontal bar chart
plt.figure(figsize=(10, 8))
plt.barh(class_counts.index, class_counts.values, color='skyblue')
plt.xlabel('Number of Samples')
plt.ylabel('SDG Goals')
plt.title('Number of Samples per SDG Goal')
plt.gca().invert_yaxis() # Invert y-axis to have the highest count on top
plt.show()
Data Distribution per Class

We see that the dataset is imbalanced; the Gender Equality and Quality Education classes appear more frequently compared to Responsible Consumption and Life on Land which are 5–10 times rarer. We can mitigate this data imbalance through random oversampling of minority classes, random undersampling of majority classes or gathering more labelled data for the underrepresented classes.

Since we have at least 500 samples for each class, we would keep things as they are. That should be enough samples for our pre-trained transformer to identify the important features and relationships in the dataset. It is highly recommended to deal with data imbalance issues in the data, especially for real-world applications.

Before we start preprocessing our dataset, let’s explore the length of words per text in the dataset. Perhaps, certain categories require more words than others. This would also give us insights into the vocabulary size of our dataset and the length of our tokens for preprocessing.

# check distribution of text length
df = data.copy()
df["Words Per Text"] = df["text"].str.split().apply(len)
df.boxplot("Words Per Text", by="sdg", grid=False,
showfliers=False, color="black")
plt.suptitle("")
plt.xlabel("Sdg Goals")
plt.ylabel("Words Per Text")
plt.show()
Distribution of Words per text

We see that most of the texts are generally between 70 to 120 words. Although we have others as low as 18 and high as 170. Transformer models have a maximum input sequence length that is referred to as the
maximum context size. For the DistilBERT model, the maximum context size is 512 tokens. This seems suitable for our vocabulary size. For problems with longer text, the token size would be truncated to the maximum content size, which can lead to information loss. In such cases, models such as Pegasus and GPTs with longer vocabulary sizes could be preferable.

Data Preprocessing (Tokenization)

Models cannot receive raw texts as input, they are required to be transformed into numerical vectors. This is known as tokenization. It can be described as the process of converting text into smaller units known as tokens. A simple word tokenization for instance would involve splitting the input texts into individual words and mapping each word to an integer, such that the initial text becomes a sequence of numerical values.

The Transformers package provides a convenient AutoTokenizer class that allows us to quickly load the tokenizer associated with a pre-trained model of choice by providing the model ID. Let’s start by splitting the dataset into train and validation sets and loading the tokenizer for DistilBERT.

# Split data into train and test sets
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42, stratify=data['sdg'])
train_data.shape, val_data.shape

# Convert DataFrame to Dataset type
train_ds = Dataset.from_pandas(train_data)
val_ds = Dataset.from_pandas(val_data)
# Define tokenizer
model_ckpt = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

def tokenize(batch):
return tokenizer(batch["text"], padding=True, truncation=True)

Notice that our tokenize function takes a batch of our dataset and returns the batch with the tokenized text. We have also set padding to True to pad the examples that are shorter than the longest input in batch with zero so that each batch inputs have equal lengths and truncation to True to truncate the inputs to the model’s maximum context size. We can now apply the tokenizer to our train and validation datasets

# Tokenize text inputs
train_encoded = train_ds.map(tokenize, batched=True, batch_size=None)
val_encoded = val_ds.map(tokenize, batched=True, batch_size=None)

# Show details of tokenized data
train_encoded.features

The tokenized inputs are referenced by the input_ids key. Notice that the tokenizer also returns an attention_mask sequence. The attention mask allows the model to ignore the padded parts of the input and only pay attention to the encoded text part of the inputs.

Before we go ahead with model training, we need to align the class labels of our model to the (0, n) class range recognized by most machine learning algorithms as against the (1, 15) range that we currently have. We also set the inputs to the tensor format that our model expects.

# Create labe with (o,n) class range
train_encoded = train_encoded.map(lambda x: {'label': [label - 1 for label in x['sdg']]}, batched=True)
val_encoded = train_encoded.map(lambda x: {'label': [label - 1 for label in x['sdg']]}, batched=True)

# Set inputs to tensor format
train_encoded.set_format("torch",columns=["input_ids", "attention_mask", "label"])
val_encoded.set_format("torch",columns=["input_ids", "attention_mask", "label"])

# Remove sdg from dataset
train_encoded = train_encoded.remove_columns(["sdg"])

Model Training

To train a text classifier, the Transformer model converts the token encodings into token embeddings which are dense, continuous vectors of fixed dimensions. These embeddings capture the syntactic and semantic meaning of tokens in a way that the Transformer model can process them effectively. The token embeddings are then passed through the encoder block layers of the transformer model to yield a hidden state for each input token. Unlike the initial token embeddings, which are primarily based on the token’s isolated meaning, the hidden state is contextually enriched. It incorporates information from all other tokens in the input sequence to better understand the context, using the self-attention mechanism. For example, in a sentence, the hidden state for the word “bank” will differ depending on whether the context is financial (“the bank approved the loan”) or geographical (“the bank of the river”).

This workflow therefore provides us with two main options to train our classifier:

  1. Use the hidden states as features to train a classifier.
  2. Train the pre-trained transformer model end-to-end.

We would utilize both options and compare the model outputs.

Feature Extraction: Train a classifier on the extracted hidden states

We would use the transformer model to extract the features (hidden states) of the text encodings and use the features to train a shallow learning algorithm. This feature-based approach can be a good option when GPU is not available.

# Load pretrained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_ckpt).to(device)

# Extract hidden state of inputs
def extract_hidden_states(batch):
# Place model inputs on the GPU
inputs = {k:v.to(device) for k,v in batch.items()
if k in tokenizer.model_input_names}
# Extract last hidden states
with torch.no_grad():
last_hidden_state = model(**inputs).last_hidden_state
# Return vector for [CLS] token
return {"hidden_state": last_hidden_state[:,0].cpu().numpy()}

We define the AutoModel class that would convert the token encodings to embeddings, and then feed them through the encoder layer to return the hidden states. We also check if GPU is available for use, if not, the model would run on CPU. The extract_hidden_states function returns the last hidden state of the embeddings. Now we can use the function to extract the hidden states.

# # Extract hidden state of inputs
# # Use this if you have enough computation resources
# train_hidden = train_encoded.map(extract_hidden_states, batched=True)
# val_hidden = val_encoded.map(extract_hidden_states, batched=True)

We can easily run out of memory by extracting the hidden states at a go due to its large size. Alternatively, we can extract the hidden states in batches and clear the cached data after each batch to better utilize resources.

# Reduce batch size to manage memory
batch_size = 8

# Process the dataset in smaller batches and clear CUDA cache
train_hidden = []
for i in range(0, len(train_encoded), batch_size):
batch = train_encoded.select(range(i, min(i + batch_size, len(train_encoded))))
train_hidden.append(batch.map(extract_hidden_states, batched=True))
torch.cuda.empty_cache() # Clear cache after each batch

train_hidden = datasets.concatenate_datasets(train_hidden)

# Repeat the process for validation set
val_hidden = []
for i in range(0, len(val_encoded), batch_size):
batch = val_encoded.select(range(i, min(i + batch_size, len(val_encoded))))
val_hidden.append(batch.map(extract_hidden_states, batched=True))
torch.cuda.empty_cache() # Clear cache after each batch

val_hidden = datasets.concatenate_datasets(val_hidden)

We can then define our train and validation features (hidden state) and labels to train our model.

# Get train and validation inputs and labels
X_train = np.array(train_hidden["hidden_state"])
X_valid = np.array(val_hidden["hidden_state"])
y_train = np.array(train_hidden["label"])
y_valid = np.array(val_hidden["label"])

We can visualize the hidden state per category to see distinctions in the different classes in our dataset that we aim for our model to capture.

# Visualizing the features in 2D
X_scaled = MinMaxScaler().fit_transform(X_train) # Scale features to [0,1] range
mapper = UMAP(n_components=2, metric="cosine").fit(X_scaled) # Initialize and fit UMAP
df_emb = pd.DataFrame(mapper.embedding_, columns=["X1", "X2"]) # Create a DataFrame of 2D embeddings
df_emb["label"] = y_train

fig, axes = plt.subplots(5, 3, figsize=(12,9))
axes = axes.flatten()
cmaps = ["Blues", "Oranges", "Reds", "Purples", "Greens"] * 3
labels = sdg_goals.values()
for i, (label, cmap) in enumerate(zip(labels, cmaps)):
df_emb_sub = df_emb.query(f"label == {i}")
axes[i].hexbin(df_emb_sub["X1"], df_emb_sub["X2"], cmap=cmap,
gridsize=20, linewidths=(0,))
axes[i].set_title(label)
axes[i].set_xticks([]), axes[i].set_yticks([])
plt.tight_layout()
plt.show()
Visualizing the hidden state per category

The visualization shows the different patterns for each category. Notice that similar themes occupy similar positions. For example, No Poverty and Zero Hunger occupy the left region, Sustainable Cities and Climate Action occupy the centre while Affordable and Clean Energy and Clean Water and Sanitation occupy a similar region. We might expect these similar SDGs to be misclassified for the other.

Now to the model training proper, let’s first develop a dummy model to establish a baseline for performance.

# Evaluate a baseline model
dummy_clf = DummyClassifier(strategy="prior")
dummy_clf.fit(X_train, y_train)
baseline_score = dummy_clf.score(X_valid, y_valid)
print(f"Baseline Score: {baseline_score}")

Then we train a Logistic Regression model on the features.

# Train and evaluate a logistic model
lr_clf = LogisticRegression(max_iter=3000) # We increase `max_iter` to guarantee convergence
lr_clf.fit(X_train, y_train) # Train the classifier
lr_score = lr_clf.score(X_valid, y_valid) # Evaluate the classifier

print(f"Logistic Regression Score: {lr_score}")

Using a basic Logistic Regression model already gives us an 80% accuracy. Since accuracy wouldn’t be the most appropriate metric for evaluating the model’s performance due to the class imbalance, let’s look at the confusion matrix to identify the level of misclassification by the model.

# Function to plot confusion matrix
def plot_confusion_matrix(y_valid, y_predicted):
plt.figure(figsize=(15,6))
cm = confusion_matrix(y_valid, y_predicted)
sns.heatmap(cm, annot= True, fmt='d')
plt.xlabel('Predicted')
plt.ylabel('Truth')
plt.title('Confusion Matrix')
ticks = range(1,16)
labels = np.sort(data["sdg"].unique())
plt.xticks(labels, ticks)
plt.yticks(labels, ticks)
plt.show()
# display confusion matrix
y_predicted = lr_clf.predict(X_valid)
plot_confusion_matrix(y_valid, y_predicted)

Despite the model’s decent performance, Some classes are frequently confused with others. For example, class 6 has 1,849 correct predictions but many instances are misclassified as class 7 (94 times). This suggests that these two classes might share similar features, making them harder to distinguish.

Also, classes with fewer samples (For example classes 7–10) have fewer correct predictions and more spread in misclassifications, indicating that the model struggles more with these classes. Let’s use alternative metrics such as precision, recall, and the F1-score to ascertain the model’s performance across the classes.

# display classification report
print('Classification Report of Logistic Regression Model \n',
classification_report(y_valid, y_predicted, target_names=[str(num) for num in range(1,16)]
))

While the model demonstrates strong overall accuracy, it has some challenges with class-specific performance, particularly in distinguishing between similar classes or those with fewer examples. This is a common issue in multi-class classification and may require hyperparameter optimization, balancing the dataset, or using more advanced algorithms such as ensemble methods or neural networks to improve performance.

Now, let’s explore fine-tuning the pre-trained model, which generally leads to superior classification performance.

Fine-tuning: Train the Pre trained Transformer

Finetuning the pre-trained transformer for text classification requires a classification head on top of the pre-trained model outputs. We can easily use the AutoModelForSequenceClassification class to define the number of labels the model has to predict from.

# Define pretrained model to fintune
num_labels = len(train_encoded['label'].unique())
model = (AutoModelForSequenceClassification
.from_pretrained(model_ckpt, num_labels=num_labels)
.to(device))

To monitor the model’s performance during training, we need to define a function to compute the performance metrics of choice. For our problem, we’ll compute the F1-score and the accuracy of the model at each training step.

# Define Performance metrics
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
f1 = f1_score(labels, preds, average="weighted")
acc = accuracy_score(labels, preds)
return {"accuracy": acc, "f1": f1}

Next, we would define the training arguments. This provides instruction to the model that allows you to control the model’s training and evaluation. We can easily define this using the TrainingArguments class.

# Define Training Arguments
batch_size = 8
logging_steps = len(train_encoded) // batch_size
model_name = f"{model_ckpt}-sdg-finetuned"
training_args = TrainingArguments(output_dir=model_name,
num_train_epochs=5,
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
evaluation_strategy="epoch",
disable_tqdm=False,
logging_steps=logging_steps,
push_to_hub=False,
log_level="error",
report_to=[])

Notice that the push_to_hub parameter is set to False because we plan to save our model locally for inference. Alternatively, we can push it to the Hugging Face model hub to easily call the model as an API for inference.

Finally, we are ready to define and train our pre-trained model end-to-end.

# Define Trainer 
trainer = Trainer(model=model, args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_encoded,
eval_dataset=val_encoded,
tokenizer=tokenizer)

# Train finetuned model
trainer.train();

The model shows consistent improvement across epochs, with validation loss decreasing and both accuracy and F1-score increasing. The performance also stabilizes by epoch 4, with higher accuracy and F1-scores than the Logistic Regression model, indicating that the model is better trained. Again, we look at the confusion matrix to see the level of misclassification.

# Get predicted outputs 
preds_output = trainer.predict(val_encoded)

# Show performance on validation set
preds_output.metrics

# Get labels of predicted outputs
y_preds = np.argmax(preds_output.predictions, axis=1)

# Plot confusion matrix of finetuned model
plot_confusion_matrix(y_valid, y_preds)

This shows a much-improved model, especially for the classes with fewer samples. We can even train the model for more epochs to see how improved the model could get. As previously, let’s also look at the model’s performance across the classes.

# display classification report
print('Classification Report of finetuned Distilbert model \n',
classification_report(y_valid, y_preds, target_names=[str(num) for num in range(1,16)]
))
Classification report for Finetuned DistilBERT model

The model generally improved performance across the classes. For instance, classes 8 and 10 have improved from 0.56 F1 scores to 0.87.

Finally, let’s see how we can save the model and use it for inference.

Saving the Model and Inference

After training, we can go ahead and save the model locally as intended.

# Save the model
trainer.save_model("distilbert-sdg-finetuned")

# Save the tokenizer
tokenizer.save_pretrained("distilbert-sdg-finetuned")

To use the model, we would simply load the model and tokenizer using the AutoModelForSequenceClassification and DistilBertTokenizer respectively and initialize a pipeline for predictions.

# Load the model and tokenizer from the local directory
model_path = "distilbert-sdg-finetuned"
model = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)

# Initialize the pipeline
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)

Now let’s test our model on a sample text.

# Load the pipeline with the locally saved model
# pipe = pipeline("SDG Classification", model="distilbert-sdg-model", tokenizer="distilbert-sdg-model", device=device)

sample_text = ''' In many parts of the world, millions of people still lack access to basic healthcare services.
Children and adults suffer from preventable diseases due to insufficient medical resources and
inadequate infrastructure. By focusing on improving healthcare systems, we can significantly reduce
mortality rates and enhance the quality of life for countless individuals. Investing in health
education and promoting healthy lifestyles are crucial steps towards achieving a healthier future.
Collaborative efforts among governments, NGOs, and communities are essential to ensure that everyone,
regardless of their socio-economic status, has the opportunity to live a long and healthy life."
'''
print("Dialogue:")
print(sample_text)
print("\nModel Prediction:")
prediction = pipe(sample_text,)[0]
label_ = prediction['label']
score = prediction['score']
label_ = int(label_.split('_')[-1])
sdg_goal = list(sdg_goals.values())[label_]
print(sdg_goal)

print("\nScore: ")
print(score)
Sample Prediction using our trained DistiBERT model

The model accurately predicts the category for our sample text with very high confidence. Let’s also try the model on a sample text that shares similarities with other categories.

sample_text = ''' Millions in third-world countries face extreme poverty, lacking access to basic necessities like food,
clean water, and education. Urgent efforts are needed to address these challenges and improve their
living conditions.""
'''
print("Dialogue:")
print(sample_text)
print("\nModel Prediction:")
prediction = pipe(sample_text,)[0]
label_ = prediction['label']
score = prediction['score']
label_ = int(label_.split('_')[-1])
sdg_goal = list(sdg_goals.values())[label_]
print(sdg_goal)

print("\nScore: ")
print(score)

The model accurately identifies that the text relates more to the No Poverty category even though it contains texts that relate to the Zero Hunger, Quality Education and Clean Water and Sanitation categories.

Conclusion

We have seen how to use transformer models for text classification using the approaches of Feature Extraction and Fine-tuning pre-trained models. While the model finetuning provides improved performance, the feature extraction approach is useful when lesser computational resources are available. You can find the entire code in this GitHub repository or run the Kaggle notebook.

--

--

No responses yet