From 4c2e9d987aaef85882c8c2af266742ea82a197c9 Mon Sep 17 00:00:00 2001 From: AhaanKanaujia <kanaujia.ahaan@gmail.com> Date: Sun, 13 Apr 2025 19:22:41 -0500 Subject: [PATCH] added optimization --- raft/raft.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/raft/raft.go b/raft/raft.go index 3299033..c110a38 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -240,6 +240,9 @@ type AppendEntriesArgs struct { type AppendEntriesReply struct { Term int // current term (for leader to update itself) Success bool // true if follower contained entry matching prevLogIndex and prevLogTerm + + ConflictTerm int // term of conflicting entry + ConflictIndex int // index of first entry in log that conflicts with new entries } // AppendEntries RPC handler @@ -252,7 +255,11 @@ func (rf *Raft) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesReply defer rf.mu.Unlock() rf.resetElectionTimer() + + reply.Term = rf.currentTerm reply.Success = false + reply.ConflictTerm = -1 + reply.ConflictIndex = -1 // candidate term is less than current term if args.Term < rf.currentTerm { @@ -274,6 +281,23 @@ func (rf *Raft) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesReply return } + // check if log contains entry within log + if args.PrevLogIndex >= len(rf.log) { + reply.ConflictIndex = len(rf.log) + return + } + + // check if log entry at precLogIndex does not match prevLogTerm + if rf.log[args.PrevLogIndex].Term != args.PrevLogTerm { + reply.ConflictTerm = rf.log[args.PrevLogIndex].Term + conflictIndex := args.PrevLogIndex + for conflictIndex >= 0 && rf.log[conflictIndex].Term == reply.ConflictTerm { + conflictIndex-- + } + reply.ConflictIndex = conflictIndex + 1 + return + } + // // check if existing entry conflicts with new entries // // delete any conflicting entries // for j, entry := range args.Entries { @@ -318,7 +342,6 @@ func (rf *Raft) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesReply } } - reply.Term = rf.currentTerm reply.Success = true // fmt.Println(rf.me, rf.log) @@ -425,8 +448,31 @@ func (rf *Raft) sendHeartbeats() { rf.matchIndex[server] = newMatchIndex rf.nextIndex[server] = newMatchIndex + 1 } else { - if rf.nextIndex[server] > 1 { - rf.nextIndex[server]-- + // if rf.nextIndex[server] > 1 { + // rf.nextIndex[server]-- + // } + + // optimized nextIndex calculation + if reply.ConflictTerm != -1 { + conflictTerm := reply.ConflictTerm + lastIndexOfTerm := -1 + for i := len(rf.log) - 1; i >= 0; i-- { + if rf.log[i].Term == conflictTerm { + lastIndexOfTerm = i + break + } + } + if lastIndexOfTerm >= 0 { + rf.nextIndex[server] = lastIndexOfTerm + 1 + } else { + rf.nextIndex[server] = reply.ConflictIndex + } + } else { + rf.nextIndex[server] = reply.ConflictIndex + } + + if rf.nextIndex[server] < 1 { + rf.nextIndex[server] = 1 } } -- GitLab