I had the same issue as @alex above with TypeError:
__init__() missing 1 required positional argument: config
below is the code that worked for me to train and then predict a SegformerForSemanticSegmentation model. Note that for my purposes I needed to set a very high feature extraction number. This can be greatly reduced depending on use case
feature_extractor.size = 2**10 # 128 is the default value
TRAINING
import os
from pathlib import Path
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
from datasets import load_metric
from matplotlib import pyplot as plt
from PIL import Image
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
class SemanticSegmentationDataset(Dataset):
"""Image (semantic) segmentation dataset."""
def __init__(self, root_dir, feature_extractor):
self.root_dir = root_dir
self.feature_extractor = feature_extractor
# get the filenames of the images and labels
self.masks = sorted([str(x) for x in Path(root_dir).glob("*_mask.png")])
self.images = [x.replace("_mask.png", ".png") for x in self.masks]
self.id2label = {0: "background", 1: "mask"}
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(os.path.join(self.root_dir, self.images[idx]))
segmentation_map = Image.open(os.path.join(self.root_dir, self.masks[idx]))
encoded_inputs = self.feature_extractor(
image, segmentation_map, return_tensors="pt"
)
for k, v in encoded_inputs.items():
encoded_inputs[k].squeeze_()
return encoded_inputs
class SegformerFinetuner(pl.LightningModule):
def __init__(
self,
id2label,
train_dataloader=None,
val_dataloader=None,
test_dataloader=None,
metrics_interval=100,
):
super(SegformerFinetuner, self).__init__()
self.id2label = id2label
self.metrics_interval = metrics_interval
self.train_dl = train_dataloader
self.val_dl = val_dataloader
self.test_dl = test_dataloader
self.num_classes = len(id2label.keys())
self.label2id = {v: k for k, v in self.id2label.items()}
self.model = SegformerForSemanticSegmentation.from_pretrained(
base_model,
return_dict=False,
num_labels=self.num_classes,
id2label=self.id2label,
label2id=self.label2id,
ignore_mismatched_sizes=True,
)
self.train_mean_iou = load_metric("mean_iou")
self.val_mean_iou = load_metric("mean_iou")
self.test_mean_iou = load_metric("mean_iou")
def forward(self, images, masks):
outputs = self.model(pixel_values=images, labels=masks)
return outputs
def training_step(self, batch, batch_nb):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.train_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
if batch_nb % self.metrics_interval == 0:
metrics = self.train_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
metrics = {
"loss": loss,
"mean_iou": metrics["mean_iou"],
"mean_accuracy": metrics["mean_accuracy"],
}
for k, v in metrics.items():
self.log(k, v)
return metrics
else:
return {"loss": loss}
def validation_step(self, batch, batch_nb):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.val_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
return {"val_loss": loss}
def validation_epoch_end(self, outputs):
metrics = self.val_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_val_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
val_mean_iou = metrics["mean_iou"]
val_mean_accuracy = metrics["mean_accuracy"]
metrics = {
"val_loss": avg_val_loss,
"val_mean_iou": val_mean_iou,
"val_mean_accuracy": val_mean_accuracy,
}
for k, v in metrics.items():
self.log(k, v)
return metrics
def test_step(self, batch, batch_nb):
images, masks = batch["pixel_values"], batch["labels"]
outputs = self(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted = upsampled_logits.argmax(dim=1)
self.test_mean_iou.add_batch(
predictions=predicted.detach().cpu().numpy(),
references=masks.detach().cpu().numpy(),
)
return {"test_loss": loss}
def test_epoch_end(self, outputs):
metrics = self.test_mean_iou.compute(
num_labels=self.num_classes,
ignore_index=255,
reduce_labels=False,
)
avg_test_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
test_mean_iou = metrics["mean_iou"]
test_mean_accuracy = metrics["mean_accuracy"]
metrics = {
"test_loss": avg_test_loss,
"test_mean_iou": test_mean_iou,
"test_mean_accuracy": test_mean_accuracy,
}
for k, v in metrics.items():
self.log(k, v)
return metrics
def configure_optimizers(self):
return torch.optim.Adam(
[p for p in self.parameters() if p.requires_grad], lr=2e-05, eps=1e-08
)
def train_dataloader(self):
return self.train_dl
def val_dataloader(self):
return self.val_dl
def test_dataloader(self):
return self.test_dl
def plot_predictions(dataloader):
predicted_lst = []
truth_lst = []
for batch in tqdm(dataloader):
images, masks = batch["pixel_values"], batch["labels"]
outputs = segformer_finetuner.model(images, masks)
loss, logits = outputs[0], outputs[1]
upsampled_logits = nn.functional.interpolate(
logits, size=masks.shape[-2:], mode="bilinear", align_corners=False
)
predicted_lst.append(upsampled_logits.argmax(dim=1).cpu().numpy())
truth_lst.append(masks.cpu().numpy())
for p, t in zip(predicted_lst, truth_lst):
a = cv2.resize(p[0], dsize=(4040, 1533), interpolation=cv2.INTER_LINEAR_EXACT)
b = cv2.resize(t[0], dsize=(4040, 1533), interpolation=cv2.INTER_LINEAR_EXACT)
plt.imshow(np.vstack([a, b]), cmap="gray")
plt.show()
for base_model in [
# "nvidia/segformer-b0-finetuned-ade-512-512",
# "nvidia/segformer-b3-finetuned-cityscapes-1024-1024",
"nvidia/segformer-b2-finetuned-cityscapes-1024-1024",
]:
feature_extractor = SegformerImageProcessor.from_pretrained(base_model)
feature_extractor.do_reduce_labels = False
feature_extractor.size = 2**10
# NOTE that the directory below contains images where the base image is
# in the format blah.png and the relevant mask blah_mask.png
pth = "<path to base images and masks>"
train_dataset = SemanticSegmentationDataset(pth, feature_extractor)
val_dataset = SemanticSegmentationDataset(pth, feature_extractor)
test_dataset = SemanticSegmentationDataset(pth, feature_extractor)
batch_size = 1
train_dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
prefetch_factor=8,
)
val_dataloader = DataLoader(
val_dataset, batch_size=batch_size, num_workers=3, prefetch_factor=8
)
test_dataloader = DataLoader(
test_dataset, batch_size=batch_size, num_workers=3, prefetch_factor=8
)
segformer_finetuner = SegformerFinetuner(
train_dataset.id2label,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
metrics_interval=10,
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=5,
verbose=False,
mode="min",
)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")
trainer = pl.Trainer(
gpus=1,
# callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[checkpoint_callback],
max_epochs=200,
val_check_interval=len(train_dataloader),
)
trainer.fit(segformer_finetuner)
res = trainer.test(ckpt_path="best")
print(base_model)
plot_predictions(train_dataloader)
LOAD FROM CHECKPOINT AND PREDICT
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import torch
from PIL import Image
from torch import nn
from matplotlib import pyplot as plt
from pathlib import Path
import numpy as np
import cv2
from tqdm import tqdm
def load_model_from_trainer_checkpoint(
base_model="nvidia/segformer-b2-finetuned-cityscapes-1024-1024",
checkpoint="<path to checkpoint>.ckpt",
):
id2label = {0: "background", 1: "mask"}
label2id = {v: k for k, v in id2label.items()}
model = SegformerForSemanticSegmentation.from_pretrained(
base_model,
return_dict=False,
num_labels=2,
id2label=id2label,
label2id=label2id,
ignore_mismatched_sizes=True,
)
feature_extractor = SegformerImageProcessor.from_pretrained(base_model)
feature_extractor.do_reduce_labels = False
feature_extractor.size = 2**10
checkpoint_dict = torch.load(checkpoint)["state_dict"]
checkpoint_dict = {k.replace("model.", ""): v for k, v in checkpoint_dict.items()}
model.load_state_dict(checkpoint_dict)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
return model, feature_extractor
def predict_mask(model, feature_extractor, img_fname, plot_img=True):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image = Image.open(img_fname)
pixel_values = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
outputs = model(
pixel_values
) # logits are of shape (batch_size, num_labels, height/4, width/4)
# First, rescale logits to original image size
upsampled_logits = nn.functional.interpolate(
outputs[0],
size=image.size[::-1], # (height, width)
mode="bilinear",
align_corners=False,
)
# Second, apply argmax on the class dimension
pred_seg = upsampled_logits.argmax(dim=1)[0].detach().cpu().numpy()
mask = cv2.cvtColor((pred_seg).astype(np.uint8) * 255, cv2.COLOR_GRAY2BGR)
if plot_img:
plt.imshow(np.vstack([mask, image]))
plt.show()
return image, mask
model, feature_extractor = load_model_from_trainer_checkpoint()
for img_fname in tqdm(
list(Path("<path to dir with images to infer>").glob("*.png"))
):
image, mask = predict_mask(model, feature_extractor, img_fname, plot_img=True)