# Offline FXAA filter based on the original NVidia paper. The algorithm isn't
# completely the same, I just took the core idea.
#
# by drummyfish, 2020, released under CC0 1.0, public domain

import sys
from PIL import Image
import random

EDGE_THRESHOLD = 95

FILENAME = sys.argv[1]

image = Image.open(FILENAME).convert("RGB")
pixels = image.load()

WIDTH = image.size[0]
HEIGHT = image.size[1]

image2 = Image.new("RGB",(WIDTH,HEIGHT),color="black")
pixels2 = image2.load()

image3 = Image.new("RGB",(WIDTH,HEIGHT),color="black")
pixels3 = image3.load()

def pixelDist(p1,p2):
  return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) + abs(p1[2] - p2[2])

"""
  For given edge end (left or right) check its two diagonal neighbours to decide
  where the edge will continue. Returns a position between 0.0 and 2.0 (opsition
  within two pixels that form the edge).
"""

def edgeEndPos(coords, cAdd, cAddPerp, edgeType):
  a = pixels2[coords[0] + cAdd[0] + cAddPerp[0], coords[1] + cAdd[1] + cAddPerp[1]][0]
  b = pixels2[coords[0] + cAdd[0] - cAddPerp[0], coords[1] + cAdd[1] - cAddPerp[1]][0]

  if a == edgeType:
    if b != edgeType:
      return 0.5
  else:
    if b == edgeType:
      return 1.5

  return 1.0  

for y in range(1,HEIGHT - 1): # detect edges
  for x in range(1,WIDTH - 1):

    # neighbour pixels:

    p00 = pixels[x,y]
    p10 = pixels[x + 1,y]
    p01 = pixels[x,y + 1]
    p11 = pixels[x + 1,y + 1]

    v = pixelDist(p00,p10)
    h = pixelDist(p00,p01)

    if min(h,v) > EDGE_THRESHOLD: # if unclear, look at wider neighbourhood and decide between H and V
      pA = pixels[x + 1,y - 1]
      pB = pixels[x - 1,y + 1]

      if pixelDist(p10,pA) > pixelDist(p01,pB):
        v = 0
      else:
        h = 0

    """
    Mark edge as horizontal (255) or vertical (127) in the R component of the pixel.
    We don't mark the pixel as edge if there is an edge immediately next to it, to
    prevent touching borders, which cause trouble.
    """

    edge = 0

    if h > v:
      if h > EDGE_THRESHOLD:
        edge = 255 if pixels2[x,y - 1][0] != 255 and pixels2[x - 1,y][0] != 127 else 0
    else:
      if v > EDGE_THRESHOLD:
        edge = 127 if pixels2[x - 1,y][0] != 127 and pixels2[x,y - 1][0] != 255 else 0

    pixels2[x,y] = edge

for y in range(1,HEIGHT - 1): # extra pass for filtering some specific edge patterns
  for x in range(1,WIDTH - 1):

    area = []

    for j in range(y - 1, y + 2):
      for i in range(x - 1, x + 2):
        area.append(pixels2[i,j][0])

    if area[4] != 0:
      isolated = True

      for i in (0,1,2,3,5,6,7,8):
        if area[i] == area[4]:
          isolated = False
          break

      newVal = 0

      if isolated:
        if area[3] == 255 or area[5] == 255:
          newVal = 255
        elif area[1] == 127 or area[7] == 127:
          newVal = 127
        else:
          for i in (0,2,6,8):
            if area[i] != 0:
              newVal = area[i]
              break

        pixels2[x,y] = (newVal,0,0)

image2.save("out_edges.png")

for y in range(HEIGHT): # smooth
  for x in range(WIDTH):
    p = pixels2[x,y]

    if p[2] == 255: # marked done?
      continue

    if p[0] == 0:   # is not an edge?
      pixels3[x,y] = pixels[x,y]
      continue
    
    start = (x,y)   # start coords of the edge
    end = (x,y)
    cAdd = (1,0) if p[0] == 255 else (0,1)       # direction of the edge
    cAddPerp = (0,1) if p[0] == 255 else (1,0)   # perpendicular dir.
    length = 0

    startPos = edgeEndPos(start,(- 1 * cAdd[0], -1 * cAdd[1]),cAddPerp,p[0])
    endPos = 1.0

    while True:     # find the length of the edge to compute the slope
      length += 1

      endPrev = end
 
      end = (end[0] + cAdd[0],end[1] + cAdd[1])

      if (end[0] == WIDTH or end[1] == HEIGHT) or (pixels2[end][0] != p[0] or pixels2[end][2] == 255):
        endPos = edgeEndPos(endPrev,cAdd,cAddPerp,p[0])
        break

    c = start

    pos = startPos    # position withing the edge (<0,2>), perpendicular to it

    posStep = (endPos - startPos) / length
 
    for i in range(length):
      cPerp = (c[0] + cAddPerp[0],c[1] + cAddPerp[1])

      pix1 = pixels[c] 
      pix2 = pixels[cPerp]

      t = 1.0 if pos <= 1.0 else (2.0 - pos)
      t2 = 1.0 - t

      pixels3[c] = (
        int(pix1[0] * t + pix2[0] * t2),
        int(pix1[1] * t + pix2[1] * t2),
        int(pix1[2] * t + pix2[2] * t2)
        )

      t = 1.0 if pos >= 1.0 else pos
      t2 = 1.0 - t

      pixels3[cPerp] = (
        int(pix1[0] * t2 + pix2[0] * t),
        int(pix1[1] * t2 + pix2[1] * t),
        int(pix1[2] * t2 + pix2[2] * t)
        )

      pixels2[c] = (p[0],0,255)     # mark done
      pixels2[cPerp] = (p[0],0,255) # mark done

      c = (c[0] + cAdd[0],c[1] + cAdd[1])

      pos += posStep

image3.save("out.png")
