diff --git a/raft.go b/raft.go new file mode 100644 index 0000000000000000000000000000000000000000..c110a381b20b210edfac28fdf7db5a227c742cd6 --- /dev/null +++ b/raft.go @@ -0,0 +1,714 @@ +package raft + +// +// this is an outline of the API that raft must expose to +// the service (or tester). see comments below for +// each of these functions for more details. +// +// rf = Make(...) +// create a new Raft server. +// rf.Start(command interface{}) (index, term, isleader) +// start agreement on a new log entry +// rf.GetState() (term, isLeader) +// ask a Raft for its current term, and whether it thinks it is leader +// ApplyMsg +// each time a new entry is committed to the log, each Raft peer +// should send an ApplyMsg to the service (or tester) +// in the same server. +// + +import "sync" +import "sync/atomic" +import "raft/labrpc" + +import "time" +import "math/rand" +// import "fmt" + +// +// as each Raft peer becomes aware that successive log entries are +// committed, the peer should send an ApplyMsg to the service (or +// tester) on the same server, via the applyCh passed to Make(). set +// CommandValid to true to indicate that the ApplyMsg contains a newly +// committed log entry. +// +type ApplyMsg struct { + CommandValid bool + Command interface{} + CommandIndex int +} + +// LogEntry struct containing command and term +type LogEntry struct { + Command interface{} // command for state machine + Term int // term when entry was received by leader +} + +// +// A Go object implementing a single Raft peer. +// +type Raft struct { + mu sync.Mutex // Lock to protect shared access to this peer's state + peers []*labrpc.ClientEnd // RPC end points of all peers + me int // this peer's index into peers[] + dead int32 // set by Kill() + + // Your data here (2A, 2B). + // Look at the paper's Figure 2 for a description of what + // state a Raft server must maintain. + // You may also need to add other state, as per your implementation. + + // persistent state on all servers + currentTerm int // latest term server has seen + votedFor int // candidateId that received vote in current term + log []LogEntry // (command, entry) of each log entry, first index is 1 + + CurrentState string // server state (F: follower, C: candidate, L: leader) + + // volatile state on all servers + commitIndex int // index of highest log entry known to be committed + lastApplied int // index of highest log entry applied to state machine + + // volatile state on leaders, reinitialized after election + nextIndex []int // index of next entry to send to each server + matchIndex []int // index of highest entry known to be replicated on server + + // channel to send ApplyMsg to service + applyCh chan ApplyMsg + + // election timers + electionTimeout *time.Timer // timeout for election + heartbeatTimeout *time.Timer // timeout for heartbeat +} + +// return currentTerm and whether this server +// believes it is the leader. +func (rf *Raft) GetState() (int, bool) { + + var term int + var isleader bool + // Your code here (2A). + + rf.mu.Lock() + defer rf.mu.Unlock() + + term = rf.currentTerm + if rf.CurrentState == "L" { + isleader = true + } else { + isleader = false + } + + return term, isleader +} + +// +// example RequestVote RPC arguments structure. +// field names must start with capital letters! +// +type RequestVoteArgs struct { + // Your data here (2A, 2B). + Term int // term of candidate + CandidateId int // candidate requesting vote + LastLogIndex int // index of candidate last log entry + LastLogTerm int // term of candidate last log entry +} + +// +// example RequestVote RPC reply structure. +// field names must start with capital letters! +// +type RequestVoteReply struct { + // Your data here (2A). + Term int // current term (for candidate to update itself) + VoteGranted bool // true if candidate received vote +} + +// check if a candidate's log is at least as up-to-date as receiver's +func (rf *Raft) checkCandidateUpToDate(args *RequestVoteArgs) bool { + lastLogIndex := len(rf.log) - 1 + if args.LastLogTerm > rf.log[lastLogIndex].Term { + return true + } else if args.LastLogTerm == rf.log[lastLogIndex].Term { + if args.LastLogIndex >= lastLogIndex { + return true + } + } + return false +} + +// +// example RequestVote RPC handler. +// +func (rf *Raft) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) { + // Your code here (2A, 2B). + // Read the fields in "args", + // and accordingly assign the values for fields in "reply". + + if rf.killed() { + return + } + + rf.mu.Lock() + defer rf.mu.Unlock() + + // candidate asking for vote has lower term + if args.Term < rf.currentTerm { + reply.VoteGranted = false + reply.Term = rf.currentTerm + return + } + + // candidate asking for vote has higher term + if args.Term > rf.currentTerm { + rf.currentTerm = args.Term + rf.votedFor = -1 + rf.CurrentState = "F" + rf.resetElectionTimer() + } + + candidateUpToDate := rf.checkCandidateUpToDate(args) + if (rf.votedFor == -1 || rf.votedFor == args.CandidateId) && candidateUpToDate { + rf.votedFor = args.CandidateId + reply.VoteGranted = true + rf.resetElectionTimer() + } else { + reply.VoteGranted = false + } + reply.Term = rf.currentTerm +} + +// +// example code to send a RequestVote RPC to a server. +// server is the index of the target server in rf.peers[]. +// expects RPC arguments in args. +// fills in *reply with RPC reply, so caller should +// pass &reply. +// the types of the args and reply passed to Call() must be +// the same as the types of the arguments declared in the +// handler function (including whether they are pointers). +// +// The labrpc package simulates a lossy network, in which servers +// may be unreachable, and in which requests and replies may be lost. +// Call() sends a request and waits for a reply. If a reply arrives +// within a timeout interval, Call() returns true; otherwise +// Call() returns false. Thus Call() may not return for a while. +// A false return can be caused by a dead server, a live server that +// can't be reached, a lost request, or a lost reply. +// +// Call() is guaranteed to return (perhaps after a delay) except if the +// handler function on the server side does not return. Thus there +// is no need to implement your own timeouts around Call(). +// +// look at the comments in ../labrpc/labrpc.go for more details. +// +// if you're having trouble getting RPC to work, check that you've +// capitalized all field names in structs passed over RPC, and +// that the caller passes the address of the reply struct with &, not +// the struct itself. +// +func (rf *Raft) sendRequestVote(server int, args *RequestVoteArgs, reply *RequestVoteReply) bool { + ok := rf.peers[server].Call("Raft.RequestVote", args, reply) + + rf.mu.Lock() + defer rf.mu.Unlock() + + if ok { + // reply term greater than current term + if reply.Term > rf.currentTerm { + rf.currentTerm = reply.Term + rf.votedFor = -1 + rf.CurrentState = "F" + rf.resetElectionTimer() + } + } + + return ok +} + +// AppendEntriesArgs struct containing args for AppendEntries RPC +type AppendEntriesArgs struct { + Term int // term of leader + LeaderId int // follower can redirect clients + PrevLogIndex int // index of log entry immediately preceding new ones + PrevLogTerm int // term of prevLogIndex entry + Entries []LogEntry // log entries to store (empty for heartbeat) + LeaderCommit int // commitIndex of leader +} + +// AppendEntriesReply struct containing reply for AppendEntries RPC +type AppendEntriesReply struct { + Term int // current term (for leader to update itself) + Success bool // true if follower contained entry matching prevLogIndex and prevLogTerm + + ConflictTerm int // term of conflicting entry + ConflictIndex int // index of first entry in log that conflicts with new entries +} + +// AppendEntries RPC handler +func (rf *Raft) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesReply) { + if rf.killed() { + return + } + + rf.mu.Lock() + defer rf.mu.Unlock() + + rf.resetElectionTimer() + + reply.Term = rf.currentTerm + reply.Success = false + reply.ConflictTerm = -1 + reply.ConflictIndex = -1 + + // candidate term is less than current term + if args.Term < rf.currentTerm { + reply.Term = rf.currentTerm + return + } + + // candidate term is greater than current term + // revert to follower state + if args.Term > rf.currentTerm { + rf.currentTerm = args.Term + rf.votedFor = -1 + rf.CurrentState = "F" + } + + // check if log contains entry at prevLogIndex matching prevLogTerm + if args.PrevLogIndex >= len(rf.log) || rf.log[args.PrevLogIndex].Term != args.PrevLogTerm { + reply.Term = rf.currentTerm + return + } + + // check if log contains entry within log + if args.PrevLogIndex >= len(rf.log) { + reply.ConflictIndex = len(rf.log) + return + } + + // check if log entry at precLogIndex does not match prevLogTerm + if rf.log[args.PrevLogIndex].Term != args.PrevLogTerm { + reply.ConflictTerm = rf.log[args.PrevLogIndex].Term + conflictIndex := args.PrevLogIndex + for conflictIndex >= 0 && rf.log[conflictIndex].Term == reply.ConflictTerm { + conflictIndex-- + } + reply.ConflictIndex = conflictIndex + 1 + return + } + + // // check if existing entry conflicts with new entries + // // delete any conflicting entries + // for j, entry := range args.Entries { + // if args.PrevLogIndex + j + 1 < len(rf.log) { + // if rf.log[args.PrevLogIndex + j + 1].Term != entry.Term { + // rf.log = rf.log[:args.PrevLogIndex + j + 1] + // break + // } + // } + // } + + // // append any new entries + // if len(args.Entries) > 0 { + // rf.log = append(rf.log, args.Entries...) + // } + + // check if existing entry conflicts with new entries + // delete any conflicting entries, then append new entries + conflictIndex := args.PrevLogIndex + 1 + for i, newEntry := range args.Entries { + if conflictIndex < len(rf.log) { + // conflict index found within existing log + if rf.log[conflictIndex].Term != newEntry.Term { + rf.log = rf.log[:conflictIndex] + rf.log = append(rf.log, args.Entries[i:]...) + break + } + } else { + // no conflict found, append new entries after existing log + rf.log = append(rf.log, args.Entries[i:]...) + break + } + conflictIndex++ + } + + // update commit index + if args.LeaderCommit > rf.commitIndex { + if args.LeaderCommit < len(rf.log) - 1 { + rf.commitIndex = args.LeaderCommit + } else { + rf.commitIndex = len(rf.log) - 1 + } + } + + reply.Success = true + + // fmt.Println(rf.me, rf.log) +} + +func (rf *Raft) sendAppendEntries(server int, args *AppendEntriesArgs, reply *AppendEntriesReply) bool { + ok := rf.peers[server].Call("Raft.AppendEntries", args, reply) + + rf.mu.Lock() + defer rf.mu.Unlock() + + if ok { + // reply term greater than current term + if reply.Term > rf.currentTerm { + rf.currentTerm = reply.Term + rf.votedFor = -1 + rf.CurrentState = "F" + rf.resetElectionTimer() + return ok + } + } else { + return ok + } + + return ok +} + +func (rf *Raft) sendHeartbeats() { + for !rf.killed() { + rf.mu.Lock() + + // must be a leader to send heartbeats + if rf.CurrentState != "L" { + rf.mu.Unlock() + // wait a little before checking again if leader + time.Sleep(5 * time.Millisecond) + continue + } + + peersCopy := make([]int, 0) + for i := 0; i < len(rf.peers); i++ { + peersCopy = append(peersCopy, i) + } + + // create AppendEntriesArgs for each peer + argsServer := make([]AppendEntriesArgs, len(peersCopy)) + for _, server := range peersCopy { + if server == rf.me { + continue + } + // check if new entries need to be sent + if len(rf.log) >= rf.nextIndex[server] { + entriesCopy := make([]LogEntry, len(rf.log[rf.nextIndex[server]:])) + copy(entriesCopy, rf.log[rf.nextIndex[server]:]) + + argsServer[server] = AppendEntriesArgs{ + Term: rf.currentTerm, + LeaderId: rf.me, + PrevLogIndex: rf.nextIndex[server] - 1, + PrevLogTerm: rf.log[rf.nextIndex[server] - 1].Term, + Entries: entriesCopy, + LeaderCommit: rf.commitIndex, + } + } else { + argsServer[server] = AppendEntriesArgs{ + Term: rf.currentTerm, + LeaderId: rf.me, + PrevLogIndex: rf.nextIndex[server] - 1, + PrevLogTerm: rf.log[rf.nextIndex[server] - 1].Term, + Entries: []LogEntry{}, + LeaderCommit: rf.commitIndex, + } + } + } + + // unlock to send AppendEntries RPC + rf.mu.Unlock() + + // send heartbeat to all peers + for _, server := range peersCopy { + if server == rf.me { + continue + } + // create server for each peer to send AppendEntries RPC back to + go func(server int, args AppendEntriesArgs) { + reply := AppendEntriesReply{} + ok := rf.sendAppendEntries(server, &args, &reply) + + if ok { + rf.mu.Lock() + // check if reply corresponds to current term + if rf.CurrentState == "L" && rf.currentTerm == args.Term { + if reply.Term > rf.currentTerm { + rf.currentTerm = reply.Term + rf.votedFor = -1 + rf.CurrentState = "F" + rf.resetElectionTimer() + } + } + + // check if reply was successful + if reply.Success { + newMatchIndex := args.PrevLogIndex + len(args.Entries) + rf.matchIndex[server] = newMatchIndex + rf.nextIndex[server] = newMatchIndex + 1 + } else { + // if rf.nextIndex[server] > 1 { + // rf.nextIndex[server]-- + // } + + // optimized nextIndex calculation + if reply.ConflictTerm != -1 { + conflictTerm := reply.ConflictTerm + lastIndexOfTerm := -1 + for i := len(rf.log) - 1; i >= 0; i-- { + if rf.log[i].Term == conflictTerm { + lastIndexOfTerm = i + break + } + } + if lastIndexOfTerm >= 0 { + rf.nextIndex[server] = lastIndexOfTerm + 1 + } else { + rf.nextIndex[server] = reply.ConflictIndex + } + } else { + rf.nextIndex[server] = reply.ConflictIndex + } + + if rf.nextIndex[server] < 1 { + rf.nextIndex[server] = 1 + } + } + + // check if commit index needs to be updated + for N := rf.commitIndex + 1; N < len(rf.log); N++ { + majority := len(rf.peers) / 2 + 1 + count := 1 + for i := 0; i < len(rf.peers); i++ { + if rf.matchIndex[i] >= N { + count++ + } + } + if count >= majority && rf.log[N].Term == rf.currentTerm { + rf.commitIndex = N + } + } + + rf.mu.Unlock() + } + }(server, argsServer[server]) + } + + // heartbeat duration + time.Sleep(100 * time.Millisecond) + } +} + +func (rf *Raft) resetElectionTimer() { + // reset election timer + if rf.electionTimeout != nil { + rf.electionTimeout.Stop() + } + + // set new election timeout (randomized between 300ms and 600ms) + timeoutValue := time.Duration(300 + rand.Intn(300)) * time.Millisecond + rf.electionTimeout = time.AfterFunc(timeoutValue, func() { + rf.startElection() + }) +} + +func (rf *Raft) startElection() { + rf.mu.Lock() + + // start election only if server is not leader + if rf.CurrentState == "L" { + rf.mu.Unlock() + return + } + + // increment current term and change state to candidate + rf.currentTerm++ + rf.votedFor = rf.me + rf.CurrentState = "C" + rf.resetElectionTimer() + + // vote count + voteCount := 1 + votesRequired := len(rf.peers) / 2 + 1 + + currentTermCopy := rf.currentTerm + candidateIdCopy := rf.me + lastLogIndexCopy := len(rf.log) - 1 + lastLogTermCopy := rf.log[lastLogIndexCopy].Term + + rf.mu.Unlock() + + // send RequestVote RPC to all peers + for i := 0; i < len(rf.peers); i++ { + if i != rf.me { + // create server for each peer to send RequestVote RPC back to + go func(server int) { + args := RequestVoteArgs{ + Term: currentTermCopy, + CandidateId: candidateIdCopy, + LastLogIndex: lastLogIndexCopy, + LastLogTerm: lastLogTermCopy, + } + + reply := RequestVoteReply{} + ok := rf.sendRequestVote(server, &args, &reply) + + if ok { + rf.mu.Lock() + + // check if reply corresponds to current term + if rf.CurrentState == "C" && rf.currentTerm == args.Term { + if reply.VoteGranted { + voteCount++ + if voteCount >= votesRequired { + // become leader + rf.CurrentState = "L" + + // reinitialize volatile state on leaders + rf.nextIndex = make([]int, len(rf.peers)) + rf.matchIndex = make([]int, len(rf.peers)) + for i := 0; i < len(rf.peers); i++ { + rf.nextIndex[i] = len(rf.log) + rf.matchIndex[i] = 0 + } + + // reset election timer + rf.resetElectionTimer() + } + } + } + + rf.mu.Unlock() + } + }(i) + } + } +} + +// +// the service using Raft (e.g. a k/v server) wants to start +// agreement on the next command to be appended to Raft's log. if this +// server isn't the leader, returns false. otherwise start the +// agreement and return immediately. there is no guarantee that this +// command will ever be committed to the Raft log, since the leader +// may fail or lose an election. even if the Raft instance has been killed, +// this function should return gracefully. +// +// the first return value is the index that the command will appear at +// if it's ever committed. the second return value is the current +// term. the third return value is true if this server believes it is +// the leader. +// +func (rf *Raft) Start(command interface{}) (int, int, bool) { + index := -1 + term := -1 + isLeader := true + + // Your code here (2B). + + rf.mu.Lock() + if rf.CurrentState != "L" { + isLeader = false + } else { + logEntry := LogEntry{ + Command: command, + Term: rf.currentTerm, + } + rf.log = append(rf.log, logEntry) + index = len(rf.log) - 1 + } + + term = rf.currentTerm + rf.mu.Unlock() + + return index, term, isLeader +} + +// +// the tester doesn't halt goroutines created by Raft after each test, +// but it does call the Kill() method. your code can use killed() to +// check whether Kill() has been called. the use of atomic avoids the +// need for a lock. +// +// the issue is that long-running goroutines use memory and may chew +// up CPU time, perhaps causing later tests to fail and generating +// confusing debug output. any goroutine with a long-running loop +// should call killed() to check whether it should stop. +// +func (rf *Raft) Kill() { + atomic.StoreInt32(&rf.dead, 1) + // Your code here, if desired. +} + +func (rf *Raft) killed() bool { + z := atomic.LoadInt32(&rf.dead) + return z == 1 +} + +func (rf *Raft) applyLogEntries() { + for { + rf.mu.Lock() + if rf.commitIndex > rf.lastApplied { + rf.lastApplied++ + msg := ApplyMsg{ + CommandValid: true, + Command: rf.log[rf.lastApplied].Command, + CommandIndex: rf.lastApplied, + } + rf.applyCh <- msg + rf.mu.Unlock() + } else { + rf.mu.Unlock() + + time.Sleep(5 * time.Millisecond) // sleep to avoid busy waiting + } + } +} + +// +// the service or tester wants to create a Raft server. the ports +// of all the Raft servers (including this one) are in peers[]. this +// server's port is peers[me]. all the servers' peers[] arrays +// have the same order. applyCh is a channel on which the +// tester or service expects Raft to send ApplyMsg messages. +// Make() must return quickly, so it should start goroutines +// for any long-running work. +// +func Make(peers []*labrpc.ClientEnd, me int, + applyCh chan ApplyMsg) *Raft { + rf := &Raft{} + rf.peers = peers + rf.me = me + + // Your initialization code here (2A, 2B). + rf.currentTerm = 0 + rf.votedFor = -1 + rf.log = make([]LogEntry, 1) // index 0 is empty + rf.log[0] = LogEntry{Command: nil, Term: 0} + + rf.commitIndex = 0 + rf.lastApplied = 0 + + rf.CurrentState = "F" // initial state is follower + + rf.nextIndex = make([]int, len(peers)) + rf.matchIndex = make([]int, len(peers)) + for i := 0; i < len(peers); i++ { + rf.nextIndex[i] = len(rf.log) + rf.matchIndex[i] = 0 + } + + rf.applyCh = applyCh + + // start election timer + rf.resetElectionTimer() + + // start go routine to send heartbeats + go rf.sendHeartbeats() + + // start go routine to apply log entries + go rf.applyLogEntries() + + return rf +} \ No newline at end of file