Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ShardingSphereMetaDataIdentifier on ShardingSphereMetaData.databases #33945

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
import org.apache.shardingsphere.broadcast.constant.BroadcastOrder;
import org.apache.shardingsphere.broadcast.rule.attribute.BroadcastDataNodeRuleAttribute;
import org.apache.shardingsphere.broadcast.rule.attribute.BroadcastTableNamesRuleAttribute;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.rule.attribute.RuleAttributes;
import org.apache.shardingsphere.infra.rule.attribute.datasource.DataSourceMapperRuleAttribute;
import org.apache.shardingsphere.infra.rule.scope.DatabaseRule;

import javax.sql.DataSource;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -57,7 +57,7 @@ public BroadcastRule(final BroadcastRuleConfiguration config, final Map<String,
}

private Collection<String> getAggregatedDataSourceNames(final Map<String, DataSource> dataSources, final Collection<ShardingSphereRule> builtRules) {
Collection<String> result = new LinkedList<>(dataSources.keySet());
Collection<String> result = new CaseInsensitiveSet<>(dataSources.keySet());
for (ShardingSphereRule each : builtRules) {
Optional<DataSourceMapperRuleAttribute> ruleAttribute = each.getAttributes().findAttribute(DataSourceMapperRuleAttribute.class);
if (ruleAttribute.isPresent()) {
Expand All @@ -68,7 +68,7 @@ private Collection<String> getAggregatedDataSourceNames(final Map<String, DataSo
}

private Collection<String> getAggregatedDataSourceNames(final Collection<String> dataSourceNames, final DataSourceMapperRuleAttribute ruleAttribute) {
Collection<String> result = new LinkedList<>();
Collection<String> result = new CaseInsensitiveSet<>();
for (Entry<String, Collection<String>> entry : ruleAttribute.getDataSourceMapper().entrySet()) {
for (String each : entry.getValue()) {
if (dataSourceNames.contains(each)) {
Expand All @@ -89,6 +89,7 @@ private Collection<String> getAggregatedDataSourceNames(final Collection<String>
* @param logicTableNames all logic table names
* @return broadcast table names
*/
@HighFrequencyInvocation
public Collection<String> getBroadcastTableNames(final Collection<String> logicTableNames) {
Collection<String> result = new CaseInsensitiveSet<>();
for (String each : logicTableNames) {
Expand All @@ -99,16 +100,6 @@ public Collection<String> getBroadcastTableNames(final Collection<String> logicT
return result;
}

/**
* Judge whether logic tables are all broadcast tables.
*
* @param logicTableNames logic table names
* @return logic tables are all broadcast tables or not
*/
public boolean isAllBroadcastTables(final Collection<String> logicTableNames) {
return !logicTableNames.isEmpty() && tables.containsAll(logicTableNames);
}

@Override
public int getOrder() {
return BroadcastOrder.ORDER;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class BroadcastRouteEngineFactoryTest {
@BeforeEach
void setUp() {
when(rule.getBroadcastTableNames(Collections.singleton("foo_tbl"))).thenReturn(Collections.singleton("foo_tbl"));
when(rule.isAllBroadcastTables(Collections.singleton("foo_tbl"))).thenReturn(true);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand All @@ -45,7 +43,7 @@ class BroadcastRuleTest {
void assertGetDataSourceNames() {
BroadcastRule rule = new BroadcastRule(
new BroadcastRuleConfiguration(Collections.emptyList()), mockDataSourceMap(), Arrays.asList(mockBuiltRule(), mock(ShardingSphereRule.class, RETURNS_DEEP_STUBS)));
assertThat(rule.getDataSourceNames(), is(Collections.singletonList("foo_ds")));
assertThat(rule.getDataSourceNames(), is(Collections.singleton("foo_ds")));
}

private static Map<String, DataSource> mockDataSourceMap() {
Expand All @@ -71,20 +69,4 @@ void assertGetBroadcastTableNames() {
BroadcastRule rule = new BroadcastRule(new BroadcastRuleConfiguration(Collections.singleton("foo_tbl")), Collections.emptyMap(), Collections.emptyList());
assertThat(rule.getBroadcastTableNames(Arrays.asList("foo_tbl", "bar_tbl")), is(Collections.singleton("foo_tbl")));
}

@Test
void assertIsAllBroadcastTables() {
BroadcastRule rule = new BroadcastRule(new BroadcastRuleConfiguration(Collections.singleton("foo_tbl")), Collections.emptyMap(), Collections.emptyList());
assertFalse(rule.isAllBroadcastTables(Collections.emptyList()));
assertTrue(rule.isAllBroadcastTables(Collections.singleton("foo_tbl")));
assertFalse(rule.isAllBroadcastTables(Arrays.asList("foo_tbl", "bar_tbl")));
}

@Test
void assertIsAllBroadcastTablesWhenEmptyRule() {
BroadcastRule rule = new BroadcastRule(new BroadcastRuleConfiguration(Collections.emptyList()), Collections.emptyMap(), Collections.emptyList());
assertFalse(rule.isAllBroadcastTables(Collections.emptyList()));
assertFalse(rule.isAllBroadcastTables(Collections.singleton("foo_tbl")));
assertFalse(rule.isAllBroadcastTables(Arrays.asList("foo_tbl", "bar_tbl")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.shardingsphere.infra.metadata;

import com.cedarsoftware.util.CaseInsensitiveMap;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.SneakyThrows;
Expand All @@ -28,6 +27,7 @@
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.metadata.database.resource.ResourceMetaData;
import org.apache.shardingsphere.infra.metadata.database.rule.RuleMetaData;
import org.apache.shardingsphere.infra.metadata.identifier.ShardingSphereMetaDataIdentifier;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.rule.attribute.datasource.StaticDataSourceRuleAttribute;
import org.apache.shardingsphere.infra.rule.scope.GlobalRule;
Expand All @@ -48,7 +48,7 @@
public final class ShardingSphereMetaData {

@Getter(AccessLevel.NONE)
private final Map<String, ShardingSphereDatabase> databases;
private final Map<ShardingSphereMetaDataIdentifier, ShardingSphereDatabase> databases;

private final ResourceMetaData globalResourceMetaData;

Expand All @@ -64,7 +64,7 @@ public ShardingSphereMetaData() {

public ShardingSphereMetaData(final Collection<ShardingSphereDatabase> databases, final ResourceMetaData globalResourceMetaData,
final RuleMetaData globalRuleMetaData, final ConfigurationProperties props) {
this.databases = new CaseInsensitiveMap<>(databases.stream().collect(Collectors.toMap(ShardingSphereDatabase::getName, each -> each)), new ConcurrentHashMap<>());
this.databases = new ConcurrentHashMap<>(databases.stream().collect(Collectors.toMap(each -> new ShardingSphereMetaDataIdentifier(each.getName()), each -> each)));
this.globalResourceMetaData = globalResourceMetaData;
this.globalRuleMetaData = globalRuleMetaData;
this.props = props;
Expand All @@ -87,7 +87,7 @@ public Collection<ShardingSphereDatabase> getAllDatabases() {
* @return contains database from meta data or not
*/
public boolean containsDatabase(final String databaseName) {
return databases.containsKey(databaseName);
return databases.containsKey(new ShardingSphereMetaDataIdentifier(databaseName));
}

/**
Expand All @@ -97,7 +97,7 @@ public boolean containsDatabase(final String databaseName) {
* @return meta data database
*/
public ShardingSphereDatabase getDatabase(final String databaseName) {
return databases.get(databaseName);
return databases.get(new ShardingSphereMetaDataIdentifier(databaseName));
}

/**
Expand All @@ -109,7 +109,7 @@ public ShardingSphereDatabase getDatabase(final String databaseName) {
*/
public void addDatabase(final String databaseName, final DatabaseType protocolType, final ConfigurationProperties props) {
ShardingSphereDatabase database = ShardingSphereDatabase.create(databaseName, protocolType, props);
databases.put(database.getName(), database);
databases.put(new ShardingSphereMetaDataIdentifier(database.getName()), database);
globalRuleMetaData.getRules().forEach(each -> ((GlobalRule) each).refresh(databases.values(), GlobalRuleChangedType.DATABASE_CHANGED));
}

Expand All @@ -119,7 +119,7 @@ public void addDatabase(final String databaseName, final DatabaseType protocolTy
* @param database database
*/
public void putDatabase(final ShardingSphereDatabase database) {
databases.put(database.getName(), database);
databases.put(new ShardingSphereMetaDataIdentifier(database.getName()), database);
}

/**
Expand All @@ -128,7 +128,7 @@ public void putDatabase(final ShardingSphereDatabase database) {
* @param databaseName database name
*/
public void dropDatabase(final String databaseName) {
cleanResources(databases.remove(databaseName));
cleanResources(databases.remove(new ShardingSphereMetaDataIdentifier(databaseName)));
}

@SneakyThrows(Exception.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ private String generateExportData(final ShardingSphereMetaData metaData) {
private Map<String, String> getDatabases(final ProxyContext proxyContext) {
Collection<String> databaseNames = proxyContext.getAllDatabaseNames();
Map<String, String> result = new LinkedHashMap<>(databaseNames.size(), 1F);
databaseNames.forEach(each -> {
for (String each : databaseNames) {
ShardingSphereDatabase database = proxyContext.getContextManager().getDatabase(each);
if (database.getResourceMetaData().getAllInstanceDataSourceNames().isEmpty()) {
return;
continue;
}
result.put(each, ExportUtils.generateExportDatabaseData(database));
});
}
return result;
}

Expand Down
Loading