DEEPLAB_V3
Return a segmentation mask from an input image. The input image is expected to be a DataContainer of an 'image' type.
The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask. Params: default : Image The input image to be segmented. Returns: out : Image The segmented image.
Python Code
from flojoy import Image, flojoy
@flojoy(deps={"torch": "2.0.1", "torchvision": "0.15.2"})
def DEEPLAB_V3(default: Image) -> Image:
"""Return a segmentation mask from an input image.
The input image is expected to be a DataContainer of an 'image' type.
The output is a DataContainer of an 'image' type with the same dimensions as the input image, but with the red, green, and blue channels replaced with the segmentation mask.
Parameters
----------
default : Image
The input image to be segmented.
Returns
-------
Image
The segmented image.
"""
import os
import numpy as np
import PIL.Image
import torch
import torchvision.transforms.functional as TF
from flojoy import Image
from flojoy.utils import FLOJOY_CACHE_DIR
from torchvision import transforms
# Parse input image
input_image = default
r, g, b, a = input_image.r, input_image.g, input_image.b, input_image.a
nparray = (
np.stack((r, g, b, a), axis=2) if a is not None else np.stack((r, g, b), axis=2)
)
# Convert input image
input_image = TF.to_pil_image(nparray).convert("RGB")
# Set torch hub cache directory
torch.hub.set_dir(os.path.join(FLOJOY_CACHE_DIR, "cache", "torch_hub"))
model = torch.hub.load(
"pytorch/vision:v0.15.2",
"deeplabv3_resnet50",
pretrained=True,
skip_validation=True,
)
model.eval()
# Preprocessing
preprocess_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Feed the input image to the model
input_tensor = preprocess_transform(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.inference_mode():
output = model(input_batch)["out"][0]
# Fetch the output
output_predictions = output.argmax(0)
palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
# plot the semantic segmentation predictions of 21 classes in each color
r = PIL.Image.fromarray(output_predictions.byte().cpu().numpy()).resize(
input_image.size
)
r.putpalette(colors)
out_img = np.array(r.convert("RGB"))
# Build the output image
return Image(
r=out_img[:, :, 0],
g=out_img[:, :, 1],
b=out_img[:, :, 2],
a=None,
)
Example
Having problems with this example app? Join our Discord community and we will help you out!
In this example, the node DEEPLAB_V3
is producing a segmentation image mask from an input image generated by the LOCAL_FILE
node.