from http.client import FOUND
import numpy as np

import matplotlib.pyplot as plt, numpy as np

import h5py
import random

import matplotlib.pyplot as plt
from os import listdir
from os.path import isfile, join
biker_counter = 1
non_biker_counter = 1

from scipy.spatial.transform import Rotation as R


for i in range(1):
    i+=13
    # 14 16
    complete_biker = []
    back_pts = []
    front_pts = []
    left_pts = []
    right_pts = []
    with h5py.File('biker_back_7_auto_reco/frame_'+str(i)+'.h5', 'r') as f:
        data = f['data'][:]
        label = f['label'][:]

        back_pts = np.array([[data[np.argmax(label)][:, 0], data[np.argmax(label)][:, 1], data[np.argmax(label)][:, 2]]])
        r = R.from_euler('xyz',(0,0,180), degrees=True) 
        back_pts = r.apply(back_pts.T.reshape(-1,3)) #Rotated points
        back_pts += np.array([0, 0.25, 0])

        complete_biker = np.array([[back_pts[:, 0], back_pts[:, 1], back_pts[:, 2]]])

        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # ax.scatter(complete_biker[:, 0], complete_biker[:, 1], complete_biker[:, 2], s=1)
        # plt.show()
        
    
    with h5py.File('biker_front_7_auto_reco/frame_'+str(i)+'.h5', 'r') as f:
        data = f['data'][:]
        label = f['label'][:]

        front_pts = np.array([[data[np.argmax(label)][:, 0], data[np.argmax(label)][:, 1], data[np.argmax(label)][:, 2]]])
        r = R.from_euler('xyz',(0,0,0), degrees=True) 
        front_pts = r.apply(front_pts.T.reshape(-1,3)) #Rotated points
        front_pts += np.array([0, -0.1, 0])
        complete_biker = np.append(complete_biker, np.array([[front_pts[:, 0], front_pts[:, 1], front_pts[:, 2]]]), axis=0) 

        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # ax.scatter(complete_biker[:, 0], complete_biker[:, 1], complete_biker[:, 2], s=1)
        # plt.show()

    with h5py.File('biker_left_7_auto_reco/frame_'+str(i)+'.h5', 'r') as f:
        data = f['data'][:]
        label = f['label'][:]

        left_pts = np.array([[data[np.argmax(label)][:, 0], data[np.argmax(label)][:, 1], data[np.argmax(label)][:, 2]]])
        r = R.from_euler('xyz',(0,0,90), degrees=True) 
        left_pts = r.apply(left_pts.T.reshape(-1,3)) #Rotated points
        left_pts += np.array([0.0, 0, 0])
        # left_pts += np.array([0.05, 0, 0])
        
        # iso_left_pts = left_pts[left_pts[:, 0] < -0.35]
        left_pts = left_pts[left_pts[:, 0] > -0.35]
        left_pts = left_pts[left_pts[:, 0] < 0.4]

        head_left_pts = left_pts[left_pts[:, 0] < -0.05]
        head_left_pts = head_left_pts[head_left_pts[:, 0] > -0.4]
        head_left_pts = head_left_pts[head_left_pts[:, 2] > 0.7]
        head_left_pts += np.array([0.125, 0, 0])

        wheel_left_pts = left_pts[left_pts[:, 1] > 0.275]
        wheel_left_pts = wheel_left_pts[wheel_left_pts[:, 2] < -0.25]
        wheel_left_pts += np.array([0.25, 0, 0])

        left_pts1 = left_pts[left_pts[:, 0] > -0.05]
        left_pts1 = left_pts1[left_pts1[:, 2] < 0.7]
        left_pts2 = left_pts[left_pts[:, 0] < -0.4]
        left_pts2 = left_pts2[left_pts2[:, 2] < 0.7]

        left_pts3 = left_pts[left_pts[:, 1] < 0.275]
        left_pts3 = left_pts3[left_pts3[:, 2] > -0.25]

        leftx = np.append(left_pts1[:,0], left_pts2[:,0])
        leftx = np.append(leftx, left_pts3[:,0])
        leftx = np.append(leftx, head_left_pts[:,0])
        leftx = np.append(leftx, wheel_left_pts[:,0])

        lefty = np.append(left_pts1[:,1], left_pts2[:,1])
        lefty = np.append(lefty, left_pts3[:,1])
        lefty = np.append(lefty, head_left_pts[:,1])
        lefty = np.append(lefty, wheel_left_pts[:,1])

        leftz = np.append(left_pts1[:,2], left_pts2[:,2])
        leftz = np.append(leftz, left_pts3[:,2])
        leftz = np.append(leftz, head_left_pts[:,2])
        leftz = np.append(leftz, wheel_left_pts[:,2])

        left = np.array([[leftx, lefty, leftz]]).T.reshape(-1,3) 

        left1 = left[(left[:, 0] > -0.05)]
        left2 = left[(left[:, 0] < -0.05) & (left[:, 2] < 0.7)]
        left = np.append(left1, left2, axis=0)

        mask = np.random.choice(left.shape[0], size=4096, replace=False)
        left_pts = left[mask]

        right_pts = left_pts * np.array([-1,1,1])

        complete_biker = np.append(complete_biker, np.array([[left_pts[:, 0], left_pts[:, 1], left_pts[:, 2]]]), axis=0) 
        complete_biker = np.append(complete_biker, np.array([[right_pts[:, 0], right_pts[:, 1], right_pts[:, 2]]]), axis=0) 
        complete_biker1 = complete_biker[0].T.reshape(-1,3)
        complete_biker2 = complete_biker[1].T.reshape(-1,3)
        complete_biker3 = complete_biker[2].T.reshape(-1,3)
        complete_biker4 = complete_biker[3].T.reshape(-1,3)
        complete_biker = np.append(complete_biker1, complete_biker2, axis=0)
        complete_biker = np.append(complete_biker, complete_biker3, axis=0)
        complete_biker = np.append(complete_biker, complete_biker4, axis=0)

        # fig = plt.figure()
        # ax = fig.add_subplot(111, projection='3d')
        # ax.set_xlabel('x')
        # ax.set_ylabel('y')
        # ax.scatter(left_pts[:, 0], left_pts[:, 1], left_pts[:, 2], s=10, c='r')
        # ax.scatter(back_pts[:, 0], back_pts[:, 1], back_pts[:, 2], s=1, c='g')
        # ax.scatter(front_pts[:, 0], front_pts[:, 1], front_pts[:, 2], s=1, c='b')
        # ax.scatter(right_pts[:, 0], right_pts[:, 1], right_pts[:, 2], s=1, c='y')
        # ax.scatter(iso_left_pts[:, 0], iso_left_pts[:, 1], iso_left_pts[:, 2], s=10, c='r')
        # ax.scatter(head_left_pts[:, 0], head_left_pts[:, 1], head_left_pts[:, 2], s=10, c='r')
        # ax.scatter(wheel_left_pts[:, 0], wheel_left_pts[:, 1], wheel_left_pts[:, 2], s=10, c='r')
        # ax.scatter(complete_biker[:, 0], complete_biker[:, 1], complete_biker[:, 2], s=1)
        # plt.show()
        
        for biker_counter in range(10000):
            angle = random.randint(0, 3)
            if angle == 0: # right
                complete_biker_f = right_pts[right_pts[:,0] < 0]
                complete_biker_b = complete_biker[complete_biker[:,0] > 0]
            if angle == 1: # left
                complete_biker_f = left_pts[left_pts[:,0] > 0]
                complete_biker_b = complete_biker[complete_biker[:,0] < 0]
            if angle == 2: # front
                complete_biker_f = front_pts[front_pts[:,1] < 0.1]
                complete_biker_b = complete_biker[complete_biker[:,1] > 0.1]
            if angle == 3: # back
                complete_biker_f = back_pts[back_pts[:,1] > 0.1]
                complete_biker_b = complete_biker[complete_biker[:,1] < 0.1]

            biker_counter+=13000
            print(biker_counter)
            N = 3000
            N_f = round(N * 0.85)
            N_b = round(N * 0.15)
            if (N_f - complete_biker_f.shape[0]) > 0:
                pad = np.random.choice(complete_biker_f.shape[0], size=N_f - complete_biker_f.shape[0], replace=True)
                data = complete_biker_f[pad]
            else:
                mask = np.random.choice(complete_biker_f.shape[0], size=N_f, replace=False)
                data = complete_biker_f[mask]
            
            if (N_b - complete_biker_b.shape[0]) > 0:
                pad = np.random.choice(complete_biker_b.shape[0], size=N_b - complete_biker_b.shape[0], replace=True)
                data = np.append(data, complete_biker_b[pad], axis=0)
            else:
                mask = np.random.choice(complete_biker_b.shape[0], size=N_b, replace=False)
                data = np.append(data, complete_biker_b[mask], axis=0)
            
            sample_size = random.randint(256, N)
            mask = np.random.choice(data.shape[0], size=sample_size, replace=False)
            data = data[mask]

            pad = np.random.choice(data.shape[0], size=N, replace=True)
            data = data[pad]

            # np.savetxt("D:/Pointnet2/data/modelnet40_normal_resampled/biker/biker_newleft"+str(biker_counter)+".txt", data)

            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            ax.set_xlabel('x')
            ax.set_ylabel('y')
            ax.scatter(data[:, 0], data[:, 1], data[:, 2], s=1)
            plt.show()