package main

import (
	"bufio"
	"context"
	"flag"
	"fmt"
	"log"
	"mp1/ml"
	"mp1/peer"
	pb "mp1/protos"
	"mp1/sdfs"
	"os"
	"strconv"
	"strings"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
)

// function to join the p2p network
func Join(args []string) {
	fmt.Println("Joining the distributed cluster! ", peer.MyHostName)
	timestamp := time.Now().Nanosecond()
	if len(args) >= 1 && args[0] == "i" {
		go peer.StartIntroducerServer()
		peer.AddMemberToList(peer.MyHostName,
			timestamp,
			2,
			1)
	} else {
		var i int
		var addr string
		for i, addr = range peer.Addrs {
			if addr == peer.MyHostName+":50053" {
				continue
			}
			fmt.Println("trying to join the distributed cluster! ", addr)
			ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
			conn, err := grpc.DialContext(ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
			if err == nil {
				fmt.Println("Connected to introducer", addr)
				defer conn.Close()
				c := pb.NewIntroInterfaceClient(conn)

				ctx, cancel = context.WithTimeout(context.Background(), time.Second)
				defer cancel()
				r, err := c.RequestIntro(ctx, &pb.IntroRequest{HostName: peer.MyHostName,
					Incarnation: 1,
					Timestamp:   int64(timestamp)})
				if err != nil {
					log.Printf("could not send grpc req response!: %v\n", err)
				} else {
					peer.MemberList = peer.ToMemberList(r.GetPeers())
					log.Printf("%v %v\n", r.GetPeers(), peer.MemberList)
				}
				break
			}
		}
		if i == len(peer.Addrs) {
			log.Printf("did not connect, setting self as leader")
			peer.AddMemberToList(peer.MyHostName,
				timestamp,
				2,
				1)

		}

	}
	go peer.StartPeerUDPServer()
	go peer.StartPeerUDPClient()
	//go sdfs.StartSDFSServer()
	//go sdfs.FixReplicas()
}

func main() {
	flag.Parse()
	hostname, err := os.ReadFile("/mp1/hostname")
	if err != nil {
		panic(err)
	}
	log.Println(string(hostname))
	peer.MyHostName = string(hostname[:len(hostname)-1])
	jobId := 0
	//fileName := "dist_log.log"
	//f, err := os.OpenFile(fileName,
	//	os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	//if err != nil {
	//	log.Println(err)
	//}
	//defer f.Close()
	//log.SetOutput(f)
	args := os.Args[1:]
	go ml.StartIdunno(0)
	go ml.StartMLServer()
	go func() {
		for {
			buf := bufio.NewReader(os.Stdin)
			fmt.Print("> ")
			sentence, err := buf.ReadBytes('\n')
			if err != nil {
				log.Println(err)
			} else {
				log.Println(string(sentence))
				if string(sentence) == "list_mem\n" {
					peer.List_mem()
				} else if string(sentence) == "list_self\n" {
					peer.List_self()
				} else if string(sentence) == "join\n" {
					peer.Lock.Lock()
					peer.MemberList = peer.MemberList[:0]
					peer.Lock.Unlock()
					go Join(args)
				} else if string(sentence) == "leave\n" {
					peer.Lock.Lock()
					peer.MemberList = peer.MemberList[:0]
					peer.Lock.Unlock()
					go peer.Leave()
				} else if strings.HasPrefix(string(sentence), "put") {
					putArgs := strings.Split(string(sentence), " ")
					if len(putArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : put <localfilename> <sdfsfilename>")
						continue
					}
					localFileName := putArgs[1]
					sdfsFileName := strings.TrimSuffix(putArgs[2], "\n")
					go sdfs.PutClient(localFileName, sdfsFileName)
				} else if strings.HasPrefix(string(sentence), "get-versions") {
					getArgs := strings.Split(string(sentence), " ")
					if len(getArgs) < 3 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : get-versions <sdfsfilename> <num-versions> <localfilename>")
						continue
					}
					sdfsFileName := getArgs[1]
					numVersions, _ := strconv.Atoi(getArgs[2])
					localFileName := strings.TrimSuffix(getArgs[3], "\n")
					fmt.Printf("%s NUMVERSION %d %s \n", sdfsFileName, numVersions, localFileName)
					go sdfs.GetClient(localFileName, sdfsFileName, int32(numVersions))
				} else if strings.HasPrefix(string(sentence), "get") {
					getArgs := strings.Split(string(sentence), " ")
					if len(getArgs) < 3 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : get <sdfsfilename> <localfilename>")
						continue
					}
					sdfsFileName := getArgs[1]
					localFileName := strings.TrimSuffix(getArgs[2], "\n")
					go sdfs.GetClient(localFileName, sdfsFileName, 1)
				} else if strings.HasPrefix(string(sentence), "delete") {
					delArgs := strings.Split(string(sentence), " ")
					if len(delArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : delete <sdfsfilename>")
						continue
					}
					sdfsFileName := strings.TrimSuffix(delArgs[1], "\n")
					go sdfs.DelClient(sdfsFileName)
				} else if strings.HasPrefix(string(sentence), "ls") {
					lsArgs := strings.Split(string(sentence), " ")
					if len(lsArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : ls <sdfsfilename>")
						continue
					}
					sdfsFileName := strings.TrimSuffix(lsArgs[1], "\n")
					go sdfs.LsClient(sdfsFileName)
				} else if strings.HasPrefix(string(sentence), "store") {
					go sdfs.StoreClient()
				} else if strings.HasPrefix(string(sentence), "runjob") {
					lsArgs := strings.Split(string(sentence), " ")
					if len(lsArgs) < 4 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : runjob <job_file> <model_name> <batch_size>")
						continue
					}
					jobId += 1
					job_file := strings.TrimSuffix(lsArgs[1], "\n")
					model_name := strings.TrimSuffix(lsArgs[2], "\n")
					batch_size, _ := strconv.Atoi(strings.TrimSuffix(lsArgs[3], "\n"))
					go ml.RequestJob(job_file, model_name, batch_size)
				} else if strings.HasPrefix(string(sentence), "query_rate") {
					queryRateArgs := strings.Split(string(sentence), " ")
					if len(queryRateArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : query-rate <job-id>")
						continue
					}
					job_id := strings.TrimSuffix(queryRateArgs[1], "\n")
					go ml.QueryRate(job_id)
				} else if strings.HasPrefix(string(sentence), "query_processing") {
					queryProcessingArgs := strings.Split(string(sentence), " ")
					if len(queryProcessingArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : query-processing <job-id>")
						continue
					}
					job_id := strings.TrimSuffix(queryProcessingArgs[1], "\n")
					go ml.QueryProcessing(job_id)
				} else if strings.HasPrefix(string(sentence), "set_batch_size") {
					batchSizeArgs := strings.Split(string(sentence), " ")
					if len(batchSizeArgs) < 3 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : set-batch-size <batch-size> <job-id>")
						continue
					}
					batch_size, _ := strconv.Atoi(strings.TrimSuffix(batchSizeArgs[1], "\n"))
					job_id, _ := strconv.Atoi(strings.TrimSuffix(batchSizeArgs[2], "\n"))
					go ml.SetBatchSize(batch_size, job_id)
				} else if strings.HasPrefix(string(sentence), "get_results") {
					getResultArgs := strings.Split(string(sentence), " ")
					if len(getResultArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : getResultArgs <job-id>")
						continue
					}
					job_id := strings.TrimSuffix(getResultArgs[1], "\n")
					go ml.GetResults(job_id)
				} else if strings.HasPrefix(string(sentence), "show_vms") {
					showVmsArgs := strings.Split(string(sentence), " ")
					if len(showVmsArgs) < 2 {
						fmt.Println("Invalid Args")
						fmt.Println("	Usage : show-vms <job-id>")
						continue
					}
					job_id, _ := strconv.Atoi(strings.TrimSuffix(showVmsArgs[1], "\n"))
					go ml.ShowVMS(job_id)
				}

			}
		}
	}()
	for {
	}
}