# myat.py - Implementation of an Ambrosio-Tortorelli segmentation
#
# Author: Stefan Fuertinger [stefan.fuertinger@gmx.at]
# Created: August 22 2012
# Last modified: <2017-09-14 11:08:26>
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import norm
from scipy.sparse import spdiags, linalg
from difftools import fidop2d
##########################################################################################
[docs]def myat(f,ep,nu=1,de=1,la=1,tol=1e-4,itmax=100,iplot=False,Dx=None,Dy=None,Lh=None):
"""
Solve the Ambrosio--Tortorelli approximation of the Mumford--Shah functional
Parameters
----------
f : NumPy 2darray
Raw (noisy) input image to be segmented. Note that `f` has to be square!
ep : float
Positive edge-"thickness" parameter in the approximation functional. For `ep -> 0`
the Ambrosio--Tortorelli approximation Gamma-converges to the Mumford--Shah
functional (see Notes below).
nu : float
Positive parameter determining the influence of the Ambrosio--Tortorelli terms in the
functional.
de : float
Positive parameter influencing the smoothness regularization term for `u` in the functional.
la : float
Positive parameter weighing the data fidelity term in the functional.
tol : float
Error tolerance for the stopping criterion satisfying `0 < tol << 1`.
itmax : int
Integer, the maximal number of iterations.
iplot : bool
Switch to turn interactive plotting on (`iplot=True`) or off (`iplot=False`)
Dx : NumPy/SciPy matrix
Disrecte derivative operator in direction `x` (foward differences are recommended).
Note that if `f` is `N`-by-`N` then `Dx` has to be `N**2`-by-`N**2`!
Dy : NumPy/SciPy matrix
Discrete derivative operator in direction `y` (forward differences are recommended).
Note that if `f` is `N`-by-`N` then `Dy` has to be `N**2`-by-`N**2`!
Lh : NumPy/SciPy matrix
Discrete Laplace operator (central differences are recommended).
Note that if `f` is `N`-by-`N` then `Dy` has to be `N**2`-by-`N**2`!
Returns
-------
u : NumPy 2darray
The smoothed version of `f`.
v : NumPy 2darray
The fuzzy edge map of `f`.
Notes
-----
The Ambrosio--Tortorelli functional [1]_ is given by
.. math::
J_{\\varepsilon}[u,v] = \\int_{\\Omega}\\frac{\\nu\\varepsilon}{2}|\\nabla v|^2 + \\frac{\\nu}{2\\varepsilon}(1-v)^2 + \\frac{\\delta}{2} v^2|\\nabla u|^2 + \\frac{\\lambda}{2} (u-f)^2 dx
The associated Euler--Lagrange equations are
.. math::
\\begin{align}
-\\delta \\mathrm{div}(v^2 \\nabla u) + \\lambda u &= \\lambda f \\\\
\\frac{\\partial u}{\\partial n} &= 0
\\end{align}
and
.. math::
\\begin{align}
-\\nu\\varepsilon\\Delta v + \\frac{\\nu}{\\varepsilon}v + \\delta v|\\nabla u|^2 &= \\frac{\\nu}{\\varepsilon} \\\\
\\frac{\\partial u}{\\partial n} &= 0
\\end{align}
The alternate optimization algorithm is initialized using
.. math::
\\begin{align}
u_0 &= f \\\\
v_0 &= 1/(1 + \\delta \\frac{\\varepsilon}{\\nu} |\\nabla f|^2)
\\end{align}
It can be shown that :math:`J_{\\epsilon}` Gamma-converges to the Mumford--Shah functional
for :math:`\\varepsilon \\rightarrow 0` (see, e.g., [2]_).
References
----------
.. [1] L. Ambrosio and V.M. Tortorelli. Approximation of functionals depending on
jumps by elliptic functionals via Gamma-convergence. Communications on Pure and
Applied Mathematics, 43:999-1036, 1990.
.. [2] G. Aubert and P. Kornprobst: "Mathematical Problems in Image Processing: Partial Differential
Equations and the Calculus of Variations", Springer 2006.
"""
# Sanity checks
if type(f).__name__ != "ndarray":
raise TypeError("f has to be a (square) NumPy 2darray!")
else:
if len(f.shape) > 2: raise ValueError("f has to be 2-dimensional!")
N = f.shape[0]
try: M = f.shape[1]
except: raise ValueError("f has to be square!")
if N!=M: raise ValueError("f has to be square!")
if np.isnan(f).max() == True or np.isinf(f).max() == True or np.isreal(f).min() == False:
raise ValueError("f must be real and must not contain NaNs or Infs!")
try: ep/2.0
except: raise TypeError("ep has to be a positive float!")
if ep < 0: raise ValueError("ep has to be > 0!")
try: nu/2.0
except: raise TypeError("nu has to be a positive float!")
if nu < 0: raise ValueError("nu has to be > 0!")
try: de/2.0
except: raise TypeError("de has to be a positive float!")
if de < 0: raise ValueError("de has to be > 0!")
try: la/2.0
except: raise TypeError("la has to be a positive float!")
if la < 0: raise ValueError("la has to be > 0!")
try: tol/2.0
except: raise TypeError("tol has to be a positive integer!")
if tol > 1: raise ValueError("tol has to be << 1!")
try: itmax/2.0
except: raise TypeError("itmax has to be a positive integer!")
if itmax < 1: raise ValueError("itmax has to be >= 1!")
if np.round(itmax) != itmax:
itmax = np.round(itmax)
print "WARNING: itmax has to be an integer - using round(itmax) = ",itmax," now..."
msg = "The switch `iplot` has to be Boolean!"
try:
bad = (iplot != True and iplot != False)
except: raise TypeError(msg)
if bad: raise TypeError(msg)
if (Dx != None and Dy == None) or (Dx == None and Dy != None):
print "WARNING: Dx or Dy not provided, switching to default Dx and Dy"
Dx,Dy = fidop2d(N,'xy','f')
elif Lh == None and (Dx != None):
print "WARNING: Dx and Dy given but not Lh - using defaults for Dx and Dy. Lh will be computed as -(Dx.T*Dx + Dy.T*Dy)"
Dx,Dy = fidop2d(N,'xy','f')
elif Dx == None and Dy == None:
Dx,Dy = fidop2d(N,'xy','f')
else:
if type(Dx).__name__.rfind("matrix") == -1:
raise TypeError("Dx has to be a SciPy/Numpy matrix!")
else:
NN = Dx.shape[0]
if NN != Dx.shape[1]: raise ValueError("Dx has to be a square matrix!")
if NN != N**2: raise ValueError("Dx has to be of dimension %s**2 = %s"%(repr(N),repr(N**2)))
if type(Dy).__name__.rfind("matrix") == -1:
raise TypeError("Dy has to be a SciPy/NumPy matrix!")
else:
NN = Dy.shape[0]
if NN != Dy.shape[1]: raise ValueError("Dy has to be a square matrix!")
if NN != N**2: raise ValueError("Dy has to be of dimension %s**2 = %s"%(repr(N),repr(N**2)))
if Lh != None:
if type(Lh).__name__.rfind("matrix") == -1:
raise TypeError("Lh has to be a SciPy/NumPy matrix!")
else:
NN = Lh.shape[0]
if NN != Lh.shape[1]: raise ValueError("Lh has to be a square matrix!")
if NN != N**2: raise ValueError("Lh has to be of dimension %s**2 = %s"%(repr(N),repr(N**2)))
else:
Lh = -(Dx.T*Dx + Dy.T*Dy)
# Get squared image dimension
NN = N**2
# Allocate memory for u and v
u = np.zeros(f.shape)
v = np.zeros(f.shape)
# Initial guess for u
u = f.copy()
# Initial guess for v
v = (1 + de*ep/nu*((Dx*f.flatten(1))**2 + (Dy*f.flatten(1))**2))**(-1)
# Convert u,v and f to vectors
u = u.flatten(1)
v = v.flatten(1)
f = f.flatten(1)
# Set up plotting stuff if necessary
if (iplot): fig = getfig(f,de,la,nu,ep)
# Show initial guess(es)
if (iplot): showit(fig,u,v)
# Initialize iteration parameters
it = 0;
rerru = 0
rerrv = 0
rerr = 2*tol
Jold = 0.0
nfo = 'inc','dec'
ep1 = 1.0e-6
# Allocate memory for right-hand-sides, iterates and the constant matrix
rhsv = nu/ep*np.ones((NN,))
rhsu = la*f
Dla = spdiags(la*np.ones((NN,)),0,NN,NN)
un = np.zeros(u.shape)
vn = np.zeros(v.shape)
JN = np.zeros((NN,))
# Allocate non-zero structure of varying matrices
Av = Lh.copy()
Au = Dx.T*Dx
Dv = spdiags(np.ones((NN,)),0,NN,NN)
# The fpi-loop
while rerr > tol and it <= itmax:
# Update iteration counter
it += 1
# Compute new v
Av = spdiags(de*((Dx*u)**2 + (Dy*u)**2) + nu/ep,0,NN,NN) - nu*ep*Lh
vn = linalg.spsolve(Av.tocsr(),rhsv)
# Compute new u
Dv = spdiags(v**2,0,NN,NN)
Au = de*Dx.T*Dv*Dx + de*Dy.T*Dv*Dy + Dla
un = linalg.spsolve(Au.tocsr(),rhsu)
# Compute value of cost
JN = nu*ep/2*((Dx*v)**2 + (Dy*v)**2) + nu/(2*ep)*(1-v)**2 \
+ de/2*v**2*((Dx*u)**2 + (Dy*u)**2) + la/2*(u-f)**2
J = np.sum(JN)
# Compute relative errors
rerru = norm(un - u)/(norm(un + ep1))
rerrv = norm(vn - v)/(norm(vn + ep1))
rerr = max(rerru,rerrv)
# Show info in prompt
print "it = %s, rerr = %s, J = %s, %s"%(repr(it),repr(rerr),repr(J),nfo[Jold>J])
# Update iterates and cost
u = un.copy()
v = vn.copy()
Jold = J
# Plot intermediate results every mplot steps
if (iplot): showit(fig,u,v)
# Show final results
if (iplot): showit(fig,u,v)
# Convert u and v back to images
u = u.reshape(N,N,order="F")
v = v.reshape(N,N,order="F")
return u,v
##########################################################################################
def getfig(f,de,la,nu,ep):
"""
Set up Figure for interactive plotting
"""
# Set up figure and assign window- and sup-title
fig = plt.figure()
fig.canvas.set_window_title("Ambrosio-Tortorelli")
fig.suptitle(r'$\delta = %s,\quad \lambda = %s,\quad \nu = %s,\quad \varepsilon = %s$'\
%(repr(de),repr(la),repr(nu),repr(ep)), fontsize=14)
# Get image dimension
N = np.sqrt(f.shape[0])
# Plot original image f
ax = fig.add_subplot(1,3,1)
plt.sca(ax)
ax.set_title(r"$f$")
plt.imshow(f.reshape(N,N,order="F"),interpolation='nearest',cmap='gray')
plt.draw()
return fig
##########################################################################################
def showit(fig,u,v):
"""
Show iteration process
"""
# Get image dimension
N = np.sqrt(u.shape[0])
# Plot u
ax = fig.add_subplot(1,3,2)
plt.sca(ax)
ax.set_title(r"$u$")
plt.imshow(u.reshape(N,N,order="F"),interpolation='nearest',cmap='gray')
plt.draw()
# Plot v
ax = fig.add_subplot(1,3,3)
plt.sca(ax)
ax.set_title(r"$v$")
plt.imshow(v.reshape(N,N,order="F"),interpolation='nearest',cmap='gray')
plt.draw()