from __future__ import annotations
from os import PathLike
import json
from pathlib import Path
import numpy as np
import numpy.typing as npt


class Atom:
    """Class containing the information of a single atom.

    Attributes
    ----------
    element : str
        The atomic symbol of the element.
    xyz : NDArray
        The x-, y-, and z-coordinates of the atom.
    """

    def __init__(
        self,
        element: str,
        xyz: npt.NDArray,
    ):
        self.element = element
        self.xyz = np.array(xyz, dtype=np.float64)


    def __repr__(self):
        return (
            f"{"Element":12}" f"{"X":11}" f"{"Y":11}" f"{"Z":11}\n"
            f"{self.element:9}"f"{self.xyz[0]:11.6f}"f"{self.xyz[1]:11.6f}"f"{self.xyz[2]:11.6f}\n"
        )


class Crystal:
    """Class storing the geometric parameters of a crystal structure.
    
    Attributes
    ----------
    lat_vec : npt.NDArray
        The primitive lattice vectors of the crystal in units of alat
    atoms : list[Atom]
        The atoms in the crystal.
    alat : float
        The lattice parameter.

    Notes
    -----
    The lattice vectors should be provided in units of alat here, which involves taking the square
    root of the sum of the first row of the lattice vector matrix.
    """

    def __init__(
        self,
        lat_vec: npt.NDArray,
        atoms: list[Atom],
        alat: float,
    ):
        self.lat_vec = lat_vec
        self.atoms = atoms
        self.alat = alat


    def get_num_atoms(self) -> int:
        return len(self.atoms)


    def get_coords(self, fractional: bool = True) -> npt.NDArray:
        if fractional:
            return np.array([i.xyz for i in self.atoms])
        else:
            return np.array([i.xyz*self.alat for i in self.atoms])


    def get_elements(self) -> list[str]:
        return [i.element for i in self.atoms]


    def remove_atom(self, index: int):
        del self.atoms[index]


    def pop_atom(self, index: int) -> Atom:
        return self.atoms.pop(index)


    def add_atom(self, atom: Atom, index: int = None):
        if index is not None:
            self.atoms.insert(index, atom)
        else:
            self.atoms.append(atom)


    @classmethod
    def from_xsf(cls, xsf_file: PathLike) -> Crystal:
        """Read in only the crystal structure information from an XSF file."""
        with open(xsf_file, "r") as xsf:

            # Pulls in the lines that contain the primitive lattice vectors and the line containing the number of atoms.
            crystal_info = [next(xsf) for _ in range(7)]

            # Extract the lattice vectors
            lat_vec = np.array([line.strip().split() for line in crystal_info[2:5]], dtype=np.float64)

            lat_inv = np.linalg.inv(lat_vec)
            
            alat = np.sqrt(np.sum(lat_vec[0,:] ** 2))

            lat_vec = lat_vec / alat

            # Pull the number of atoms
            num_atoms = int(crystal_info[-1].split()[0])

            # Read in all of the atoms and turn it into a list of Atom objects
            atoms = [next(xsf).strip().split() for _ in range(num_atoms)]
            atoms = [Atom(element=atom[0], xyz=np.dot(np.array([float(i) for i in atom[1:4]]), lat_inv)) for atom in atoms]

        return Crystal(lat_vec, atoms, alat)


    def __repr__(self):

        lat_vec_scaled = self.lat_vec * self.alat

        self_repr = f"{"Lattice":12}{"X":11}{"Y":11}{"Z":11}\n{"Vectors":11}\n"
        self_repr += f"{"":9}{lat_vec_scaled[0][0]:11.6f}{lat_vec_scaled[0][1]:11.6f}{lat_vec_scaled[0][2]:11.6f}\n"
        self_repr += f"{"":9}{lat_vec_scaled[1][0]:11.6f}{lat_vec_scaled[1][1]:11.6f}{lat_vec_scaled[1][2]:11.6f}\n"
        self_repr += f"{"":9}{lat_vec_scaled[2][0]:11.6f}{lat_vec_scaled[2][1]:11.6f}{lat_vec_scaled[2][2]:11.6f}\n\n"

        self_repr += f"{"Element":12}{"X":11}{"Y":11}{"Z":11}\n\n"
        for i in self.atoms:
            self_repr += f"{i.element:9}{i.xyz[0]*self.alat:11.6f}{i.xyz[1]*self.alat:11.6f}{i.xyz[2]*self.alat:11.6f}\n"
        return self_repr


    def __iter__(self):
        yield from self.atoms


def get_atomic_number(element_symbol: str) -> int:
    """Return the atomic number for the provided element symbol (case insensitive)."""
    element_dict = {
        "H" : 1,
        "He": 2,
        "Li": 3,
        "Be": 4,
        "B" : 5,
        "C" : 6,
        "N" : 7,
        "O" : 8,
        "F" : 9,
        "Ne": 10,
        "Na": 11,
        "Mg": 12,
        "Al": 13,
        "Si": 14,
        "P" : 15,
        "S" : 16,
        "Cl": 17,
        "Ar": 18,
        "K" : 19,
        "Ca": 20,
        "Sc": 21,
        "Ti": 22,
        "V" : 23,
        "Cr": 24,
        "Mn": 25,
        "Fe": 26,
        "Co": 27,
        "Ni": 28,
        "Cu": 29,
        "Zn": 30,
        "Ga": 31,
        "Ge": 32,
        "As": 33,
        "Se": 34,
        "Br": 35,
        "Kr": 36,
        "Rb": 37,
        "Sr": 38,
        "Y" : 39,
        "Zr": 40,
        "Nb": 41,
        "Mo": 42,
        "Tc": 43,
        "Ru": 44,
        "Rh": 45,
        "Pd": 46,
        "Ag": 47,
        "Cd": 48,
        "In": 49,
        "Sn": 50,
        "Sb": 51,
        "Te": 52,
        "I" : 53,
        "Xe": 54,
        "Cs": 55,
        "Ba": 56,
        "La": 57,
        "Ce": 58,
        "Pr": 59,
        "Nd": 60,
        "Pm": 61,
        "Sm": 62,
        "Eu": 63,
        "Gd": 64,
        "Tb": 65,
        "Dy": 66,
        "Ho": 67,
        "Er": 68,
        "Tm": 69,
        "Yb": 70,
        "Lu": 71,
        "Hf": 72,
        "Ta": 73,
        "W" : 74,
        "Re": 75,
        "Os": 76,
        "Ir": 77,
        "Pt": 78,
        "Au": 79,
        "Hg": 80,
        "Tl": 81,
        "Pb": 82,
        "Bi": 83,
        "Po": 84,
        "At": 85,
        "Rn": 86,
        "Fr": 87,
        "Ra": 88,
        "Ac": 89,
        "Th": 90,
        "Pa": 91,
        "U" : 92,
        "Np": 93,
        "Pu": 94,
        "Am": 95,
        "Cm": 96,
        "Bk": 97,
        "Cf": 98,
        "Es": 99,
        "Fm": 100,
        "Md": 101,
        "No": 102,
        "Lr": 103,
        "Rf": 104,
        "Db": 105,
        "Sg": 106,
        "Bh": 107,
        "Hs": 108,
        "Mt": 109,
        "Ds": 110,
        "Rg": 111,
        "Cn": 112,
        "Nh": 113,
        "Fl": 114,
        "Mc": 115,
        "Lv": 116,
        "Ts": 117,
        "Og": 118,
    }
    # Make sure symbol is capitalized
    return element_dict[element_symbol.title()]


def read_xsf_chg_dens(xsf_file: PathLike, crystal: Crystal) -> tuple[npt.NDArray, list[int], list[float]]:
    with open(xsf_file, "rb") as xsf:
        # Skip over atomic coordinate information
        for _ in range(7+crystal.get_num_atoms()+3):
            next(xsf)

        # Pull size of FFT Grid
        num_grid = [int(i) for i in next(xsf).strip().split()]

        # Assign to variables to reduce indexing in loop
        n_grid_0 = num_grid[0]
        n_grid_1 = num_grid[1]
        n_grid_2 = num_grid[2]

        origin = [float(i) for i in next(xsf).strip().split()]

        # Skip over misc. XSF data
        for _ in range(3):
            next(xsf)

        # Calculate total number of datapoints in XSF file
        n_points = n_grid_0 * n_grid_1 * n_grid_2

        # Initialize indices
        i = 0
        j = 0
        k = 0

        # Allocate memory for charge density array
        charge_density = np.zeros((n_grid_0, n_grid_1, n_grid_2))
        
        # Loop over all data points in file
        for index in range(n_points):

            # Read in bytes for each data point, adding an extra if there is a newline character
            charge_density[i][j][k] = xsf.read(14 + ((index+1) % 6 == 0))

            # Increment grid components to fill out array
            i += 1
            if i == n_grid_0:
                i = 0
                j += 1

                if j == n_grid_1:
                    j = 0
                    k += 1

    # Convert to periodic array
    charge_density = np.delete(charge_density, [n_grid_0-1], axis=0)
    charge_density = np.delete(charge_density, [n_grid_1-1], axis=1)
    charge_density = np.delete(charge_density, [n_grid_2-1], axis=2)

    #charge_density = charge_density.T

    return charge_density.flatten(), num_grid, origin


# Courtesy of https://github.com/matterhorn103/easyxtb/blob/main/src/easyxtb/format.py
def _flatten_arrays(data: dict) -> dict:
    """Turn any lists of simple items (not dicts or lists) into strings."""
    if isinstance(data, list):
        # Turn simple lists into flat strings
        if all(not isinstance(i, (dict, list)) for i in data):
            return json.dumps(data)
        # Recursively flatten any nested lists
        else:
            items = [_flatten_arrays(i) for i in data]
            return items
    elif isinstance(data, dict):
        # Recursively flatten all entries
        new = {k: _flatten_arrays(v) for k, v in data.items()}
        return new
    else:
        return data
    

def xsf_to_cjson(xsf: Path, output_name: str):
    cjson = {
        "atoms": {
            "coords": {
                "3d": [],
                "3dFractional": []
            },
            "elements": {
                "number": []
            }
        },
        "chemicalJson": 1,
        "cube": {
            "dimensions": [],
            "name": "",
            "origin": [],
            "scalars": [],
            "spacing": [],
            "type": "fromFile"
        },
        "name": "",
        "properties": {
            "fileName": "",
            "totalCharge": 0,
            "totalSpinMultiplicity": 1
        },
        "unitCell": {
            "cellVectors": [],
            "hallNumber": 0,
            "spaceGroup": ""
        }
    }

    crystal = Crystal.from_xsf(xsf_file=xsf)

    elements = [get_atomic_number(element) for element in crystal.get_elements()]

    cjson["atoms"]["elements"]["number"] = elements

    coords = crystal.get_coords(fractional=False)
    fractional_coords = crystal.get_coords(fractional=True)

    cjson["atoms"]["coords"]["3d"] = coords.flatten().tolist()
    cjson["atoms"]["coords"]["3dFractional"] = fractional_coords.flatten().tolist()

    chg_dens, num_grid, origin = read_xsf_chg_dens(xsf, crystal)

    cell_dim = crystal.lat_vec * crystal.alat

    spacing = [float(cell_dim[i][i] / (num_grid[i] - 1)) for i in range(3)]

    cjson["cube"]["dimensions"] = [i - 1 for i in num_grid]
    cjson["cube"]["name"] = xsf.name
    cjson["cube"]["origin"] = origin
    cjson["cube"]["scalars"] = chg_dens.tolist()
    cjson["cube"]["spacing"] = spacing

    cjson["name"] = xsf.name[:-4]

    cjson["unitCell"]["cellVectors"] = cell_dim.flatten().tolist()

    cjson = _flatten_arrays(cjson)

    with open(Path.cwd() / Path(f"{output_name}.cjson"), "w", encoding="utf-8") as output:
        cjson_string = (
            json.dumps(cjson, indent=4).replace('"[', "[").replace(']"', "]")
        )
        cjson_string = cjson_string.replace(r"\"", '"')
        output.write(cjson_string)