Netty--使用TCP协议传输文件

简介:

用于将文件通过TCP协议传输到另一台机器,两台机器需要通过网络互联。

实现:

使用Netty进行文件传输,服务端读取文件并将文件拆分为多个数据块发送,接收端接收数据块,并按顺序将数据写入文件。

工程结构:

Netty--使用TCP协议传输文件

Maven配置:

<dependencies>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.15.Final</version>
        </dependency>

        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>3.8.1</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

传输对象:type为数据块类型,index为数据块的序列,length为数据块的大小,data为要传输的数据。

public class TransData {

    private TypeEnum type;

    private int index;

    private int length;

    private ByteBuf data;

    public TransData() {
    }

    public TransData(TypeEnum type, ByteBuf data) {
        this.type = type;
        this.data = data;
    }

    public TypeEnum getType() {
        return type;
    }

    public void setType(TypeEnum type) {
        this.type = type;
    }

    public int getIndex() {
        return index;
    }

    public void setIndex(int index) {
        this.index = index;
    }

    public int getLength() {
        return length;
    }

    public void setLength(int length) {
        this.length = length;
    }

    public ByteBuf getData() {
        return data;
    }

    public void setData(ByteBuf data) {
        this.data = data;
    }

    @Override
    public String toString() {
        return "TransData{" +
                "type=" + type +
                ", index=" + index +
                ", length=" + length +
                '}';
    }
}

---

类型枚举:

public enum TypeEnum {

    UNKNOW(0),

    CMD(1),

    MSG(2),

    DATA(3),

    BEGIN(4),

    END(5);

    short value;

    TypeEnum(int value) {
        this.value = (short) value;
    }

    public static TypeEnum get(short s) {
        for (TypeEnum e : TypeEnum.values()) {
            if (e.value == s) {
                return e;
            }
        }
        return UNKNOW;
    }

    public short value() {
        return this.value;
    }

}

---

解码器,从数据块还原Java对象

public class Decode extends ReplayingDecoder<Void> {

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        TransData data = new TransData();
        data.setType(TypeEnum.get(in.readShort()));
        data.setIndex(in.readInt());
        data.setLength(in.readInt());
        data.setData(in.readBytes(data.getLength()));
        out.add(data);
    }
}

---

编码器:

public class Encode extends MessageToByteEncoder<TransData> {

    @Override
    protected void encode(ChannelHandlerContext ctx, TransData msg, ByteBuf out) throws Exception {
        out.writeShort(msg.getType().value())
                .writeInt(msg.getIndex())
                .writeInt(msg.getData().readableBytes())
                .writeBytes(msg.getData());
    }
}

--- 

数据接收器

public class Receiver {

    private SortedQueue queue = new SortedQueue();
    private String dstPath = System.getProperty("user.dir") + File.separator + "received";
    private String fileName;
    private long receivedSize = 0;
    private long totalSize = 0;
    private int chunkIndex = 0;
    private FileOutputStream out;
    private FileChannel ch;
    private long t1;
    private int process = 0;

    public Receiver(TransData data) {
        init(Tool.getMsg(data));
    }

    public void init(String msg) {
        String[] ss = msg.split("/:/");
        fileName = ss[0].trim();
        totalSize = Long.valueOf(ss[1].trim());
        //new File(dstPath).mkdirs();
        File f = new File(dstPath + File.separator + fileName);
        if (f.exists()) {
            f.delete();
        }
        try {
            out = new FileOutputStream(f);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        ch = out.getChannel();
        queue.clear();
        System.out.println("receive begin: " + fileName + "  " + Tool.size(totalSize));
        new MyThread().start();
        t1 = System.currentTimeMillis();
    }

    public void receiver(TransData data) {
        queue.put(data);
    }

    public void end() throws IOException {
        long cost = Math.round((System.currentTimeMillis() - t1) / 1000f);
        System.out.println("receive over: " + fileName + "   Cost Time: " + cost + "s");
        if (out != null) {
            out.close();
            out = null;
            ch = null;
        }
        fileName = null;
        chunkIndex = 0;
        totalSize = 0;
        receivedSize = 0;
        process = 0;
        queue.clear();
    }

    private void printProcess() {
        int ps = (int) (receivedSize * 100 / totalSize);
        if (ps != process) {
            this.process = ps;
            System.out.print(process + "% ");
            if (this.process % 10 == 0 || process >= 100) {
                System.out.println();
            }
        }
    }

    private class MyThread extends Thread {
        public void run() {
            try {
                while (true) {
                    TransData data = queue.offer(chunkIndex++);
                    if (data.getType() == TypeEnum.END) {
                        end();
                        break;
                    }
                    ByteBuf bfn = data.getData();
                    receivedSize += data.getLength();
                    ByteBuffer bf = bfn.nioBuffer();
                    ch.write(bf);
                    printProcess();
                    bfn.release();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

}

---

数据发送器,异步发送,每次传输任务实例化一个发送器对象。

public class Sender implements Runnable {

    String path;
    String name;
    File f;
    FileInputStream in;
    FileChannel ch;
    ByteBuffer bf;
    int index = 0;
    Channel channel;
    private long t0;

    public Sender(String path, String name, Channel c) {
        this.path = path;
        this.name = name;
        this.channel = c;
    }

    @Override
    public void run() {
        begin(path);
        send();
    }

    public void begin(String path) {
        f = new File(path + File.separator + name);
        if (!f.exists() || !f.isFile() || !f.canRead()) {
            Tool.sendMsg(channel, "file can not read.");
        }

        try {
            in = new FileInputStream(f);
            ch = in.getChannel();
            bf = ByteBuffer.allocate(20480);
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }

        t0 = System.currentTimeMillis();
        System.out.println("send begin: " + name + "  " + Tool.size(f.length()));
        Tool.sendBegin(channel, name + "/:/" + f.length());
    }

    public void send() {
        if (in == null) {
            return;
        }
        try {
            while (ch.read(bf) != -1) {
                while (!channel.isWritable()) {
                    TimeUnit.MILLISECONDS.sleep(5);
                }
                bf.flip();
                Tool.sendData(channel, bf, index);
                index++;
                bf.clear();
            }
            Tool.sendEnd(channel, index);

            long cost = Math.round((System.currentTimeMillis() - t0) / 1000f);
            System.out.println("send over: " + f.getName() + "   Cost Time: " + cost + "s");
            clear();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void clear() throws IOException {
        in.close();
        bf.clear();
        f = null;
        in = null;
        index = 0;
    }

}

---

发送器线程池:

public class SenderThreadPool {

    private static ExecutorService exe = Executors.newFixedThreadPool(10);

    public static void exe(Runnable run) {
        exe.execute(run);
    }

}

---

有序的数据队列,缓存接收到的数据,并按序号排序。

/*
按序列取出队列中的元素,如果序列缺失则阻塞
 */
public class SortedQueue {

    final Lock lock = new ReentrantLock();
    final Condition canTake = lock.newCondition();
    private final AtomicInteger count = new AtomicInteger();
    private LinkedList<TransData> list = new LinkedList<TransData>();

    public void put(TransData node) {
        lock.lock();
        boolean p = false;
        for (int i = 0; i < list.size(); i++) {
            if (node.getIndex() < list.get(i).getIndex()) {
                list.add(i, node);
                p = true;
                break;
            }
        }
        if (p == false) {
            list.add(node);
        }
        count.incrementAndGet();
        canTake.signal();
        lock.unlock();
    }

    public TransData offer(int index) {
        lock.lock();
        try {
            while (list.isEmpty() || list.get(0).getIndex() != index) {
                canTake.await();
            }
            count.getAndDecrement();
            return list.pop();
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            lock.unlock();
        }
        return null;
    }

    public int size() {
        return count.get();
    }

    public void clear() {
        list.clear();
        count.set(0);
    }
}

---

服务端Handler:

public class ServerHandler extends SimpleChannelInboundHandler<TransData> {

    private String cpath = System.getProperty("user.dir");

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TransData data) throws Exception {
        handle(ctx, data);
    }

    private void handle(ChannelHandlerContext ctx, TransData data) throws Exception {
        if (data.getType() == TypeEnum.CMD) {
            String cmd = Tool.getMsg(data);
            if (cmd.equalsIgnoreCase("ls")) {
                ls(ctx.channel());
            } else if (cmd.startsWith("cd ")) {
                cd(ctx.channel(), cmd);
            } else if (cmd.startsWith("get ")) {
                String name = cmd.substring(4);
                Sender sender = new Sender(cpath, name, ctx.channel());
                SenderThreadPool.exe(sender);
            } else if (cmd.equalsIgnoreCase("pwd")) {
                Tool.sendMsg(ctx.channel(), "now at: " + cpath);
            } else {
                Tool.sendMsg(ctx.channel(), "unknow command!");
            }
        }
    }

    private void ls(Channel channel) {
        int k = 0;
        StringBuilder sb = new StringBuilder();
        File file = new File(cpath);
        for (File f : file.listFiles()) {
            if (f.isDirectory()) {
                sb.append(k);
                sb.append("/:/");
                sb.append("目录");
                sb.append("/:/");
                sb.append(f.getName());
                sb.append("
");
                k++;
            }
        }
        for (File f : file.listFiles()) {
            if (f.isFile()) {
                sb.append(k);
                sb.append("/:/");
                sb.append(Tool.size(f.length()));
                sb.append("/:/");
                sb.append(f.getName());
                sb.append("
");
                k++;
            }
        }
        Tool.sendMsg(channel, "ls " + sb.toString());
    }

    private void cd(Channel channel, String cmd) {
        String dir = cmd.substring(3).trim();
        if (dir.equals("..")) {
            File f = new File(cpath);
            f = f.getParentFile();
            cpath = f.getAbsolutePath();
            Tool.sendMsg(channel, "new path " + cpath);
            ls(channel);
        } else {
            String path1 = cpath + File.separator + dir;
            File f1 = new File(path1);
            if (f1.exists()) {
                cpath = path1;
                Tool.sendMsg(channel, "new path " + cpath);
                ls(channel);
            } else {
                Tool.sendMsg(channel, "error, path not found");
            }
        }
    }

}

---

客户端Handler:

public class ClientHandler extends SimpleChannelInboundHandler<TransData> {

    private static Map<Integer, String> map = new HashMap();

    Receiver receiver;

    public static String getName(int i) {
        return map.get(Integer.valueOf(i));
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        Tool.sendCmd(ctx.channel(), "pwd");
        Tool.sendCmd(ctx.channel(), "ls");
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TransData data) throws Exception {
        TypeEnum type = data.getType();
        if (type == TypeEnum.MSG) {
            String msg = Tool.getMsg(data);
            if (msg.startsWith("ls ")) {
                praseLs(msg);
            } else if (msg.startsWith("msg ")) {
                System.out.println(msg.substring(4));
            } else {
                System.out.println(msg);
            }
        } else if (type == TypeEnum.DATA || type == TypeEnum.END) {
            receiver.receiver(data);
        } else if (type == TypeEnum.BEGIN) {
            receiver = new Receiver(data);
        } else {
            System.out.println(Tool.getMsg(data));
        }
    }

    private void praseLs(String msg) {
        map.clear();
        String ss = msg.substring(3).trim();
        String[] paths = ss.split("
");
        for (String p : paths) {
            p = p.trim();
            String[] dd = p.split("/:/");
            if (dd.length == 3) {
                System.out.println(dd[0] + " " + dd[1] + " " + dd[2]);
                map.put(Integer.valueOf(dd[0].trim()), dd[2].trim());
            }
        }
    }

}

---

客户端启动:

public class TransClient {

    private static TransClient client = new TransClient();
    private String ip;
    private int port;
    private Channel channel = null;
    private Thread t = new ClientThread();

    private TransClient() {
    }

    public static TransClient instance() {
        return client;
    }

    public void start(String ip, int port) {
        if (t.isAlive()) {
            return;
        }
        this.ip = ip;
        this.port = port;
        t.start();
    }

    public void readCmd() {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNextLine()) {
            String cmd = sc.nextLine().trim();
            if (cmd.equalsIgnoreCase("exit")) {
                channel.closeFuture();
                return;
            } else if (cmd.startsWith("get ")) {
                int i = Integer.valueOf(cmd.substring(4).trim());
                cmd = "get " + ClientHandler.getName(i);
            } else if (cmd.startsWith("cd ")) {
                String p = cmd;
                p = p.substring(3).trim();
                if (!p.equals("..")) {
                    try {
                        int i = Integer.valueOf(p);
                        cmd = "cd " + ClientHandler.getName(i);
                    } catch (Exception e) {
                    }
                }
            }
            Tool.sendCmd(channel, cmd);
        }
    }

    private class ClientThread extends Thread {
        @Override
        public void run() {
            Bootstrap bootstrap = new Bootstrap();
            EventLoopGroup group = new NioEventLoopGroup();
            try {
                bootstrap.group(group).channel(NioSocketChannel.class);
                bootstrap.handler(new ChannelInitializer<Channel>() {
                    @Override
                    protected void initChannel(Channel ch) throws Exception {
                        ChannelPipeline pipeline = ch.pipeline();
                        ch.pipeline().addLast(new Decode());
                        ch.pipeline().addLast(new Encode());
                        pipeline.addLast(new ClientHandler());
                    }
                });
                bootstrap.option(ChannelOption.SO_KEEPALIVE, true);

                channel = bootstrap.connect(ip, port).sync().channel();
                System.out.println("Trans Client connect to " + ip + ":" + port);
                channel.closeFuture().sync();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                group.shutdownGracefully();
                System.out.println("Trans Client stoped.");
            }
        }
    }

}

---

服务端启动:

public class TransServer {

    private int port;
    private static TransServer server = new TransServer();
    private Thread t = new ServerThread();

    private TransServer() {
    }

    public static TransServer instance() {
        return server;
    }

    public void start(int port) {
        this.port = port;
        t.start();
    }

    private class ServerThread extends Thread {
        @Override
        public void run() {
            EventLoopGroup bossGroup = new NioEventLoopGroup(1);
            EventLoopGroup workerGroup = new NioEventLoopGroup(1);
            try {
                ServerBootstrap bootstrap = new ServerBootstrap();
                bootstrap.group(bossGroup, workerGroup)
                        .channel(NioServerSocketChannel.class)
                        .option(ChannelOption.SO_BACKLOG, 100)
                        .childHandler(new ChannelInitializer<SocketChannel>() {
                            @Override
                            protected void initChannel(SocketChannel ch) throws Exception {
                                ch.pipeline().addLast(new Decode());
                                ch.pipeline().addLast(new Encode());
                                ch.pipeline().addLast(new ServerHandler());
                            }
                        });
                ChannelFuture future = bootstrap.bind(port).sync();
                System.out.println("Trans Server started, port: " + port);
                future.channel().closeFuture().sync();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                System.out.println("Trans Server shuting down");
                bossGroup.shutdownGracefully();
                workerGroup.shutdownGracefully();
            }
        }
    }

}

---

工具类,发送消息和数据。

public class Tool {

    public static final Charset CHARSET = Charset.forName("UTF8");

    public static void sendMsg(Channel ch, String msg) {
        ByteBuffer bf = CHARSET.encode(msg);
        ByteBuf bfn = Unpooled.copiedBuffer(bf);
        TransData d = new TransData(TypeEnum.MSG, bfn);
        ch.writeAndFlush(d);
    }

    public static void sendCmd(Channel ch, String msg) {
        ByteBuffer bf = CHARSET.encode(msg);
        ByteBuf bfn = Unpooled.copiedBuffer(bf);
        TransData d = new TransData(TypeEnum.CMD, bfn);
        ch.writeAndFlush(d);
    }

    public static String getMsg(TransData data) {
        CharBuffer cb = CHARSET.decode(data.getData().nioBuffer());
        return cb.toString().trim();
    }


    public static void sendBegin(Channel ch, String msg) {
        ByteBuffer bf = CHARSET.encode(msg);
        ByteBuf bfn = Unpooled.copiedBuffer(bf);
        TransData d = new TransData(TypeEnum.BEGIN, bfn);
        ch.writeAndFlush(d);
    }

    public static void sendData(Channel ch, ByteBuffer bf, int index) {
        TransData data = new TransData();
        data.setType(TypeEnum.DATA);
        ByteBuf bfn = Unpooled.copiedBuffer(bf);
        data.setData(bfn);
        data.setIndex(index);
        ch.writeAndFlush(data);
    }

    public static void sendEnd(Channel ch, int index) {
        TransData data = new TransData(TypeEnum.END, Unpooled.EMPTY_BUFFER);
        data.setIndex(index);
        ch.writeAndFlush(data);
    }

    public static String size(long num) {
        long m = 1 << 20;
        if (num / m == 0) {
            return (num / 1024) + "KB";
        }
        return num / m + "MB";
    }

}

---

配置读取类:

public class ConfigTool {

    private static final String CONFIG_PATH = System.getProperty("user.dir") + File.separator + "config" + File.separator + "app.properties";

    private static Properties ppt = new Properties();

    static {
        try {
            ppt.load(new FileInputStream(CONFIG_PATH));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void reload() {
        ppt.clear();
        try {
            ppt.load(new FileInputStream(CONFIG_PATH));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String getValue(String key) {
        key = key.trim();
        if (ppt.containsKey(key)) {
            String value = ppt.getProperty(key).trim();
            if ("".equals(value)) {
                return null;
            }
            return value;
        }
        return null;
    }

    public static int getInt(String key) {
        String s = getValue(key);
        return Integer.valueOf(s);
    }

}

--- 

启动类:

public class CmdApp {

    public static void main(String[] args) {

        String mode = ConfigTool.getValue("mode");
        String ip = ConfigTool.getValue("server.ip");
        int port = ConfigTool.getInt("server.port");

        if (mode == null) {
            System.out.println("error");
        } else if (mode.equals("server")) {
            TransServer.instance().start(port);
        } else if (mode.equals("client")) {
            TransClient.instance().start(ip, port);
            TransClient.instance().readCmd();
        } else if (mode.equals("both")) {
            TransServer.instance().start(port);
            TransClient.instance().start(ip, port);
            TransClient.instance().readCmd();
        }

    }
}

--- 

end