@@ -37,6 +37,8 @@ import org.apache.spark.ml.param.{IntArrayParam, IntParam, Param, StringArrayPar
37
37
import org .apache .spark .ml .util .Identifiable
38
38
import org .apache .spark .sql .SparkSession
39
39
40
+ import scala .jdk .CollectionConverters .asScalaBufferConverter
41
+
40
42
/** MarianTransformer: Fast Neural Machine Translation
41
43
*
42
44
* Marian is an efficient, free Neural Machine Translation framework written in pure C++ with
@@ -317,6 +319,13 @@ class MarianTransformer(override val uid: String)
317
319
/** @group setParam * */
318
320
def getModelIfNotSet : Marian = _model.get.value
319
321
322
+ def getVocabulary : Array [String ] = {
323
+ if ($(vocabulary).isInstanceOf [java.util.ArrayList [String ]]) {
324
+ val arrayListValue = $(vocabulary).asInstanceOf [java.util.ArrayList [String ]]
325
+ arrayListValue.asScala.toArray
326
+ } else $(vocabulary)
327
+ }
328
+
320
329
setDefault(
321
330
maxInputLength -> 40 ,
322
331
maxOutputLength -> 40 ,
@@ -349,7 +358,7 @@ class MarianTransformer(override val uid: String)
349
358
sentences = allAnnotations.map(_._1),
350
359
maxInputLength = $(maxInputLength),
351
360
maxOutputLength = $(maxOutputLength),
352
- vocabs = $(vocabulary) ,
361
+ vocabs = getVocabulary ,
353
362
langId = $(langId),
354
363
batchSize = $(batchSize),
355
364
ignoreTokenIds = $(ignoreTokenIds))
@@ -441,7 +450,6 @@ trait ReadMarianMTDLModel extends ReadTensorflowModel with ReadSentencePieceMode
441
450
addReader(readModel)
442
451
443
452
def loadSavedModel (modelPath : String , spark : SparkSession ): MarianTransformer = {
444
-
445
453
val (localModelPath, detectedEngine) = modelSanityCheck(modelPath)
446
454
447
455
val sppSrc = loadSentencePieceAsset(localModelPath, " source.spm" )
0 commit comments