"""Výpočet vzdálenosti start-konec z DNA sekvence."""

import json
import numpy as np
import pandas as pd
from scipy.spatial.transform import Rotation

# Načtení dat
with open("instructions-solution.json", "r") as f:
    CODONS = json.load(f)

with open("DNA.in", "r") as f:
    seq = f.read().strip().upper()
    seq = "".join(ch for ch in seq if ch in "ATGC")

# Dekódování DNA
out = ""
pos = 0
stack = []

while 0 <= pos < len(seq):
    codon = seq[pos:pos+3]
    if len(codon) < 3:
        break
    
    instr = CODONS.get(codon)
    if instr is None:
        pos += 3
        continue
    
    if instr.startswith("ROT"):
        rot_num = int(instr[-1])
        stack.append((instr, rot_num))
        if rot_num == 4:
            values = list(CODONS.values())
            keys = list(CODONS.keys())
            CODONS = dict(zip(keys, [values[-1]] + values[:-1]))
        else:
            layer = rot_num - 1
            CODONS = {"".join(list(c)[:layer] + [{"T":"C","C":"G","G":"A","A":"T"}[c[layer]]] + list(c)[layer+1:]): v for c, v in CODONS.items()}
    elif instr == "SKIP":
        skip_count = 1
        next_pos = pos + 3
        while seq[next_pos:next_pos+3] and CODONS.get(seq[next_pos:next_pos+3]) == "SKIP":
            skip_count += 1
            next_pos += 3
        # Posun: za všechny SKIP kodóny (next_pos) + přeskočené kodóny (skip_count * 3)
        pos = next_pos + skip_count * 3
        stack.append(("SKIP", skip_count))
        continue
    elif instr == "STOP":
        break
    elif instr == "UNDO":
        if stack:
            last = stack.pop()
            if last[0].startswith("ROT"):
                if last[1] == 4:
                    values = list(CODONS.values())
                    keys = list(CODONS.keys())
                    CODONS = dict(zip(keys, values[1:] + [values[0]]))
                else:
                    layer = last[1] - 1
                    CODONS = {"".join(list(c)[:layer] + [{"T":"A","A":"G","G":"C","C":"T"}[c[layer]]] + list(c)[layer+1:]): v for c, v in CODONS.items()}
            elif last[0] == "LETTER":
                out = out[:-1]  # Odeber poslední písmeno
    else:
        out += instr
        stack.append(("LETTER", instr))
    pos += 3

print(f"Aminokyseliny: {out}")
print(f"Délka: {len(out)}")

# Pomocná funkce pro normalizaci vektoru
def normalize(v):
    return v / np.linalg.norm(v)

# Vytvoření 3D struktury
table = {str(row["pair"]): (float(row["length"]), float(row["angle_deg"]), float(row["dihedral_deg"])) 
         for _, row in pd.read_csv("table.in", keep_default_na=False, na_values=[]).iterrows()}

# Počáteční tři body (a je fiktivní bod pro definici první roviny)
a = np.array([0.0, 1.0, 0.0])
b = np.array([0.0, 0.0, 0.0])
length, angle, dihedral = table[out[0:2]]
c = np.array([length, 0.0, 0.0])

for i in range(2, len(out)):
    length, angle, dihedral = table[out[i-1:i+1]]
    
    v1 = normalize(b - a)  # směr předchozího spoje
    v2 = normalize(c - b)  # směr aktuálního spoje
    
    # Normála roviny definovaná skutečnými předchozími body
    normal = normalize(np.cross(v1, v2))
    
    # Rotace normály o dihedral úhel kolem v2
    tiltedNormal = Rotation.from_rotvec(v2 * dihedral, degrees=True).apply(normal)
    
    # Rotace -v2 o angle kolem rotované normály
    v = Rotation.from_rotvec(-tiltedNormal * angle, degrees=True).apply(-v2)
    
    # Nový bod
    d = c + length * v
    a, b, c = b, c, d

# Výsledek - vzdálenost od počátku (bod b byl na [0,0,0])
distance = np.linalg.norm(c)
print(f"Vzdálenost: {distance:.3f}")
