Skip to content

Commit 6fae7b8

Browse files
authored
SPARKNLP-873 Handling vocabulary type from Python side (#13908)
1 parent 95217a7 commit 6fae7b8

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala

+10-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ import org.apache.spark.ml.param.{IntArrayParam, IntParam, Param, StringArrayPar
3737
import org.apache.spark.ml.util.Identifiable
3838
import org.apache.spark.sql.SparkSession
3939

40+
import scala.jdk.CollectionConverters.asScalaBufferConverter
41+
4042
/** MarianTransformer: Fast Neural Machine Translation
4143
*
4244
* Marian is an efficient, free Neural Machine Translation framework written in pure C++ with
@@ -317,6 +319,13 @@ class MarianTransformer(override val uid: String)
317319
/** @group setParam * */
318320
def getModelIfNotSet: Marian = _model.get.value
319321

322+
def getVocabulary: Array[String] = {
323+
if ($(vocabulary).isInstanceOf[java.util.ArrayList[String]]) {
324+
val arrayListValue = $(vocabulary).asInstanceOf[java.util.ArrayList[String]]
325+
arrayListValue.asScala.toArray
326+
} else $(vocabulary)
327+
}
328+
320329
setDefault(
321330
maxInputLength -> 40,
322331
maxOutputLength -> 40,
@@ -349,7 +358,7 @@ class MarianTransformer(override val uid: String)
349358
sentences = allAnnotations.map(_._1),
350359
maxInputLength = $(maxInputLength),
351360
maxOutputLength = $(maxOutputLength),
352-
vocabs = $(vocabulary),
361+
vocabs = getVocabulary,
353362
langId = $(langId),
354363
batchSize = $(batchSize),
355364
ignoreTokenIds = $(ignoreTokenIds))
@@ -441,7 +450,6 @@ trait ReadMarianMTDLModel extends ReadTensorflowModel with ReadSentencePieceMode
441450
addReader(readModel)
442451

443452
def loadSavedModel(modelPath: String, spark: SparkSession): MarianTransformer = {
444-
445453
val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
446454

447455
val sppSrc = loadSentencePieceAsset(localModelPath, "source.spm")

0 commit comments

Comments
 (0)