# -*- coding: utf-8 -*-
# ELEKTRONN - Neural Network Toolkit
#
# Copyright (c) 2014 - now
# Max-Planck-Institute for Medical Research, Heidelberg, Germany
# Authors: Marius Killinger, Gregor Urban
import time
import matplotlib.pyplot as plt
import numpy as np
# try:
#     from ._warping import warp2dFast, warp3dFast, _warp2dFastLab, _warp3dFastLab
# except ImportError:
#     raise RuntimeError('_warping.so Cython extension not found.\n'
#                        'Please run setup.py or manually cythonize _warping.pyx.')
from _warping import warp2dFast, warp3dFast, _warp2dFastLab, _warp3dFastLab
[docs]def warp2dJoint(img, lab, patch_size, rot, shear, scale, stretch):
    """
    Warp image and label data jointly. Non-image labels are ignored i.e. lab must be 3d to be warped
    Parameters
    ----------
    img: array
      Image data
      The array must be 3-dimensional (ch,x,y) and larger/equal the patch size
    lab: array
      Label data (with offsets subtracted)
    patch_size: 2-tuple
      Patch size *excluding* channel for the image: (px, py).
      The warping result of the input image is cropped to this size
    rot: float
      Rotation angle in deg for rotation around z-axis
    shear: float
      Shear angle in deg for shear w.r.t xy-diagonal
    scale: 3-tuple of float
      Scale per axis
    stretch: 4-tuple of float
      Fraction of perspective stretching from the center (where stretching is always 1)
      to the outer border of image per axis. The 4 entry correspond to:
      - X stretching depending on Y
      - Y stretching depending on X
    Returns
    -------
    img, lab: np.ndarrays
      Warped image and labels (cropped to patch_size)
    """
    if len(lab.shape) == 2:
        lab = _warp2dFastLab(lab, patch_size, img.shape[1:], rot, shear, scale, stretch)
    img = warp2dFast(img, patch_size, rot, shear, scale, stretch)
    return img, lab 
[docs]def warp3dJoint(img, lab, patch_size, rot=0, shear=0, scale=(1, 1, 1), stretch=(0, 0, 0, 0), twist=0):
    """
    Warp image and label data jointly. Non-image labels are ignored i.e. lab must be 3d to be warped
    Parameters
    ----------
    img: array
      Image data
      The array must be 4-dimensional (z,ch,x,y) and larger/equal the patch size
    lab: array
      Label data (with offsets subtracted)
    patch_size: 3-tuple
      Patch size *excluding* channel for the image: (pz, px, py).
      The warping result of the input image is cropped to this size
    rot: float
      Rotation angle in deg for rotation around z-axis
    shear: float
      Shear angle in deg for shear w.r.t xy-diagonal
    scale: 3-tuple of float
      Scale per axis
    stretch: 4-tuple of float
      Fraction of perspective stretching from the center (where stretching is always 1)
      to the outer border of image per axis. The 4 entry correspond to:
      - X stretching depending on Y
      - Y stretching depending on X
      - X stretching depending on Z
      - Y stretching depending on Z
    twist: float
      Dependence of the rotation angle on z in deg from center to outer border
    Returns
    -------
    img, lab: np.ndarrays
      Warped image and labels (cropped to patch_size)
    """
    if len(lab.shape) == 3:
        lab = _warp3dFastLab(lab, patch_size, np.array(img.shape)[[0, 2, 3]], rot, shear, scale, stretch, twist)
    img = warp3dFast(img, patch_size, rot, shear, scale, stretch, twist)
    return img, lab 
### Utilities #################################################################
###############################################################################
[docs]def getCornerIx(sh):
    """Returns array-indices of corner elements for n-dim shape"""
    def getGrayCode(n, n_dim):
        if n == 0:
            return np.zeros(n_dim, dtype=np.int)
        return np.array([(n // 2**i) % 2 for i in range(max(n_dim, int(np.ceil(np.log2(n)))))])
    sh = np.array(sh) - 1  ###TODO
    n_dim = len(sh)
    ix = []
    for i in xrange(2**n_dim):
        ix.append(getGrayCode(i, n_dim))
    ix = np.array(ix)
    corners = ix * sh
    return corners 
def _warpCorners2d(sh, corners, rot=0, shear=0, scale=(1, 1), stretch=(0, 0), plot=False):
    """
    Create warped coordinates of corners
    """
    rot = rot * np.pi / 180
    shear = shear * np.pi / 180
    scale = np.array(scale)
    scale = 1.0 / scale
    stretch = np.array(stretch)
    corners = corners.astype(np.float).copy()
    x_center_off = float(sh[0]) / 2 - 0.5
    y_center_off = float(sh[1]) / 2 - 0.5
    stretch[0] /= x_center_off
    stretch[1] /= y_center_off
    x = corners[:, 0] - x_center_off
    y = corners[:, 1] - y_center_off
    xt = x * (scale[0] + stretch[0] * y)
    yt = y * (scale[1] + stretch[1] * x)
    u = xt * np.cos(rot - shear) - yt * np.sin(rot + shear) + x_center_off
    v = yt * np.cos(rot + shear) + xt * np.sin(rot - shear) + y_center_off
    if plot:
        coords = np.array([u, v]).T
        coords = coords[[3, 2, 0, 1, 3]]
        plt.figure(figsize=(5, 5))
        plt.scatter(corners[:, 1], corners[:, 0], c='b')
        plt.plot(corners[:, 1], corners[:, 0], 'b:')
        plt.scatter(coords[:, 1], coords[:, 0], c='r', marker='x')
        plt.plot(coords[:, 1], coords[:, 0], c='r')
        plt.axes().set_aspect('equal')
        plt.gca().invert_yaxis()
        plt.grid()
    return np.array([u, v]).T
def _warpCorners3d(sh, corners, rot=0, shear=0, scale=(1, 1, 1), stretch=(0, 0, 0, 0), twist=0):
    """
    Create warped coordinates of corners
    """
    rot = rot * np.pi / 180
    shear = shear * np.pi / 180
    twist = twist * np.pi / 180
    scale = np.array(scale)
    scale = 1.0 / scale
    stretch = np.array(stretch)
    corners = corners.astype(np.float).copy()
    z_center_off = float(sh[0]) / 2 - 0.5
    x_center_off = float(sh[1]) / 2 - 0.5
    y_center_off = float(sh[2]) / 2 - 0.5
    stretch[0] /= x_center_off
    stretch[1] /= y_center_off
    stretch[2] /= z_center_off
    stretch[3] /= z_center_off
    twist /= z_center_off
    z = corners[:, 0] - z_center_off
    x = corners[:, 1] - x_center_off
    y = corners[:, 2] - y_center_off
    w = z * scale[2] + z_center_off
    rot = rot + (z * twist)
    xt = x * (scale[0] + stretch[0] * y + stretch[2] * z)
    yt = y * (scale[1] + stretch[1] * x + stretch[3] * z)
    u = xt * np.cos(rot - shear) - yt * np.sin(rot + shear) + x_center_off
    v = yt * np.cos(rot + shear) + xt * np.sin(rot - shear) + y_center_off
    return np.array((w, u, v)).T
[docs]def getRequiredPatchSize(patch_size, rot, shear, scale, stretch, twist=None):
    """
    Given desired patch size and warping parameters:
    return required size for warping input patch
    """
    patch_size = np.array(patch_size)
    corners = getCornerIx(patch_size)
    if len(patch_size) == 2:
        coords = _warpCorners2d(patch_size, corners, rot, shear, scale, stretch)
    elif len(patch_size) == 3:
        coords = _warpCorners3d(patch_size, corners, rot, shear, scale, stretch, twist)
    eff_size = np.ceil(coords.max(axis=0) - coords.min(axis=0))  # effective range
    left_exc = np.floor(np.abs(np.minimum(coords.min(axis=0), 0)))  # how much image needs to be added left
    right_exc = np.ceil(np.maximum(coords.max(axis=0) - patch_size + 1, 0))
    total_exc = np.maximum(left_exc, right_exc)  # how much image must be added centrally
    req_size = patch_size + 2 * total_exc
    return req_size.astype(np.int), eff_size.astype(np.int), left_exc.astype(np.int) 
[docs]def getWarpParams(patch_size, amount=1.0):
    """
    To be called from CNNData. Get warping parameters + required warping input patch size.
    """
    if amount > 1:
        print 'WARNING: warpAugment amount > 1 this requires more than 1.4 bigger patches before warping'
    rot_max = 15 * amount
    shear_max = 3 * amount
    scale_max = 1.1 * amount
    stretch_max = 0.1 * amount
    n_dim = len(patch_size)
    shear = shear_max * 2 * (np.random.rand() - 0.5)
    if n_dim == 3:
        twist = rot_max * 2 * (np.random.rand() - 0.5)
        rot = min(rot_max - abs(twist), rot_max * (np.random.rand()))
        scale = 1 + (scale_max - 1) * np.random.rand(3)
        stretch = stretch_max * 2 * (np.random.rand(4) - 0.5)
    elif n_dim == 2:
        rot = rot_max * 2 * (np.random.rand() - 0.5)
        scale = 1 + (scale_max - 1) * np.random.rand(2)
        scale[0] = 1  # do not change along z!
        stretch = stretch_max * 2 * (np.random.rand(2) - 0.5)
        twist = None
    req_size, _, _ = getRequiredPatchSize(patch_size, rot, shear, scale,
                                          stretch, twist)
    return req_size, rot, shear, scale, stretch, twist 
[docs]def test():
    try:
        img_s = np.random.rand(11, 11)
        img_s = np.concatenate((img_s[None], np.exp(img_s[None])), axis=0)
        out = warp2dFast(img_s, (11, 11), 0, 0, (1, 1), (0.0, 0.0))
    except Exception as e:
        print """%s
        Warping is broken. Most likeley the distributed _warping.so is not binary compatible to your system.""" % (e, ) 
test()
##############################################################################################################
#def paddImage(img,  ext_size, left_exc):
#  new_img = np.ones(ext_size, dtype=img.dtype) 
#  xs, ys = img.shape
#  xo, yo = left_exc
#  new_img[xo:xo+xs, yo:yo+ys] = img
#
#  return new_img
[docs]def maketestimage(sh):
    img = np.ones(sh) * 0.5
    xs, ys = sh
    try:
        d = np.diag(np.ones(xs))
        img[:xs, :xs] += d
        img[:xs, :xs] += d[::-1]
        img[-xs:, -xs:] += d
        img[-xs:, -xs:] += d[::-1]
    except:
        d = np.diag(np.ones(ys))
        img[:ys, :ys] += d
        img[:ys, :ys] += d[::-1]
        img[-ys:, -ys:] += d
        img[-ys:, -ys:] += d[::-1]
    img[0, :] += 1
    img[:, 0] += 1
    img[-1, :] += 1
    img[:, -1] += 1
    if sh[0] > 80:
        img[30, :] += 1
        img[:, 30] += 1
        img[-31, :] += 1
        img[:, -31] += 1
    return img / img.max() 
if __name__ == "__main__":
    #  test_img = io.imread('Lichtenstein.png')
    #  test_img = test_img.mean(axis=2)
    #  s1 = test_img.shape[0]
    #  s2 = test_img.shape[1]
    ps = (200, 200)
    if True:
        ext_size, rot, shear, scale, stretch, twist = getWarpParams(ps,
                                                                    amount=1.0)
        t = []
        for i in xrange(10000):
            ext_size, rot, shear, scale, stretch, twist = getWarpParams(ps, amount=1.0); t.append(ext_size)
    #  img_in = maketestimage(eff_size)
    #  img_in = paddImage(img_in, ext_size, left_exc)[None]
        img_in = maketestimage(ext_size)[None]
        out = warp2dFast(img_in, ps, rot, shear, scale, stretch)
        plt.figure()
        plt.subplot(121)
        plt.imshow(img_in[0], interpolation='none', cmap='gray')
        plt.hlines(ext_size[0] / 2 - 0.5, 0, ext_size[1] - 1, color='r')
        plt.vlines(ext_size[1] / 2 - 0.5, 0, ext_size[0] - 1, color='r')
        plt.subplot(122)
        plt.imshow(out[0], interpolation='none', cmap='gray')
        plt.hlines(ps[0] / 2 - 0.5, 0, ps[1] - 1, color='r')
        plt.vlines(ps[1] / 2 - 0.5, 0, ps[0] - 1, color='r')
    if False:  # visual 2d
        #out  = _warp2d_c(test_img, 20, 10, (1,1.1), (0.1, 0))
        test_img = np.concatenate((test_img[None], np.exp(test_img[None])), axis=0)
        out2 = warp2dFast(test_img, (512, 512), 20, 10, (1, 1.1), (0.1, 0))
        plt.figure()
        plt.subplot(121)
        plt.imshow(test_img, interpolation='none', cmap='gray')
        plt.subplot(122)
        plt.imshow(out2[0], interpolation='none', cmap='gray')
    if False:
        img_s = maketestimage((11, 11))
        img_s = np.concatenate((img_s[None], np.exp(img_s[None])), axis=0)
        out = warp2dFast(img_s, (11, 11), 0, 0, (1, 1), (0.0, 0.0))
    if False:
        img_s = maketestimage((110, 110))
        img_s = np.concatenate((img_s[None], ) * 4, axis=0)
        img_s = np.concatenate((img_s[None], np.exp(img_s[None])), axis=0)
        out = warp3dFast(img_s, (4, 110, 110), 0, 0, (1, 1, 1),
                         (0.0, 0.0, 0.0, 0.0), 10)
    if False:  # visual 3d
        n = 100
        img_s = np.tile(test_img, n)
        img_s = img_s.reshape((s1, n, s2))
        img_s = np.swapaxes(img_s, 1, 0)
        patch_size = img_s.shape
        off = 0
        lab = img_s[off:-off, off:-off, off:-off]
        img = np.concatenate((test_img[None], np.exp(test_img[None])), axis=0)
        img1, lab1 = warpAugment(img_s[None], lab, patch_size=patch_size)
        for i in xrange(n):
            plt.imsave('/tmp/%i-img.png' % i, img1[0, i, :, :] / 255)
        for i in xrange(lab1.shape[0]):
            plt.imsave('/tmp/%i-lab.png' % (i + off), lab1[i, :, :])
    if False:  # visual 3d
        n = 40
        img_s = np.tile(test_img, n)
        img_s = img_s.reshape((s1, n, s2))
        img_s = np.swapaxes(img_s, 1, 2)
        wow1 = warp3dFast(img_s[None], (s1, s2, n), 0, 0, (1, 1, 1),
                          (0.1, 0.1, 0.1, -0.1), 10)
        for i in xrange(n):
            plt.imsave('/tmp/%i-ref.png' % i, wow1[:, :, i] / 255)
        wow2 = _warp3dFastLab(img_s[20:-20, 20:-20], (s1 - 40, s2 - 40, n),
                              (s1, s2, n), 0, 0, (1, 1, 1), (0.1, 0.1, 0.1, -0.1), 10)
        for i in xrange(wow2.shape[2]):
            plt.imsave('/tmp/%i.png' % i, wow2[:, :, i] / 255)
    if False:  # 3d timing
        s = 400
        test = np.random.rand(s, s, s).astype(np.float32)
        test2 = np.random.rand(s * 2, s * 2, s * 2).astype(np.float32)
        t0 = time.time()
        wow1 = warp3dFast(test[None], (s, s, s), 20, 5, (1, 1, 1), (0.1, 0.1, 0.1, 0.1), 10)
        #wow1 = warp3dFast(test[None], (s,s,s))
        print time.time() - t0