Skip to content

Commit 6f59e89

Browse files
authored
Android backend used by method
Differential Revision: D74913386 Pull Request resolved: #10934
1 parent f8218d1 commit 6f59e89

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/ModuleE2ETest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
package org.pytorch.executorch;
1010

11+
import static org.junit.Assert.assertArrayEquals;
1112
import static org.junit.Assert.assertEquals;
1213
import static org.junit.Assert.assertTrue;
1314
import static org.junit.Assert.assertFalse;
@@ -89,6 +90,18 @@ public void testClassification(String filePath) throws IOException, URISyntaxExc
8990
assertEquals(bananaClass, argmax(scores));
9091
}
9192

93+
@Test
94+
public void testXnnpackBackendRequired() throws IOException, URISyntaxException {
95+
File pteFile = new File(getTestFilePath("/mv3_xnnpack_fp32.pte"));
96+
InputStream inputStream = getClass().getResourceAsStream("/mv3_xnnpack_fp32.pte");
97+
FileUtils.copyInputStreamToFile(inputStream, pteFile);
98+
inputStream.close();
99+
100+
Module module = Module.load(getTestFilePath("/mv3_xnnpack_fp32.pte"));
101+
String[] expectedBackends = new String[] {"XnnpackBackend"};
102+
assertArrayEquals(expectedBackends, module.getUsedBackends("forward"));
103+
}
104+
92105
@Test
93106
public void testMv2Fp32() throws IOException, URISyntaxException {
94107
testClassification("/mv2_xnnpack_fp32.pte");

extension/android/executorch_android/src/main/java/org/pytorch/executorch/Module.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,16 @@ public int loadMethod(String methodName) {
137137
}
138138
}
139139

140+
/**
141+
* Returns the names of the methods in a certain method.
142+
*
143+
* @param methodName method name to query
144+
* @return an array of backend name
145+
*/
146+
public String[] getUsedBackends(String methodName) {
147+
return mNativePeer.getUsedBackends(methodName);
148+
}
149+
140150
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
141151
public String[] readLogBuffer() {
142152
return mNativePeer.readLogBuffer();

extension/android/executorch_android/src/main/java/org/pytorch/executorch/NativePeer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ public void resetNative() {
5555
@DoNotStrip
5656
public native int loadMethod(String methodName);
5757

58+
/** Return the list of backends used by a method */
59+
@DoNotStrip
60+
public native String[] getUsedBackends(String methodName);
61+
5862
/** Retrieve the in-memory log buffer, containing the most recent ExecuTorch log entries. */
5963
@DoNotStrip
6064
public native String[] readLogBuffer();

extension/android/jni/jni_layer.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <sstream>
1414
#include <string>
1515
#include <unordered_map>
16+
#include <unordered_set>
1617
#include <vector>
1718

1819
#include "jni_layer_constants.h"
@@ -395,13 +396,34 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
395396
#endif
396397
}
397398

399+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> getUsedBackends(
400+
facebook::jni::alias_ref<jstring> methodName) {
401+
auto methodMeta = module_->method_meta(methodName->toStdString()).get();
402+
std::unordered_set<std::string> backends;
403+
for (auto i = 0; i < methodMeta.num_backends(); i++) {
404+
backends.insert(methodMeta.get_backend_name(i).get());
405+
}
406+
407+
facebook::jni::local_ref<facebook::jni::JArrayClass<jstring>> ret =
408+
facebook::jni::JArrayClass<jstring>::newArray(backends.size());
409+
int i = 0;
410+
for (auto s : backends) {
411+
facebook::jni::local_ref<facebook::jni::JString> backend_name =
412+
facebook::jni::make_jstring(s.c_str());
413+
(*ret)[i] = backend_name;
414+
i++;
415+
}
416+
return ret;
417+
}
418+
398419
static void registerNatives() {
399420
registerHybrid({
400421
makeNativeMethod("initHybrid", ExecuTorchJni::initHybrid),
401422
makeNativeMethod("forward", ExecuTorchJni::forward),
402423
makeNativeMethod("execute", ExecuTorchJni::execute),
403424
makeNativeMethod("loadMethod", ExecuTorchJni::load_method),
404425
makeNativeMethod("readLogBuffer", ExecuTorchJni::readLogBuffer),
426+
makeNativeMethod("getUsedBackends", ExecuTorchJni::getUsedBackends),
405427
});
406428
}
407429
};

0 commit comments

Comments
 (0)