diff --git a/framework/src/main/java/org/tron/core/net/service/relay/RelayService.java b/framework/src/main/java/org/tron/core/net/service/relay/RelayService.java index dfc5f2e89da..107d086c1b7 100644 --- a/framework/src/main/java/org/tron/core/net/service/relay/RelayService.java +++ b/framework/src/main/java/org/tron/core/net/service/relay/RelayService.java @@ -40,6 +40,8 @@ @Component public class RelayService { + private static final int MAX_PEER_COUNT_PER_ADDRESS = 2; + @Autowired private ChainBaseManager chainBaseManager; @@ -139,6 +141,13 @@ public boolean checkHelloMessage(HelloMessage message, Channel channel) { return false; } + if (getPeerCountByAddress(msg.getAddress()) >= MAX_PEER_COUNT_PER_ADDRESS) { + logger.warn("HelloMessage from {}, the number of peers of {} exceeds 2.", + channel.getInetAddress(), + ByteArray.toHexString(msg.getAddress().toByteArray())); + return false; + } + boolean flag; try { Sha256Hash hash = Sha256Hash.of(CommonParameter @@ -164,6 +173,12 @@ public boolean checkHelloMessage(HelloMessage message, Channel channel) { } } + private long getPeerCountByAddress(ByteString address) { + return tronNetDelegate.getActivePeer().stream() + .filter(peer -> peer.getAddress() != null && peer.getAddress().equals(address)) + .count(); + } + private boolean isActiveWitness() { return parameter.isWitness() && keySize > 0 diff --git a/framework/src/test/java/org/tron/core/net/services/RelayServiceTest.java b/framework/src/test/java/org/tron/core/net/services/RelayServiceTest.java index 777472bdc35..5e22e538e80 100644 --- a/framework/src/test/java/org/tron/core/net/services/RelayServiceTest.java +++ b/framework/src/test/java/org/tron/core/net/services/RelayServiceTest.java @@ -1,31 +1,45 @@ package org.tron.core.net.services; +import static org.mockito.Mockito.mock; + import com.google.common.collect.Lists; import com.google.protobuf.ByteString; +import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Set; import javax.annotation.Resource; + +import lombok.extern.slf4j.Slf4j; import org.bouncycastle.util.encoders.Hex; import org.junit.Assert; -import org.junit.Before; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.Mockito; +import org.springframework.context.ApplicationContext; import org.tron.common.BaseTest; import org.tron.common.utils.ReflectUtils; +import org.tron.core.ChainBaseManager; import org.tron.core.Constant; import org.tron.core.capsule.BlockCapsule; import org.tron.core.capsule.WitnessCapsule; import org.tron.core.config.args.Args; import org.tron.core.net.P2pEventHandlerImpl; import org.tron.core.net.message.adv.BlockMessage; +import org.tron.core.net.message.handshake.HelloMessage; import org.tron.core.net.peer.Item; import org.tron.core.net.peer.PeerConnection; +import org.tron.core.net.peer.PeerManager; import org.tron.core.net.service.relay.RelayService; +import org.tron.p2p.connection.Channel; +import org.tron.p2p.discover.Node; +import org.tron.p2p.utils.NetUtil; import org.tron.protos.Protocol; +@Slf4j(topic = "net") public class RelayServiceTest extends BaseTest { @Resource @@ -49,6 +63,7 @@ public void test() throws Exception { initWitness(); testGetNextWitnesses(); testBroadcast(); + testCheckHelloMessage(); } private void initWitness() { @@ -119,4 +134,38 @@ private ByteString getFromHexString(String s) { return ByteString.copyFrom(Hex.decode(s)); } -} + private void testCheckHelloMessage() { + ByteString address = getFromHexString("A04711BF7AFBDF44557DEFBDF4C4E7AA6138C6331F"); + InetSocketAddress a1 = new InetSocketAddress("127.0.0.1", 10001); + Node node = new Node(NetUtil.getNodeId(), a1.getAddress().getHostAddress(), + null, a1.getPort()); + HelloMessage helloMessage = new HelloMessage(node, System.currentTimeMillis(), + ChainBaseManager.getChainBaseManager()); + helloMessage.setHelloMessage(helloMessage.getHelloMessage().toBuilder() + .setAddress(address).build()); + Channel c1 = mock(Channel.class); + Mockito.when(c1.getInetSocketAddress()).thenReturn(a1); + Mockito.when(c1.getInetAddress()).thenReturn(a1.getAddress()); + Channel c2 = mock(Channel.class); + Mockito.when(c2.getInetSocketAddress()).thenReturn(a1); + Mockito.when(c2.getInetAddress()).thenReturn(a1.getAddress()); + Args.getInstance().fastForward = true; + ApplicationContext ctx = (ApplicationContext) ReflectUtils.getFieldObject(p2pEventHandler, + "ctx"); + PeerConnection peer1 = PeerManager.add(ctx, c1); + assert peer1 != null; + peer1.setAddress(address); + PeerConnection peer2 = PeerManager.add(ctx, c2); + assert peer2 != null; + peer2.setAddress(address); + try { + Field field = service.getClass().getDeclaredField("witnessScheduleStore"); + field.setAccessible(true); + field.set(service, chainBaseManager.getWitnessScheduleStore()); + boolean res = service.checkHelloMessage(helloMessage, c1); + Assert.assertFalse(res); + } catch (Exception e) { + logger.info("{}", e.getMessage()); + } + } +} \ No newline at end of file