计网lab-基于udp实现可靠传输协议(ABP/GBN)

准备

Message和Packet类定义

分别作为应用层和传输层的数据包

其中Packet类的generateChecksum()方法根据报文数据生成校验和,selfcheck()方法根据报文数据和校验和检查数据是否损坏

public class Message {
  public char[] data = new char[20];
  public Message(char[] str) {
     for (int i = 0; i < 20; i++) data[i] = str[i];
  }
  public Message(Packet pkt) {
     for (int i = 0; i < 20; i++) data[i] = pkt.payload[i];
  }
}
package MyClass;
import java.io.Serializable;

public class Packet implements Serializable {
  private int seqnum;
  private int acknum;
  private int checksum;
  public char[] payload = new char[20];
  public void setSeq(int _seqnum) { seqnum = _seqnum; }
  public void setAck(int _acknum) { acknum = _acknum; }
  public int getSeq() { return seqnum; }
  public int getAck() { return acknum; }
  public void generateChecksum() {
​    checksum += seqnum;
​    checksum += acknum;
​    for (int i = 0; i < 20; i++) checksum += payload[i];
​    checksum = ~checksum;
  }
  public boolean selfcheck() {
​    int value = checksum;
​    value += seqnum;
​    value += acknum;
​    for (int i = 0; i < 20; i++) value +=payload[i];
​    if (value != -1) return false;
​    return true;
  }
  public Packet(){}
  public Packet(Message msg) {
​    for (int i = 0; i < 20; i++) {
​      payload[i] = msg.data[i];
​    }
  }
}

序列化模块Serialization

由于udp只能传输字节数组,需要定义序列化模块实现class与字节数组的相互转换

public class Message {
    public char[] data = new char[20];
    public Message(char[] str) {
        for (int i = 0; i < 20; i++) data[i] = str[i];
    }
    public Message(Packet pkt) {
        for (int i = 0; i < 20; i++) data[i] = pkt.payload[i];
    }
}
package MyClass;

import java.io.Serializable;

public class Packet implements Serializable {
    private int seqnum;
    private int acknum;
    private int checksum;
    public char[] payload = new char[20];
    public void setSeq(int _seqnum) { seqnum = _seqnum; }
    public void setAck(int _acknum) { acknum = _acknum; }
    public int getSeq() { return seqnum; }
    public int getAck() { return acknum; }
    public void generateChecksum() {
        checksum += seqnum;
        checksum += acknum;
        for (int i = 0; i < 20; i++) checksum += payload[i];
        checksum = ~checksum;
    }
    public boolean selfcheck() {
        int value = checksum;
        value += seqnum;
        value += acknum;
        for (int i = 0; i < 20; i++) value +=payload[i];
        if (value != -1) return false;
        return true;
    }
    public Packet(){}
    public Packet(Message msg) {
        for (int i = 0; i < 20; i++) {
            payload[i] = msg.data[i];
        }
    }
}

超时时间计算方法

参考书3.5.3

img

img

img

Alpha,beta分别取0.125和0.25

SampleRTT = receiveTime - sendTime;
EstimatedRTT = (long)(0.875 * EstimatedRTT + 0.125 * SampleRTT);
DevRTT = (long)(0.75 * DevRTT + 0.25 * Math.abs(SampleRTT - EstimatedRTT));
TimeoutInterval = EstimatedRTT + 4 * DevRTT;

将TImeoutInterval初始化为1000ms,EstimatedRTT和SampleRTT初始化为500ms

比特交替协议实现

通过变量记录当前0/1状态,通过线程交互实现停等和超时重发。

发送方

img

参考课本上rdt发送方的FSM进行设计。整体思路上使用三个线程分别处理上层调用、接收ACK响应包和超时重发。调用线程1将Message转化为Packet并将数据包通过A_output()发送给服务器端后,调用计时线程3并通过wait()进入等待状态,接收线程2收到响应包后,判断是否发送成功,若成功则中止计时线程3并唤醒线程1,否则线程3计时结束后通过toLayer3()重发数据包,再由线程2判断是否发送成功,直到接收到响应成功。

通过变量seqNum判断当前处于FSM的哪种状态,每接收到一个成功发送响应,seqNum变量的值将0/1交替。使用对象锁的方式实现线程交互(等待和唤醒)。

package rdt;

import java.io.IOException;
import java.net.DatagramSocket;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;

import MyClass.Message;
import MyClass.Packet;
import MyClass.Serialization;

public class Client implements Runnable {
    class Timer implements Runnable {
        private long timeout;
        Timer(long _timeout) { timeout = _timeout; }
        void setTimeout() {}
        public void run () {
            try {
                TimeUnit.MILLISECONDS.sleep(timeout);
                A_timerinterrupt();
            } catch (InterruptedException e) {
            } catch (IOException e) {
                e.printStackTrace();
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
    }
    void A_timerinterrupt() throws InterruptedException, ClassNotFoundException, IOException{
        System.out.println("前一个Packet发送超时,开始重发。");
        TimeoutInterval *= 2;
        startTimer(TimeoutInterval);
        toLayer3(lastpkt);
    }
    void startTimer(long increment) {
        System.out.println("计时器启动: " + increment + "ms后重新发送。");
        if (timer != null && !timer.isInterrupted()) stopTimer();
        timer = new Thread(new Timer(increment));
        timer.start();
    }
    void stopTimer() {
        System.out.println("计时器中止。");
        timer.interrupt();
    }
    class Receiver implements Runnable {
        public void run() {
            try {
                DatagramSocket datagramSocket = new DatagramSocket(11000);
                int cnt = 0;
                while (cnt < MAXSEQ) {
                    byte[] buf = new byte[1024];
                    DatagramPacket dataPacket = new DatagramPacket(buf, buf.length);
                    datagramSocket.receive(dataPacket);
                    Packet receivePack = Serialization.BytesToPacket(dataPacket.getData());
                    if (receivePack.selfcheck() && receivePack.getAck() == seqNum) {
                        cnt++;
                        int ackNum = receivePack.getAck();
                        long receiveTime = System.currentTimeMillis();
                        SampleRTT = receiveTime - sendTime;
                        EstimatedRTT = (long)(0.875 * EstimatedRTT + 0.125 * SampleRTT);
                        DevRTT = (long)(0.75 * DevRTT + 0.25 * Math.abs(SampleRTT - EstimatedRTT));
                        TimeoutInterval = EstimatedRTT + 4 * DevRTT;
                        stopTimer();
                        synchronized (obj) {
                            obj.notify();
                        }
                    } else {
                        System.out.println("响应包损坏,或前一次发送未被收到。");
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
    }
    static final Object obj = new Object();
    static Thread timer;
    static Thread thread = new Thread(new Client());
    Packet lastpkt;
    int MAXSEQ;
    int targetPort;
    int seqNum;
    long sendTime;
    long EstimatedRTT, SampleRTT, DevRTT, TimeoutInterval;
    void A_init() throws IOException {
        MAXSEQ = 50;
        seqNum = 0;
        TimeoutInterval = 1000;
        EstimatedRTT = SampleRTT = 500;
        DevRTT = 0;
        targetPort = 12000;
    }

    void toLayer3(Packet pkt) throws IOException {
        DatagramSocket datagramSocket = new DatagramSocket();
        byte[] sendBytes = Serialization.PacketToBytes(pkt);
        DatagramPacket datagramPacket = new DatagramPacket(
                sendBytes,
                sendBytes.length,
                InetAddress.getLocalHost(),
                targetPort
        );
        try {
            lastpkt = pkt;
            sendTime = System.currentTimeMillis();
            datagramSocket.send(datagramPacket);
            System.out.printf("Packet发送成功,时间已记录。序列号:%d\n", pkt.getSeq());
            // System.out.println(pkt.getSeq());
        } catch (IOException e) {
            System.out.println("Packet发送失败。");
        }
    }
    void A_output(Message msg) throws IOException, ClassNotFoundException, InterruptedException {
        Packet pkt = new Packet(msg);
        pkt.setSeq(seqNum);
        pkt.generateChecksum();
        startTimer(TimeoutInterval);
        toLayer3(pkt);
    }
    public void run() {
        Thread receiver = new Thread(new Receiver());
        receiver.start();
        try {
            A_init();
            for (int i = 1; i <= MAXSEQ; i++) {
                char[] data = new char[20];
                for (int j = 0; j < 20; j++) data[j] = (char) ('a' + (i - 1) % 26);
                Message msg = new Message(data);
                A_output(msg);
                synchronized (obj) {
                    obj.wait();
                }
                seqNum ^= 1;
            }
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    public static void main(String[] args){
        thread.start();
    }
}

接收方

img

同样通过seqNum判断当前FSM状态。发送响应包时,通过发送与当前状态相反的状态值表示数据包发送失败。通过参数lossRate和simulatedDelay模拟真实网络中的丢包和延迟情况。

lossRate = 0.2;

simulatedDelay = 200;

当发送相应包时,通过Math.random()生成一个0~1的随机数,若Math.random()<lossRate则视作响应包发送失败,不进行发送。每次发送响应包前随机等待利用TimeUnit.MILLISECONDS.sleep 随机等待Math.random() * simulatedDelay毫秒,模拟网络延迟。

由于阻塞情况只有一种(即等待用户端发送数据包),所以单线程即可实现。

package rdt;


import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.util.concurrent.TimeUnit;

import MyClass.Message;
import MyClass.Packet;
import MyClass.Serialization;


public class Server {
    DatagramSocket datagramSocket;
    byte[] buf;
    int seqNum;
    int targetPort;
    double lossRate;
    long simulatedDelay;
    void B_init() throws IOException {
        datagramSocket = new DatagramSocket(12000);
        lossRate = 0.2;
        simulatedDelay = 200;
        buf = new byte[1024];
        seqNum = 0;
        targetPort = 11000;
    }
    void tolayer5(Packet pkt) {
        Message msg = new Message(pkt);
        System.out.print("信息接收成功。");
        msg.show();
    }
    void B_input(Packet pkt, DatagramPacket datagramPacket) throws IOException{
        System.out.println(seqNum + " " + pkt.getSeq() + " " + pkt.selfcheck());
        if (!pkt.selfcheck() || pkt.getSeq() != seqNum) {
            System.out.println("Packet损坏或重复。");

            if (Math.random() > lossRate) {
                System.out.println("损坏响应已发送。");
                Packet responsePacket = new Packet();
                responsePacket.setAck(seqNum ^ 1);
                responsePacket.generateChecksum();
                byte[] sendData = Serialization.PacketToBytes(responsePacket);
                DatagramPacket sendPacket = new DatagramPacket(
                        sendData,
                        sendData.length,
                        datagramPacket.getAddress(),
                        targetPort
                );
                datagramSocket.send(sendPacket);
            } else {
                System.out.println("响应消息发送失败。");
            }
        } else {
            tolayer5(pkt);
            if (Math.random() > lossRate) {
                Packet responsePacket = new Packet();
                responsePacket.setAck(seqNum);
                responsePacket.generateChecksum();
                byte[] sendData = Serialization.PacketToBytes(responsePacket);
                DatagramPacket sendPacket = new DatagramPacket(
                        sendData,
                        sendData.length,
                        datagramPacket.getAddress(),
                        targetPort
                );
                datagramSocket.send(sendPacket);
            } else {
                System.out.println("响应消息发送失败。");
            }
            seqNum ^= 1;
        }
    }
    void run() throws IOException, ClassNotFoundException, InterruptedException{
        B_init();
        while(true) {
            DatagramPacket datagramPacket = new DatagramPacket(buf, buf.length);
            datagramSocket.receive(datagramPacket);
            Packet pkt = Serialization.BytesToPacket(buf);
            TimeUnit.MILLISECONDS.sleep((long)(Math.random() * simulatedDelay));
            B_input(pkt, datagramPacket);
        }
    }
    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException{
        Server server = new Server();
        server.run();
    }
}

GBN协议实现

实现方式与比特交替协议类似,区别在于GBN协议不是严格的停等协议,可能会出现连续发送多个包后再连续接收多个响应包的情况,数据包发送和接收情况变得更加复杂。因此没有采用阻塞的方式控制线程,而是给上层调用加了一定的间隔(300ms),并且当上层调用时滑动窗口大小以达到上限时,将丢弃一些上层调用。同时为了简化问题,序列号没有采用循环的方式,而是由1开始编号至n。

发送方

img

设置windowSize=8作为窗口最大值,通过base和nextSeqNum控制滑动窗口的大小和位置。收到上层调用时,先判断窗口大小是否已经达到上限,若还没有达到上限则发送数据包并将窗口扩大一位。由于响应包接收情况可能乱序,需要用数组的方式记录每个数据包发送的时间来计算RTT。触发超时重传时,需要循环重发base~nextSeqNum-1的所有数据包。其余类似比特交替协议。

package gbn;

import java.io.IOException;
import java.net.DatagramSocket;
import java.net.DatagramPacket;
import java.net.InetAddress;
import java.util.concurrent.TimeUnit;

import MyClass.Message;
import MyClass.Packet;
import MyClass.Serialization;


public class Client implements Runnable {
    class Timer implements Runnable {
        private long timeout;
        Timer(long _timeout) { timeout = _timeout; }
        public void run () {
            try {
                TimeUnit.MILLISECONDS.sleep(timeout);
                A_timerinterrupt();
            } catch (InterruptedException e) {
            } catch (IOException e) {
                e.printStackTrace();
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
    }
    void A_timerinterrupt() throws InterruptedException, ClassNotFoundException, IOException{
        System.out.println("计时器时限到,开始重发。");
        TimeoutInterval *= 2;
        startTimer(TimeoutInterval);
        for (int i = base; i < nextSeqNum; i++) {
            char[] data = new char[20];
            for (int j = 0; j < 20; j++) data[j] = (char) ('a' + i % 26);
            Message msg = new Message(data);
            Packet pkt = new Packet(msg);
            pkt.setSeq(i);
            pkt.generateChecksum();
            toLayer3(pkt);
        }
    }
    void startTimer(long increment) {
        System.out.println("计时器启动: " + increment + "ms后重新发送。");
        if (timer != null && !timer.isInterrupted()) stopTimer();
        timer = new Thread(new Timer(increment));
        timer.start();
    }
    void stopTimer() {
        System.out.println("计时器中止。");
        timer.interrupt();
    }
    class Receiver implements Runnable {
        public void run() {
            try {
                DatagramSocket datagramSocket = new DatagramSocket(11000);
                while (base != MAXSEQ + 1) {
                    byte[] buf = new byte[1024];
                    DatagramPacket dataPacket = new DatagramPacket(buf, buf.length);
                    datagramSocket.receive(dataPacket);
                    Packet receivePack = Serialization.BytesToPacket(dataPacket.getData());
                    if (receivePack.selfcheck()) {
                        int ackNum = receivePack.getAck();
                        base = Math.max(base, ackNum);
                        long receiveTime = System.currentTimeMillis();
                        SampleRTT = receiveTime - sendTime[ackNum - 1];
                        EstimatedRTT = (long)(0.875 * EstimatedRTT + 0.125 * SampleRTT);
                        DevRTT = (long)(0.75 * DevRTT + 0.25 * Math.abs(SampleRTT - EstimatedRTT));
                        TimeoutInterval = EstimatedRTT + 4 * DevRTT;
                        // System.out.println("TimeoutInterval被修改为" + TimeoutInterval + " ackNum为" + ackNum);
                        stopTimer();
                        if (base != nextSeqNum) startTimer(TimeoutInterval);
                    } else {
                        System.out.println("响应包损坏。");
                    }
                }
            } catch (IOException e) {
                e.printStackTrace();
            } catch (ClassNotFoundException e) {
                e.printStackTrace();
            }
        }
    }
    static Thread timer;
    static Thread thread = new Thread(new Client());
    int MAXSEQ;
    int targetPort;
    int windowSize;
    int base;
    int nextSeqNum;
    static long sendTime[];
    long EstimatedRTT, SampleRTT, DevRTT, TimeoutInterval;
    void A_init() throws IOException {
        MAXSEQ = 50;
        base = nextSeqNum = 1;
        windowSize = 8;
        TimeoutInterval = 1000;
        EstimatedRTT = SampleRTT = 500;
        DevRTT = 0;
        targetPort = 12000;
        sendTime = new long[MAXSEQ + 1];
    }

    void toLayer3(Packet pkt) throws IOException {
        DatagramSocket datagramSocket = new DatagramSocket();
        byte[] sendBytes = Serialization.PacketToBytes(pkt);
        DatagramPacket datagramPacket = new DatagramPacket(
                sendBytes,
                sendBytes.length,
                InetAddress.getLocalHost(),
                targetPort
        );
        try {

            sendTime[pkt.getSeq()] = System.currentTimeMillis();
            datagramSocket.send(datagramPacket);
            System.out.printf("第 %d 个Packet发送成功,时间已记录。\n", pkt.getSeq());
        } catch (IOException e) {
            System.out.println("Packet发送失败。");
        }
    }
   void A_output(Message msg) throws IOException, ClassNotFoundException, InterruptedException {
        if (nextSeqNum < base + windowSize) {
            Packet pkt = new Packet(msg);
            pkt.setSeq(nextSeqNum);
            pkt.generateChecksum();

            toLayer3(pkt);
            if (base == nextSeqNum) startTimer(TimeoutInterval);
            nextSeqNum++;
        }
    }
    public void run() {
        Thread receiver = new Thread(new Receiver());
        receiver.start();
        try {
            A_init();
            for (int i = 1; i <= MAXSEQ; i++) {
                char[] data = new char[20];
                for (int j = 0; j < 20; j++) data[j] = (char) ('a' + (i - 1) % 26);
                Message msg = new Message(data);
                A_output(msg);
                TimeUnit.MILLISECONDS.sleep(200);
            }
        } catch (IOException e) {
            e.printStackTrace();
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
    public static void main(String[] args){
        thread.start();
    }
}

接收方

img

和比特交替协议的区别在于,收到的序列号seqnum和响应包发送的acknum不再是0/1交替,而是每次自增一。GBN接收方采取累计接收方法,如果接收到的数据包有误,返回的acknum为当前已接收的最大的序列号,表示该序列号以及所有该序列号以前的包都已成功接收。

package gbn;

import java.io.IOException;
import java.net.DatagramPacket;
import java.net.DatagramSocket;
import java.util.concurrent.TimeUnit;

import MyClass.Message;
import MyClass.Packet;
import MyClass.Serialization;

public class Server {
    DatagramSocket datagramSocket;
    byte[] buf;
    int seqNum;
    int windowSize;
    int targetPort;
    double lossRate;
    long simulatedDelay;
    void B_init() throws IOException {
        windowSize = 8;
        datagramSocket = new DatagramSocket(12000);
        lossRate = 0.2;
        simulatedDelay = 200;
        buf = new byte[1024];
        seqNum = 1;
        targetPort = 11000;
    }
    void tolayer5(Packet pkt) {
        Message msg = new Message(pkt);
        System.out.print("信息接收成功。");
        msg.show();
    }
    void B_input(Packet pkt, DatagramPacket datagramPacket) throws IOException{
        if (!pkt.selfcheck() || pkt.getSeq() != seqNum) {
            System.out.println("Packet损坏或重复。");
            // System.out.println(seqNum + " " + pkt.getSeq() + " " + pkt.selfcheck());
            if (Math.random() > lossRate) {
                System.out.println("损坏响应已发送。");
                Packet responsePacket = new Packet();
                responsePacket.setAck(seqNum);
                responsePacket.generateChecksum();
                byte[] sendData = Serialization.PacketToBytes(responsePacket);
                DatagramPacket sendPacket = new DatagramPacket(
                        sendData,
                        sendData.length,
                        datagramPacket.getAddress(),
                        targetPort
                );
                datagramSocket.send(sendPacket);
            } else {
                System.out.println("响应消息发送失败。");
            }
        } else {
            tolayer5(pkt);
            seqNum++;
            if (Math.random() > lossRate) {
                Packet responsePacket = new Packet();
                responsePacket.setAck(seqNum);
                responsePacket.generateChecksum();
                byte[] sendData = Serialization.PacketToBytes(responsePacket);
                // 好坑
                DatagramPacket sendPacket = new DatagramPacket(
                        sendData,
                        sendData.length,
                        datagramPacket.getAddress(),
                        targetPort
                );
                datagramSocket.send(sendPacket);
            } else {
                System.out.println("响应消息发送失败。");
            }
        }
    }
    void run() throws IOException, ClassNotFoundException, InterruptedException{
        B_init();
        while(true) {
            DatagramPacket datagramPacket = new DatagramPacket(buf, buf.length);
            datagramSocket.receive(datagramPacket);
            Packet pkt = Serialization.BytesToPacket(buf);
            TimeUnit.MILLISECONDS.sleep((long)(Math.random() * simulatedDelay));
            B_input(pkt, datagramPacket);
        }
    }
    public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException{
        Server server = new Server();
        server.run();
    }
}

结果和测试

可以通过参数设置响应包发送失败的概率,当丢包概率为0时,所有消息被正确接收,不会触发超时重发。(图为比特交替协议):

img

img

将丢包概率提高到30%后,服务器端频繁出现响应包发送失败的情况,客户端同时触发超时重传。(图为GBN协议):

img

img

总结

由于代码量较大,本次实验过程中对实验所需要的常用功能进行了模块化(如序列化模块实现class与字符数组相互转化、Packet校验和生成和检验等),同时加深了对java语言和多线程编程的运用和理解。

实验过程中调试解决了一些问题,如:

多线程编程中udp数据包发送线程和接收线程的端口号不同,这意味着服务器端不能通过receivePacket.getPort()来获取端口并作为响应包的发送端口,因为通过这方法获得的端口是发送线程而不是接收线程的。因此需要提前指定客户端的socket端口,如11000。