Skip to content

Commit 4944d54

Browse files
authored
[apache#1751][0.9] improvement: support gluten (apache#1753)
* support gluten * optimize * fix bug * nit * fix spotless * nit * nit * fix bug * optimize * optimize * nit * nit * nit * nit * nit * Update RssShuffleWriter.java
1 parent a6a715f commit 4944d54

File tree

4 files changed

+44
-42
lines changed

4 files changed

+44
-42
lines changed

client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java

+13-11
Original file line numberDiff line numberDiff line change
@@ -475,15 +475,6 @@ public <K, V> ShuffleWriter<K, V> getWriter(
475475

476476
int shuffleId = rssHandle.getShuffleId();
477477
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
478-
ShuffleHandleInfo shuffleHandleInfo;
479-
if (shuffleManagerRpcServiceEnabled) {
480-
// Get the ShuffleServer list from the Driver based on the shuffleId
481-
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
482-
} else {
483-
shuffleHandleInfo =
484-
new ShuffleHandleInfo(
485-
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
486-
}
487478
ShuffleWriteMetrics writeMetrics = context.taskMetrics().shuffleWriteMetrics();
488479
return new RssShuffleWriter<>(
489480
rssHandle.getAppId(),
@@ -496,8 +487,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
496487
shuffleWriteClient,
497488
rssHandle,
498489
this::markFailedTask,
499-
context,
500-
shuffleHandleInfo);
490+
context);
501491
} else {
502492
throw new RssException("Unexpected ShuffleHandle:" + handle.getClass().getName());
503493
}
@@ -806,6 +796,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
806796
.createShuffleManagerClient(ClientType.GRPC, host, port);
807797
}
808798

799+
public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
800+
if (shuffleManagerRpcServiceEnabled) {
801+
// Get the ShuffleServer list from the Driver based on the shuffleId
802+
return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
803+
} else {
804+
return new ShuffleHandleInfo(
805+
rssHandle.getShuffleId(),
806+
rssHandle.getPartitionToServers(),
807+
rssHandle.getRemoteStorage());
808+
}
809+
}
810+
809811
/**
810812
* Get the ShuffleServer list from the Driver based on the shuffleId
811813
*

client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ public RssShuffleWriter(
188188
ShuffleWriteClient shuffleWriteClient,
189189
RssShuffleHandle<K, V, C> rssHandle,
190190
Function<String, Boolean> taskFailureCallback,
191-
TaskContext context,
192-
ShuffleHandleInfo shuffleHandleInfo) {
191+
TaskContext context) {
193192
this(
194193
appId,
195194
shuffleId,
@@ -201,9 +200,10 @@ public RssShuffleWriter(
201200
shuffleWriteClient,
202201
rssHandle,
203202
taskFailureCallback,
204-
shuffleHandleInfo,
203+
shuffleManager.getShuffleHandleInfo(rssHandle),
205204
context);
206205
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
206+
ShuffleHandleInfo shuffleHandleInfo = shuffleManager.getShuffleHandleInfo(rssHandle);
207207
final WriteBufferManager bufferManager =
208208
new WriteBufferManager(
209209
shuffleId,

client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java

+14-23
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ public class RssShuffleManager extends RssShuffleManagerBase {
141141
private boolean rssResubmitStage;
142142

143143
private boolean taskBlockSendFailureRetryEnabled;
144-
145144
private boolean shuffleManagerRpcServiceEnabled;
146145
/** A list of shuffleServer for Write failures */
147146
private Set<String> failuresShuffleServerIds;
@@ -514,15 +513,6 @@ public <K, V> ShuffleWriter<K, V> getWriter(
514513
} else {
515514
writeMetrics = context.taskMetrics().shuffleWriteMetrics();
516515
}
517-
ShuffleHandleInfo shuffleHandleInfo;
518-
if (shuffleManagerRpcServiceEnabled) {
519-
// Get the ShuffleServer list from the Driver based on the shuffleId
520-
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
521-
} else {
522-
shuffleHandleInfo =
523-
new ShuffleHandleInfo(
524-
shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
525-
}
526516
String taskId = "" + context.taskAttemptId() + "_" + context.attemptNumber();
527517
LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), rssHandle.getShuffleId());
528518
return new RssShuffleWriter<>(
@@ -536,8 +526,7 @@ public <K, V> ShuffleWriter<K, V> getWriter(
536526
shuffleWriteClient,
537527
rssHandle,
538528
this::markFailedTask,
539-
context,
540-
shuffleHandleInfo);
529+
context);
541530
}
542531

543532
@Override
@@ -656,17 +645,7 @@ public <K, C> ShuffleReader<K, C> getReaderImpl(
656645
RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) handle;
657646
final int partitionNum = rssShuffleHandle.getDependency().partitioner().numPartitions();
658647
int shuffleId = rssShuffleHandle.getShuffleId();
659-
ShuffleHandleInfo shuffleHandleInfo;
660-
if (shuffleManagerRpcServiceEnabled) {
661-
// Get the ShuffleServer list from the Driver based on the shuffleId
662-
shuffleHandleInfo = getRemoteShuffleHandleInfo(shuffleId);
663-
} else {
664-
shuffleHandleInfo =
665-
new ShuffleHandleInfo(
666-
shuffleId,
667-
rssShuffleHandle.getPartitionToServers(),
668-
rssShuffleHandle.getRemoteStorage());
669-
}
648+
ShuffleHandleInfo shuffleHandleInfo = getShuffleHandleInfo(rssShuffleHandle);
670649
Map<Integer, List<ShuffleServerInfo>> allPartitionToServers =
671650
shuffleHandleInfo.getPartitionToServers();
672651
Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers =
@@ -1101,6 +1080,18 @@ private ShuffleManagerClient createShuffleManagerClient(String host, int port) {
11011080
.createShuffleManagerClient(ClientType.GRPC, host, port);
11021081
}
11031082

1083+
public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
1084+
if (shuffleManagerRpcServiceEnabled) {
1085+
// Get the ShuffleServer list from the Driver based on the shuffleId
1086+
return getRemoteShuffleHandleInfo(rssHandle.getShuffleId());
1087+
} else {
1088+
return new ShuffleHandleInfo(
1089+
rssHandle.getShuffleId(),
1090+
rssHandle.getPartitionToServers(),
1091+
rssHandle.getRemoteStorage());
1092+
}
1093+
}
1094+
11041095
/**
11051096
* Get the ShuffleServer list from the Driver based on the shuffleId
11061097
*

client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java

+14-5
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
9595

9696
private final String appId;
9797
private final int shuffleId;
98+
private final ShuffleHandleInfo shuffleHandleInfo;
9899
private WriteBufferManager bufferManager;
99100
private String taskId;
100101
private final int numMaps;
@@ -110,7 +111,8 @@ public class RssShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
110111
private final ShuffleWriteClient shuffleWriteClient;
111112
private final Set<ShuffleServerInfo> shuffleServersForData;
112113
private final long[] partitionLengths;
113-
private final boolean isMemoryShuffleEnabled;
114+
// Gluten needs this variable
115+
protected final boolean isMemoryShuffleEnabled;
114116
private final Function<String, Boolean> taskFailureCallback;
115117
private final Set<Long> blockIds = Sets.newConcurrentHashSet();
116118
private TaskContext taskContext;
@@ -195,6 +197,7 @@ private RssShuffleWriter(
195197
this.isMemoryShuffleEnabled =
196198
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
197199
this.taskFailureCallback = taskFailureCallback;
200+
this.shuffleHandleInfo = shuffleHandleInfo;
198201
this.taskContext = context;
199202
this.sparkConf = sparkConf;
200203
this.blockFailSentRetryEnabled =
@@ -204,6 +207,7 @@ private RssShuffleWriter(
204207
RssClientConf.RSS_CLIENT_BLOCK_SEND_FAILURE_RETRY_ENABLED.defaultValue());
205208
}
206209

210+
// Gluten needs this constructor
207211
public RssShuffleWriter(
208212
String appId,
209213
int shuffleId,
@@ -215,8 +219,7 @@ public RssShuffleWriter(
215219
ShuffleWriteClient shuffleWriteClient,
216220
RssShuffleHandle<K, V, C> rssHandle,
217221
Function<String, Boolean> taskFailureCallback,
218-
TaskContext context,
219-
ShuffleHandleInfo shuffleHandleInfo) {
222+
TaskContext context) {
220223
this(
221224
appId,
222225
shuffleId,
@@ -228,7 +231,7 @@ public RssShuffleWriter(
228231
shuffleWriteClient,
229232
rssHandle,
230233
taskFailureCallback,
231-
shuffleHandleInfo,
234+
shuffleManager.getShuffleHandleInfo(rssHandle),
232235
context);
233236
BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
234237
final WriteBufferManager bufferManager =
@@ -264,7 +267,8 @@ public void write(Iterator<Product2<K, V>> records) {
264267
}
265268
}
266269

267-
private void writeImpl(Iterator<Product2<K, V>> records) {
270+
// Gluten needs this method.
271+
protected void writeImpl(Iterator<Product2<K, V>> records) {
268272
List<ShuffleBlockInfo> shuffleBlockInfos;
269273
boolean isCombine = shuffleDependency.mapSideCombine();
270274
Function1<V, C> createCombiner = null;
@@ -322,6 +326,11 @@ private void writeImpl(Iterator<Product2<K, V>> records) {
322326
+ bufferManager.getManagerCostInfo());
323327
}
324328

329+
// Gluten needs this method
330+
protected void internalCheckBlockSendResult() {
331+
this.checkBlockSendResult(this.blockIds);
332+
}
333+
325334
private void checkSentRecordCount(long recordCount) {
326335
if (recordCount != bufferManager.getRecordCount()) {
327336
String errorMsg =

0 commit comments

Comments
 (0)