RGB Segmentation Masks to Classes in Tensorflow

So i started to play arround with Tensorflow and looked into segmenting image pixels into different classes.
The dataset i am working with, contains full RGB images of a dashcam like this one:

Sample RGB Dashcam image from the comma10k Dataset ( comma.ai )

and another RGB image with the pixels colored in one of five different colors, representing the classes like that one:

Sample RGB Mask image from the comma10k Dataset ( comma.ai )

The so called “mask” is a full color (3 channels) RGB image as well

To train my neural network, i needed to transform these RGB values into classes in the interval [0…4] and not 3 color tuples like (64,32,32).

So it would be easy to just load the image in python convert it to a numpy array and do the replacements in python.. However this is quite slow for a large dataset and needs a lot of RAM to keep all the converted “masks” in memory.
I wasn´t able to find a good solution to do this in tensorflow itself, so i came up with the below after a lot of trial and error:

import tensorflow as tf

imagePayload = tf.io.read_file("masks/0000_0085e9e41513078a_2018-08-19--13-26-08_11_864.png")
mask = tf.io.decode_png(imagePayload)

colors = [
  (64, 32, 32), # road (all parts, anywhere nobody would look at you funny for driving)
  (255, 0, 0), # lane markings (don't include non lane markings like turn arrows and crosswalks)
  (128, 128, 96), # undrivable
  (0, 255, 102), # movable (vehicles and people/animals)
  (204, 0, 255) # my car (and anything inside it, including wires, mounts, etc. No reflections)
]

one_hot_map = []
for color in colors:
    class_map = tf.reduce_all(tf.equal(mask, color), axis=-1)
    one_hot_map.append(class_map)

one_hot_map = tf.stack(one_hot_map, axis=-1)
one_hot_map = tf.cast(one_hot_map, tf.float32)

mask = tf.argmax(one_hot_map, axis=-1)

tf.print(mask)

So in line 3 and 4 we load the image into a tensor mask
This Tensor is of shape (image_height,image_width,3) so the last axis contains the RGB values

We want to convert this Tensor in a Tensor of shape (image_height,image_width) where each Tensor element contains the class value 0 – 4 depending on color.

So we first define all the colors which encode our classes in line 6 to 12..

I was unable to find a way to directly replace the RGB values in the Tensor with the class value, so we will do the following:

  1. Reduce (tf.reduce_all(…)) the Tensor on the last axis (which is our axis with the RGB values).. The new value will be either a 0 or a 1 depending on a condition (tf.equals(…)).. So we do this for each color.
    This yield 5 new Tensors with the shape (image_height, image_width) for the 5 different classes. If the pixel color matches the class a 1 is encoded, if not a 0 is encoded.
  2. We stack (tf.stack(…)) our 5 Tensor on axis “-1” which is the last axis / our 1/0 axis.. This creates a new Tensor with the Shape (image_height, image_width, 5).. The last axis is our “one hot” encoded class. so [0 0 0 1 0] for example for the “movable” / 4th class.. You could stop here, if you want one-hot encoded classes.
  3. To convert [0 0 0 1 0] to 4 or [0 1 0 0 0] to 1 we use the function tf.argmax on our “one hot” axis. This function basically return the index of the element with the highest value in the axis for each other axis.. So in our case this is the index of the one-hot encoded “1”.. Which is exactly what we want.

We now have a Tensor of the Shape (image_height,image_width) which encodes the class of the pixel in the interval [0…4].

Hope you found this helpfull.. If you have a better solution please contact me!

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.