# Author: Hongzhu Zhao

import pymel.core as pm
import numpy as np
import math
import maya.api.OpenMaya as om2


def getMObj(name):
    tempList = om2.MSelectionList()
    tempList.add(str(name))
    return tempList.getDependNode(0)


def getPlug(mObj, plugName):
    mfnDep = om2.MFnDependencyNode(mObj)
    return mfnDep.findPlug(plugName, False)


class InverseKinematicsHM:
    __instance = None

    @staticmethod
    def get_instance():
        if InverseKinematicsHM.__instance is None:
            InverseKinematicsHM()
        return InverseKinematicsHM.__instance

    def __init__(self):
        self.e_j_dict = {}
        self.script_job = []
        self.om_callback = []
        self.max_iteration = 200
        self.max_interactive_iteration = 5
        self.max_after_interaction_iteration = 10
        self.end_effector_nodes_i = None
        self.target_nodes_i = None
        self.method_i = None
        self.optimize_i = False
        self.inverse_method = None
        self.time_step = 0.01
        self.base_time_step = 3
        self.clamp_length = 0.2
        self.check_diag_when_inverse = False
        self.damp_constant = 0.1
        self.close_threshold = 0.01
        if InverseKinematicsHM.__instance is not None:
            raise Exception("This class is a singleton!")
        else:
            InverseKinematicsHM.__instance = self
        pass

    def normalize(self, vec):
        length = vec[0] * vec[0] + vec[1] * vec[1] + vec[2] * vec[2]
        return vec / math.sqrt(length)

    def matrix_inverse(self, matrix):
        u, d, v = np.linalg.svd(matrix)
        if self.check_diag_when_inverse:
            d_inv = [(x ** -1 if math.fabs(x ** -1) < 1e3 else 0) for x in d]
        else:
            d_inv = d ** -1
        ret = np.dot(np.diag(d_inv), u.T)
        ret = np.dot(v.T, ret)
        # print(ret)
        return ret

    def calc_delta_theta(self, joints, target, effector, optimize=False):
        # joints_cp = np.copy(joints)
        axes = self.calc_axes(joints, target, effector)
        jacobian = self.calc_jacobian(joints, effector, axes=axes)
        t_jacobian = self.inverse_method(jacobian)
        direction = target - effector
        if optimize:
            if np.linalg.norm(direction) > self.clamp_length:
                direction = self.normalize(direction) * self.clamp_length
            pass

        d_theta = np.dot(t_jacobian, np.array([direction]).T)
        d_theta *= self.time_step
        return d_theta, axes

    def set_up_inverse(self, method):
        switcher = {
            'Transpose': lambda mat: mat.T,
            'Pseudo': self.calc_inverse_pseudo,
            'Damped': self.calc_inverse_damped
        }
        self.inverse_method = switcher.get(method, lambda x: None)
        if method == 'Transpose':
            self.time_step = 0.02 * self.base_time_step
        elif method == 'Pseudo':
            self.time_step = 0.1 * self.base_time_step
            self.check_diag_when_inverse = True
        else:
            self.time_step = 0.2 * self.base_time_step

    def calc_inverse_pseudo(self, jacobian):
        pre_inv = np.dot(jacobian, jacobian.T)
        inv = self.matrix_inverse(pre_inv)
        return np.dot(jacobian.T, inv)

    def calc_inverse_damped(self, jacobian):
        pre_inv = np.dot(jacobian, jacobian.T)
        pre_inv += self.damp_constant * np.eye(len(pre_inv))
        inv = self.matrix_inverse(pre_inv)
        return np.dot(jacobian.T, inv)

    def calc_axes(self, joints, target, effector):
        return np.array([self.normalize(np.cross(effector - x, target - x)) for x in joints])  #

    def calc_jacobian(self, joints, effector, axes=None):
        joints_count = len(joints)
        J = np.zeros((3, joints_count))
        for i in range(joints_count):
            if axes is None:
                now_axis = np.array([0., 0., 1.])
            else:
                now_axis = axes[i]
            J[:, i] = np.cross(now_axis, (effector - joints[i]))
        return J

    def calc_topology(self, joints, end_effectors):
        for effector in end_effectors:
            now_joint = effector
            l = []
            while now_joint.getParent() is not None:
                now_joint = now_joint.getParent()
                if now_joint in joints:
                    l.append(now_joint)
            self.e_j_dict[effector] = l

    def is_close_enough(self, end_effector_nodes, target_nodes):
        end_effectors = np.array([x.getTranslation(space='world') for x in end_effector_nodes])
        targets = np.array([x.getTranslation(space='world') for x in target_nodes])
        return all(np.linalg.norm(x) < self.close_threshold for x in (end_effectors - targets))

    def iterate_once(self, end_effector_nodes, target_nodes, optimize, target=None):
        for (idx, effector) in enumerate(end_effector_nodes):
            if target is not None and target_nodes[idx] != target:
                continue
            joints = self.e_j_dict[effector]
            d_theta, axes = self.calc_delta_theta([x.getTranslation(space='world') for x in joints],
                                                  target_nodes[idx].getTranslation(space='world'),
                                                  effector.getTranslation(space='world'),
                                                  optimize=optimize)
            for j in range(len(joints)):
                now_joint = joints[j]
                axis = pm.datatypes.Vector(axes[j])
                w_rot = now_joint.getRotation(space='world')
                axis = axis * w_rot.asMatrix().inverse()
                delta_rot = pm.datatypes.Quaternion(d_theta[j][0], axis)
                tmp_rot = now_joint.getRotation(quaternion=True)
                new_rot = delta_rot * tmp_rot
                now_joint.setRotation(new_rot)

                # delta_rot = pm.datatypes.Quaternion(d_theta[j][0], axis)
                # tmp_rot = now_joint.getRotation(space='world', quaternion=True)
                # new_rot = delta_rot * tmp_rot
                # now_joint.setRotation(new_rot, space='world')

    def callback_handler(self, msg, node, data):
        refresh, target = data
        self.interactive_iterate(refresh=refresh, target=target)

    def interactive_iterate(self, refresh=True, target=None):
        count = 0
        if refresh:
            while not self.is_close_enough(self.end_effector_nodes_i,
                                           self.target_nodes_i) and count < self.max_after_interaction_iteration:
                self.iterate_once(self.end_effector_nodes_i, self.target_nodes_i,
                                  optimize=self.optimize_i, target=target)
                count += 1
                pm.refresh()
        else:
            while not self.is_close_enough(self.end_effector_nodes_i,
                                           self.target_nodes_i) and count < self.max_interactive_iteration:
                self.iterate_once(self.end_effector_nodes_i, self.target_nodes_i,
                                  optimize=self.optimize_i, target=target)
                count += 1

    def reset_everything(self):
        for x in self.script_job:
            pm.scriptJob(kill=x)
        self.script_job = []
        for x in self.om_callback:
            om2.MMessage.removeCallback(x)
        self.om_callback = []

    def inverse_kinematics(self, data, method='Transpose', optimize=False,
                           interactive=False):  # joint_nodes_all, end_effector_nodes, target_nodes
        self.reset_everything()
        # self.calc_topology(joint_nodes_all, end_effector_nodes)
        self.set_up_inverse(method)
        end_effector_nodes = []
        target_nodes = []
        joint_nodes_all = []
        for j, e, t in data:
            self.e_j_dict[e] = j
            end_effector_nodes.append(e)
            target_nodes.append(t)
            joint_nodes_all += j

        if interactive:
            self.end_effector_nodes_i = end_effector_nodes
            self.target_nodes_i = target_nodes
            self.method_i = method
            self.optimize_i = optimize
            for target in target_nodes:
                self.om_callback.append(om2.MNodeMessage.addNodeDirtyPlugCallback(
                    getMObj(target.name()), self.callback_handler, (False, target)
                ))
                self.script_job.append(pm.scriptJob(ac=[target.longName() + '.translate',
                                                        'InverseKinematicsHM.get_instance().interactive_iterate()']))

        else:
            pm.playbackOptions(maxTime=self.max_iteration)
            # animationEndTime=iteration,
            pm.currentTime(1)
            now_frame = 2
            while not self.is_close_enough(end_effector_nodes, target_nodes) and now_frame <= self.max_iteration:
                pm.currentTime(now_frame, update=False)

                self.iterate_once(end_effector_nodes, target_nodes, optimize=optimize)
                for now_joint in joint_nodes_all:
                    pm.setKeyframe(now_joint, at='rotateX')
                    pm.setKeyframe(now_joint, at='rotateY')
                    pm.setKeyframe(now_joint, at='rotateZ')
                now_frame += 1

            pm.playbackOptions(animationEndTime=now_frame)

    def create_object(self, pos, function_name='polySphere'):
        func = getattr(pm, function_name)
        ret = func(radius=0.1)[0]
        pm.move(pos)
        return ret

    def test(self):
        pm.newFile(f=1)
        j = np.array([[1, 0, 0], [3, 2, 0.0], [4, 4, 0.0]])
        e = np.array([[5, 1, 0.0]])
        target = np.array([[-4, 0.5, 0.0]])
        joint_count = len(j)
        joints = []
        target_nodes = [pm.polyCube(h=0.1, w=0.1, d=0.1)[0]]
        pm.move(target)
        for i in range(joint_count):
            now_joint = self.create_object(j[i])
            # nowJoint.translate.set(j[i])
            # pm.move(j[i])
            if i > 0:
                now_joint.setParent(joints[i - 1])
            joints.append(now_joint)

        end_effectors = [self.create_object(e[0])]
        end_effectors[0].setParent(joints[-1])

        self.inverse_kinematics(joints, end_effectors, target_nodes, method='Transpose', interactive=True, optimize=True)


# InverseKinematicsHM.get_instance().test()
