Generic UNet implementation written in pure Julia, based on Flux.jl
40 Stars
Updated Last
1 Year Ago
Started In
January 2020


Actions Status

This pacakge provides a generic UNet implemented in Julia.

The package is built on top of Flux.jl, and therefore can be extended as needed

julia> u = Unet()
  ConvDown(64, 64)
  ConvDown(128, 128)
  ConvDown(256, 256)
  ConvDown(512, 512)

  UNetConvBlock(1, 3)
  UNetConvBlock(3, 64)
  UNetConvBlock(64, 128)
  UNetConvBlock(128, 256)
  UNetConvBlock(256, 512)
  UNetConvBlock(512, 1024)
  UNetConvBlock(1024, 1024)

  UNetUpBlock(1024, 512)
  UNetUpBlock(1024, 256)
  UNetUpBlock(512, 128)
  UNetUpBlock(256, 64)

The default input channel dimension is expected to be 1 ie. grayscale. To support different channel images, you can pass the channels to Unet.

julia> u = Unet(3) # for RGB images

The input size can be any power of two sized batch. Something like (256,256, channels, batch_size).

The default output channel dimension is the input channel dimension. So, 1 for a Unet() and e.g. 3 for a Unet(3). The output channel dimension can be set by supplying a second argument:

julia> u = Unet(3, 5) # 3 input channels, 5 output channels.

GPU Support

To train the model on UNet, it is as simple as calling gpu on the model.

julia> u = Unet();

julia> u = gpu(u);

julia> r = gpu(rand(Float32, 256, 256, 1, 1));

julia> size(u(r))
(256, 256, 1, 1)


Training UNet is a breeze too.

You can define your own loss function, or use Flux binary cross entropy implementation.

using UNet, Flux,  Base.Iterators
import Flux.Losses.binarycrossentropy

device = gpu #cpu

function loss(x, y)
    op = clamp.(u(x), 0.001f0, 1.f0)

u = Unet() |> device
w = rand(Float32, 256, 256, 1, 1) |> device
w′ = rand(Float32, 256, 256, 1, 1) |> device
rep = Iterators.repeated((w, w′), 10)

opt = ADAM()

Flux.train!(loss, Flux.params(u), rep, opt, cb = () -> @show(loss(w, w′)))

Further Reading

The package is an implementation of the paper, and all credits of the model itself go to the respective authors.

Used By Packages

No packages found.