package org.rx.socks;
import lombok.extern.slf4j.Slf4j;
import org.rx.core.LogWriter;
import org.rx.core.NQuery;
import org.rx.beans.DateTime;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import static org.rx.core.Contract.require;
@Slf4j
public final class SocketPool extends Traceable implements AutoCloseable {
public static final class PooledSocket implements AutoCloseable {
private final SocketPool owner;
private DateTime lastActive;
public final Socket socket;
public boolean isConnected() {
return !owner.isClosed() && !socket.isClosed() && socket.isConnected();
}
public DateTime getLastActive() {
return lastActive;
}
public void setLastActive(DateTime lastActive) {
this.lastActive = lastActive;
}
private PooledSocket(SocketPool owner, Socket socket) {
this.owner = owner;
this.socket = socket;
lastActive = DateTime.utcNow();
}
@Override
public void close() {
owner.returnSocket(this);
}
}
public static final SocketPool Pool = new SocketPool();
private static final int DefaultConnectTimeout = 30000;
private static final int DefaultMaxIdleMillis = 120000;
private static final int DefaultMaxSocketsCount = 64;
private final ConcurrentHashMap<InetSocketAddress, ConcurrentLinkedDeque<PooledSocket>> pool;
private volatile int connectTimeout;
private volatile int maxIdleMillis;
private volatile int maxSocketsCount;
private final Timer timer;
private volatile boolean isTimerRun;
public int getConnectTimeout() {
return connectTimeout;
}
public void setConnectTimeout(int connectTimeout) {
this.connectTimeout = connectTimeout;
}
public int getMaxIdleMillis() {
return maxIdleMillis;
}
public void setMaxIdleMillis(int maxIdleMillis) {
if (maxIdleMillis <= 0) {
maxIdleMillis = DefaultMaxIdleMillis;
}
this.maxIdleMillis = maxIdleMillis;
}
public int getMaxSocketsCount() {
return maxSocketsCount;
}
public void setMaxSocketsCount(int maxSocketsCount) {
if (maxSocketsCount < 0) {
maxSocketsCount = 0;
}
this.maxSocketsCount = maxSocketsCount;
}
private SocketPool() {
pool = new ConcurrentHashMap<>();
connectTimeout = DefaultConnectTimeout;
maxIdleMillis = DefaultMaxIdleMillis;
maxSocketsCount = DefaultMaxSocketsCount;
String n = "SocketPool";
timer = new Timer(n, true);
LogWriter tracer = new LogWriter();
tracer.setPrefix(n + " ");
tracer.info("started..");
setTracer(tracer);
}
@Override
protected void freeObjects() {
clear();
}
private void runTimer() {
if (isTimerRun) {
return;
}
synchronized (timer) {
if (isTimerRun) {
return;
}
long period = 90000;
timer.schedule(new TimerTask() {
@Override
public void run() {
clearIdleSockets();
}
}, period, period);
isTimerRun = true;
}
getTracer().info("runTimer..");
}
private void clearIdleSockets() {
for (Map.Entry<InetSocketAddress, ConcurrentLinkedDeque<PooledSocket>> entry : NQuery.of(pool.entrySet())) {
ConcurrentLinkedDeque<PooledSocket> sockets = entry.getValue();
if (sockets == null) {
continue;
}
for (PooledSocket socket : NQuery.of(sockets)) {
if (!socket.isConnected()
|| DateTime.utcNow().subtract(socket.getLastActive()).getTotalMilliseconds() >= maxIdleMillis) {
sockets.remove(socket);
getTracer().info("clear idle socket[local=%s, remote=%s]..",
Sockets.getId(socket.socket, false), Sockets.getId(socket.socket, true));
}
}
if (sockets.isEmpty()) {
pool.remove(entry.getKey());
}
}
if (pool.size() == 0) {
stopTimer();
}
}
private void stopTimer() {
synchronized (timer) {
timer.cancel();
timer.purge();
isTimerRun = false;
}
getTracer().info("stopTimer..");
}
private ConcurrentLinkedDeque<PooledSocket> getSockets(InetSocketAddress remoteAddr) {
ConcurrentLinkedDeque<PooledSocket> sockets = pool.get(remoteAddr);
if (sockets == null) {
pool.put(remoteAddr, sockets = new ConcurrentLinkedDeque<>());
runTimer();
}
return sockets;
}
public PooledSocket borrowSocket(InetSocketAddress remoteAddr) {
checkNotClosed();
require(remoteAddr);
boolean isExisted = true;
ConcurrentLinkedDeque<PooledSocket> sockets = getSockets(remoteAddr);
PooledSocket pooledSocket;
if ((pooledSocket = sockets.pollFirst()) == null) {
Socket sock = new Socket();
try {
sock.connect(remoteAddr, connectTimeout);
} catch (IOException ex) {
throw new SocketException(remoteAddr, ex);
}
pooledSocket = new PooledSocket(this, sock);
isExisted = false;
}
if (!pooledSocket.isConnected()) {
if (isExisted) {
sockets.remove(pooledSocket);
}
return borrowSocket(remoteAddr);
}
Socket sock = pooledSocket.socket;
getTracer().info("borrow %s socket[local=%s, remote=%s]..", isExisted ? "existed" : "new",
Sockets.getId(sock, false), Sockets.getId(sock, true));
return pooledSocket;
}
public void returnSocket(PooledSocket pooledSocket) {
checkNotClosed();
require(pooledSocket);
String action = "return";
try {
if (!pooledSocket.isConnected()) {
action = "discard closed";
return;
}
pooledSocket.setLastActive(DateTime.utcNow());
ConcurrentLinkedDeque<PooledSocket> sockets = getSockets(
(InetSocketAddress) pooledSocket.socket.getRemoteSocketAddress());
if (sockets.size() >= maxSocketsCount || sockets.contains(pooledSocket)) {
action = "discard contains";
return;
}
sockets.addFirst(pooledSocket);
} finally {
Socket sock = pooledSocket.socket;
getTracer().info("%s socket[local=%s, remote=%s]..", action, Sockets.getId(sock, false),
Sockets.getId(sock, true));
}
}
public void clear() {
checkNotClosed();
for (Socket socket : NQuery.of(pool.values()).selectMany(p -> p).select(p -> p.socket)) {
try {
getTracer().info("clear socket[local=%s, remote=%s]..", Sockets.getId(socket, false),
Sockets.getId(socket, true));
Sockets.close(socket);
} catch (Exception ex) {
log.error("SocketPool clear", ex);
}
}
pool.clear();
}
}
package org.rx.socks;
import org.rx.core.Disposable;
import org.rx.core.LogWriter;
import static org.rx.core.Contract.isNull;
public abstract class Traceable extends Disposable {
private LogWriter tracer;
public LogWriter getTracer() {
return tracer;
}
public synchronized void setTracer(LogWriter tracer) {
this.tracer = isNull(tracer, new LogWriter());
}
}
package org.rx.socks;
import lombok.extern.slf4j.Slf4j;
import org.rx.beans.$;
import org.rx.beans.Tuple;
import org.rx.util.BufferSegment;
import org.rx.util.BytesSegment;
import org.rx.core.*;
import org.rx.core.AsyncTask;
import org.rx.io.MemoryStream;
import java.io.IOException;
import java.net.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.rx.beans.$.$;
import static org.rx.core.Contract.isNull;
import static org.rx.core.Contract.require;
@Slf4j
public class DirectSocket extends Traceable implements AutoCloseable {
@FunctionalInterface
public interface SocketSupplier {
Tuple<AutoCloseable, Socket> get(MemoryStream pack);
}
private static class ClientItem {
private final DirectSocket owner;
private final BufferSegment segment;
public final NetworkStream stream;
public final AutoCloseable toSock;
public final NetworkStream toStream;
public ClientItem(Socket client, DirectSocket owner) {
this.owner = owner;
segment = new BufferSegment(Contract.config.getDefaultBufferSize(), 2);
try {
stream = new NetworkStream(client, segment.alloc());
if (owner.directAddress != null) {
SocketPool.PooledSocket pooledSocket = App.retry(owner.connectRetryCount,
p -> SocketPool.Pool.borrowSocket(p.directAddress), owner);
toSock = pooledSocket;
toStream = new NetworkStream(pooledSocket.socket, segment.alloc(), false);
return;
}
if (owner.directSupplier != null) {
MemoryStream firstPack = new MemoryStream(32, true);
BytesSegment buffer = stream.getSegment();
int read;
while ((read = stream.readSegment()) > 0) {
System.out.println("----:" + Bytes.toString(buffer.array, buffer.offset, read));
firstPack.write(buffer.array, buffer.offset, read);
Tuple<AutoCloseable, Socket> toSocks;
if ((toSocks = owner.directSupplier.get(firstPack)) != null) {
toSock = toSocks.left;
firstPack.writeTo(toStream = new NetworkStream(toSocks.right, segment.alloc(), false));
return;
}
}
log.info("DirectSocket ClientState directSupplier read: {}
content: {}", read,
Bytes.toString(firstPack.toArray(), 0, firstPack.getLength()));
}
} catch (IOException ex) {
throw new SocketException((InetSocketAddress) client.getLocalSocketAddress(), ex);
}
throw new SocketException((InetSocketAddress) client.getLocalSocketAddress(),
"DirectSocket directSupplier error");
}
public void closeSocket() {
owner.getTracer().info("client close socket[%s->%s]..", Sockets.getId(stream.getSocket(), false),
Sockets.getId(stream.getSocket(), true));
owner.clients.remove(this);
stream.close();
}
public void closeToSocket(boolean pooling) {
owner.getTracer().info("client %s socket[%s->%s]..", pooling ? "pooling" : "close",
Sockets.getId(toStream.getSocket(), false), Sockets.getId(toStream.getSocket(), true));
if (pooling) {
try {
toSock.close();
} catch (Exception ex) {
ex.printStackTrace();
}
} else {
Sockets.close(toStream.getSocket());
}
}
}
public static final SocketSupplier HttpSupplier = pack -> {
String line = Bytes.readLine(pack.getBuffer());
if (line == null) {
return null;
}
InetSocketAddress authority;
try {
authority = Sockets.parseEndpoint(
new URL(line.split(" ")[1])
.getAuthority());
} catch (MalformedURLException ex) {
throw SystemException.wrap(ex);
}
SocketPool.PooledSocket pooledSocket = App.retry(2,
p -> SocketPool.Pool.borrowSocket(p),
authority);
return Tuple.of(pooledSocket, pooledSocket.socket);
};
private static final int DefaultBacklog = 128;
private static final int DefaultConnectRetryCount = 4;
private final ServerSocket server;
private final List<ClientItem> clients;
private volatile int connectRetryCount;
private InetSocketAddress directAddress;
private SocketSupplier directSupplier;
@Override
public boolean isClosed() {
return !(!super.isClosed() && !server.isClosed());
}
public InetSocketAddress getLocalAddress() {
return (InetSocketAddress) server.getLocalSocketAddress();
}
public NQuery<Tuple<Socket, Socket>> getClients() {
return NQuery.of(clients).select(p -> Tuple.of(p.stream.getSocket(), p.toStream.getSocket()));
}
public int getConnectRetryCount() {
return connectRetryCount;
}
public void setConnectRetryCount(int connectRetryCount) {
if (connectRetryCount <= 0) {
connectRetryCount = 1;
}
this.connectRetryCount = connectRetryCount;
}
public DirectSocket(int listenPort, InetSocketAddress directAddr) {
this(new InetSocketAddress(Sockets.AnyAddress, listenPort), directAddr, null);
}
public DirectSocket(InetSocketAddress listenAddr, InetSocketAddress directAddr, SocketSupplier directSupplier) {
require(listenAddr);
require(this, directAddr != null || directSupplier != null);
try {
server = new ServerSocket();
server.setReuseAddress(true);
server.bind(listenAddr, DefaultBacklog);
} catch (IOException ex) {
throw new SocketException(listenAddr, ex);
}
directAddress = directAddr;
this.directSupplier = directSupplier;
clients = Collections.synchronizedList(new ArrayList<>());
connectRetryCount = DefaultConnectRetryCount;
String taskName = String.format("DirectSocket[%s->%s]", listenAddr, isNull(directAddress, "autoAddress"));
LogWriter tracer = new LogWriter();
tracer.setPrefix(taskName + " ");
setTracer(tracer);
AsyncTask.TaskFactory.run(() -> {
getTracer().info("start..");
while (!isClosed()) {
try {
ClientItem client = new ClientItem(server.accept(), this);
clients.add(client);
onReceive(client, taskName);
} catch (IOException ex) {
log.error(taskName, ex);
}
}
close();
}, taskName);
}
@Override
protected void freeObjects() {
try {
for (ClientItem client : NQuery.of(clients)) {
client.closeSocket();
}
clients.clear();
server.close();
} catch (IOException ex) {
log.error("DirectSocket close", ex);
}
getTracer().info("stop..");
}
private void onReceive(ClientItem client, String taskName) {
AsyncTask.TaskFactory.run(() -> {
try {
int recv = client.stream.directTo(client.toStream, (p1, p2) -> {
getTracer().info("sent %s bytes from %s to %s..", p2,
Sockets.getId(client.stream.getSocket(), true),
Sockets.getId(client.toStream.getSocket(), false));
return true;
});
getTracer().info("socket[%s->%s] closing with %s", Sockets.getId(client.stream.getSocket(), false),
Sockets.getId(client.stream.getSocket(), true), recv);
} catch (SystemException ex) {
$<java.net.SocketException> out = $();
if (ex.tryGet(out, java.net.SocketException.class)) {
if (out.v.getMessage().contains("Socket closed")) {
//ignore
log.debug("DirectTo ignore socket closed");
return;
}
}
throw ex;
} finally {
client.closeSocket();
}
}, String.format("%s[networkStream]", taskName));
AsyncTask.TaskFactory.run(() -> {
int recv = NetworkStream.StreamEOF;
try {
recv = client.toStream.directTo(client.stream, (p1, p2) -> {
getTracer().info("recv %s bytes from %s to %s..", p2,
Sockets.getId(client.toStream.getSocket(), false),
Sockets.getId(client.stream.getSocket(), true));
return true;
});
getTracer().info("socket[%s->%s] closing with %s", Sockets.getId(client.toStream.getSocket(), false),
Sockets.getId(client.toStream.getSocket(), true), recv);
} catch (SystemException ex) {
$<java.net.SocketException> out = $();
if (ex.tryGet(out, java.net.SocketException.class)) {
if (out.v.getMessage().contains("Socket closed")) {
//ignore
log.debug("DirectTo ignore socket closed");
return;
}
}
throw ex;
} finally {
client.closeToSocket(recv == NetworkStream.CannotWrite);
}
}, String.format("%s[toNetworkStream]", taskName));
}
}
package org.rx.socks;
import lombok.extern.slf4j.Slf4j;
import org.rx.util.BytesSegment;
import org.rx.io.IOStream;
import java.io.IOException;
import java.net.Socket;
import static org.rx.core.Contract.require;
import static org.rx.socks.Sockets.shutdown;
@Slf4j
public final class NetworkStream extends IOStream {
@FunctionalInterface
public interface DirectPredicate {
boolean test(BytesSegment buffer, int count);
}
public static final int SocketEOF = 0;
public static final int StreamEOF = -1;
public static final int CannotWrite = -2;
private final boolean ownsSocket;
private final Socket socket;
private final BytesSegment segment;
public boolean isConnected() {
return !isClosed() && !socket.isClosed() && socket.isConnected();
}
@Override
public boolean canRead() {
return super.canRead() && checkSocket(socket, false);
}
@Override
public boolean canWrite() {
return super.canWrite() && checkSocket(socket, true);
}
private static boolean checkSocket(Socket sock, boolean isWrite) {
return !sock.isClosed() && sock.isConnected() && !(isWrite ? sock.isOutputShutdown() : sock.isInputShutdown());
}
public Socket getSocket() {
return socket;
}
public BytesSegment getSegment() {
return segment;
}
public NetworkStream(Socket socket, BytesSegment segment) throws IOException {
this(socket, segment, true);
}
public NetworkStream(Socket socket, BytesSegment segment, boolean ownsSocket) throws IOException {
super(socket.getInputStream(), socket.getOutputStream());
this.ownsSocket = ownsSocket;
this.socket = socket;
this.segment = segment;
}
@Override
protected void freeObjects() {
try {
log.info("NetworkStream freeObjects ownsSocket={} socket[{}][closed={}]", ownsSocket,
Sockets.getId(socket, false), socket.isClosed());
if (ownsSocket) {
//super.freeObjects(); Ignore this!!
Sockets.close(socket, 1);
}
} finally {
segment.close();
}
}
int readSegment() {
return read(segment.array, segment.offset, segment.count);
}
void writeSegment(int count) {
write(segment.array, segment.offset, count);
}
public int directTo(NetworkStream to, DirectPredicate onEach) {
checkNotClosed();
require(to);
int recv = StreamEOF;
while (canRead() && (recv = read(segment.array, segment.offset, segment.count)) >= -1) {
if (recv <= 0) {
if (ownsSocket) {
log.debug("DirectTo read {} flag and shutdown send", recv);
shutdown(socket, 1);
}
break;
}
if (!to.canWrite()) {
log.debug("DirectTo read {} bytes and can't write", recv);
recv = CannotWrite;
break;
}
to.write(segment.array, segment.offset, recv);
if (onEach != null && !onEach.test(segment, recv)) {
recv = StreamEOF;
break;
}
}
if (to.canWrite()) {
to.flush();
}
return recv;
}
}
package org.rx.socks;
import org.rx.core.SystemException;
import java.net.InetSocketAddress;
public class SocketException extends SystemException {
private InetSocketAddress localAddress;
public InetSocketAddress getLocalAddress() {
return localAddress;
}
public SocketException(InetSocketAddress localAddress, Exception ex) {
super(ex);
this.localAddress = localAddress;
}
public SocketException(InetSocketAddress localAddress, String msg) {
super(msg);
this.localAddress = localAddress;
}
}
package org.rx.socks;
import java.io.IOException;
import java.net.*;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import org.rx.core.Strings;
import org.rx.core.SystemException;
import org.rx.core.WeakCache;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Properties;
import java.util.function.Function;
import static org.rx.core.Contract.require;
import static org.rx.core.Contract.values;
public final class Sockets {
public static final InetAddress LocalAddress, AnyAddress;
static {
LocalAddress = InetAddress.getLoopbackAddress();
try {
AnyAddress = InetAddress.getByName("0.0.0.0");
} catch (Exception ex) {
throw SystemException.wrap(ex);
}
}
public InetAddress[] getAddresses(String host) {
return (InetAddress[]) WeakCache.getOrStore("Sockets.getAddresses", values(host), p -> {
try {
return InetAddress.getAllByName(host);
} catch (UnknownHostException ex) {
throw SystemException.wrap(ex);
}
});
}
public static InetSocketAddress getAnyEndpoint(int port) {
return new InetSocketAddress(AnyAddress, port);
}
public static InetSocketAddress parseEndpoint(String endpoint) {
require(endpoint);
String[] arr = Strings.split(endpoint, ":", 2);
return new InetSocketAddress(arr[0], Integer.parseInt(arr[1]));
}
public static void writeAndFlush(Channel channel, Object... packs) {
require(channel);
channel.eventLoop().execute(() -> {
for (Object pack : packs) {
channel.write(pack);
}
channel.flush();
});
}
public static EventLoopGroup bossEventLoop() {
return eventLoopGroup(1);
}
public static EventLoopGroup workEventLoop() {
return eventLoopGroup(0);
}
public static EventLoopGroup eventLoopGroup(int threadAmount) {
return Epoll.isAvailable() ? new EpollEventLoopGroup(threadAmount) : new NioEventLoopGroup(threadAmount); //NioEventLoopGroup(0, TaskFactory.getExecutor());
}
public static Bootstrap bootstrap() {
return bootstrap(getChannelClass());
}
public static Bootstrap bootstrap(Class<? extends Channel> channelClass) {
require(channelClass);
return new Bootstrap().group(channelClass.getName().startsWith("Epoll") ? new EpollEventLoopGroup() : new NioEventLoopGroup()).channel(channelClass);
}
public static Bootstrap bootstrap(Channel channel) {
require(channel);
return new Bootstrap().group(channel.eventLoop()).channel(channel.getClass());
}
public static Class<? extends ServerChannel> getServerChannelClass() {
return Epoll.isAvailable() ? EpollServerSocketChannel.class : NioServerSocketChannel.class;
}
public static Class<? extends Channel> getChannelClass() {
return Epoll.isAvailable() ? EpollSocketChannel.class : NioSocketChannel.class;
}
public static void closeOnFlushed(Channel channel) {
require(channel);
if (!channel.isActive()) {
return;
}
channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
}
public static void close(Socket socket) {
close(socket, 1 | 2);
}
public static void close(Socket socket, int flags) {
require(socket);
if (!socket.isClosed()) {
shutdown(socket, flags);
try {
socket.setSoLinger(true, 2);
socket.close();
} catch (IOException ex) {
throw SystemException.wrap(ex);
}
}
}
/**
* @param socket
* @param flags Send=1, Receive=2
*/
public static void shutdown(Socket socket, int flags) {
require(socket);
if (!socket.isClosed() && socket.isConnected()) {
try {
if ((flags & 1) == 1 && !socket.isOutputShutdown()) {
socket.shutdownOutput();
}
if ((flags & 2) == 2 && !socket.isInputShutdown()) {
socket.shutdownInput();
}
} catch (IOException ex) {
throw SystemException.wrap(ex);
}
}
}
public static String getId(Socket sock, boolean isRemote) {
require(sock);
InetSocketAddress addr = (InetSocketAddress) (isRemote ? sock.getRemoteSocketAddress()
: sock.getLocalSocketAddress());
return addr.getHostString() + ":" + addr.getPort();
}
public static <T> T httpProxyInvoke(String proxyAddr, Function<String, T> func) {
setHttpProxy(proxyAddr);
try {
return func.apply(proxyAddr);
} finally {
clearHttpProxy();
}
}
public static void setHttpProxy(String proxyAddr) {
setHttpProxy(proxyAddr, null, null, null);
}
public static void setHttpProxy(String proxyAddr, List<String> nonProxyHosts, String userName, String password) {
InetSocketAddress ipe = parseEndpoint(proxyAddr);
Properties prop = System.getProperties();
prop.setProperty("http.proxyHost", ipe.getAddress().getHostAddress());
prop.setProperty("http.proxyPort", String.valueOf(ipe.getPort()));
prop.setProperty("https.proxyHost", ipe.getAddress().getHostAddress());
prop.setProperty("https.proxyPort", String.valueOf(ipe.getPort()));
if (!CollectionUtils.isEmpty(nonProxyHosts)) {
//如"localhost|192.168.0.*"
prop.setProperty("http.nonProxyHosts", String.join("|", nonProxyHosts));
}
if (userName != null && password != null) {
Authenticator.setDefault(new UserAuthenticator(userName, password));
}
}
public static void clearHttpProxy() {
System.clearProperty("http.proxyHost");
System.clearProperty("http.proxyPort");
System.clearProperty("https.proxyHost");
System.clearProperty("https.proxyPort");
System.clearProperty("http.nonProxyHosts");
}
static class UserAuthenticator extends Authenticator {
private String userName;
private String password;
public UserAuthenticator(String userName, String password) {
this.userName = userName;
this.password = password;
}
protected PasswordAuthentication getPasswordAuthentication() {
return new PasswordAuthentication(userName, password.toCharArray());
}
}
}