You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi ,
I use scala restructure the attention layer from javacpp -pytorch 15.1-2.1 ,java23 scala3.6
first 1.now all attention layer ,cannot use make memory out stackflow ,
second 2 FractionalMaxPool2d FractionalMaxPool3d cannot pass right parameter for output_size output_ratio ,
third 3 AdaptiveMaxPool3d AdaptiveMaxPool2d AdaptiveAvgPool3d AdaptiveAvgPool2d the last dimension is always be zero
package com.hideep.example
import org.bytedeco.javacpp.{FloatPointer, PointerScope}
import org.bytedeco.pytorch.{OutputArchive, TensorExampleVectorIterator}
import torch.{Float32, *}
import torch.Device.{CPU, CUDA}
import torch.data.dataset.ChunkSharedBatchDataset
import torch.nn.functional as F
import torch.nn.modules.HasParams
import torch.nn.modules.attention.Transformer.TransformerActivation.kGELU
import torch.optim.Adam
import torchvision.datasets.FashionMNIST
//import torchvision.datasets.FashionMNIST
import java.nio.file.Paths
//import scala.runtime.stdLibPatches.Predef.nn
import scala.util.{Random, Using}
import torch.internal.NativeConverters.{fromNative, toNative}
object testLayer:
def transformerTest(): Unit = {
val transformer = nn.Transformer(nhead = 16, num_encoder_layers = 12, activation = kGELU)
val src = torch.rand(Seq(10, 32, 512))
val tgt = torch.rand(Seq(20, 32, 512))
val out = transformer(src, tgt)
println(s"transformer out shape ${out.shape}")
val input = torch.randn(Seq(1, 64, 8, 9))
}
def multiheadAttentionTest():Unit ={
// val m1 = nn.TransformerEncoderLayer(d_model = 8,n_head = 4)
val batchSize =2
val seqLength = 10
val embedDim = 64
val numHeads = 8
val multiheadAttention = nn.MultiheadAttention(embed_dim = 64,num_heads = 8,dropout = 0.1,kdim = 8,bias = true,vdim = 8) //java.lang.RuntimeException: from is out of bounds for float
val input = torch.randn(Seq(batchSize,seqLength,embedDim))
val out = multiheadAttention(input,input,input)
println(s"multiheadAttention attn_output ${out._1.shape} attn_weight ${out._2.shape}")
}
def transformerEncoderLayerTest():Unit = {
val input = torch.randn(Seq(10, 32, 512))
// assertEquals(m12(input).shape, Seq(1, 64, 5, 7)) //lack of paramter
val layer = nn.TransformerEncoderLayer(d_model = 512, n_head = 8)
// val encoder = nn.TransformerEncoder(encoder_layer = layer, num_layers = 6)
val out = layer(input)
println(s"out .shape ${out.shape}")
}
def TransformerEncoderLayerSuite2():Unit ={
val input = torch.randn(Seq(32,10, 512))
// assertEquals(m12(input).shape, Seq(1, 64, 5, 7)) //lack of paramter
val layer = nn.TransformerEncoderLayer(d_model = 512, n_head = 8,batch_first = true)
// val encoder = nn.TransformerEncoder(encoder_layer = layer, num_layers = 6)
val out = layer(input)
println(s"out .shape ${out.shape}")
}
def TransformerDecoderLayerSuite():Unit = {
val layer = nn.TransformerDecoderLayer(d_model = 512, n_head = 8)
// val decoder = nn.TransformerDecoder(decoder_layer = layer, num_layers = 6)
val memory = torch.randn(Seq(10, 32, 512))
val tgt = torch.randn(Seq(20, 32, 512))
val out = layer(tgt, memory)
println(s"out.shape ${out.shape}")
}
def TransformerDecoderLayerSuite2():Unit ={
val layer = nn.TransformerDecoderLayer(d_model = 512, n_head = 8,batch_first = true)
// val decoder = nn.TransformerDecoder(decoder_layer = layer, num_layers = 6)
val memory = torch.randn(Seq( 32,10, 512))
val tgt = torch.randn(Seq(32,20, 512))
val out = layer(tgt, memory)
println(s"out.shape ${out.shape}")
}
def TransformerEncoderSuite():Unit = {
val input = torch.randn(Seq(10, 32, 512))
// assertEquals(m12(input).shape, Seq(1, 64, 5, 7)) //lack of paramter
val layer = nn.TransformerEncoderLayer(d_model = 512, n_head = 8)
val encoder = nn.TransformerEncoder(encoder_layer = layer,num_layers = 6)
val out = encoder(input)
println(s"out .shape ${out.shape}")
}
def TransformerDecoderSuite():Unit = {
val layer = nn.TransformerDecoderLayer(d_model = 512, n_head = 8)
val decoder = nn.TransformerDecoder(decoder_layer = layer, num_layers = 6)
val memory = torch.randn(Seq(10, 32, 512))
val tgt = torch.randn(Seq(20, 32, 512))
val out = decoder(tgt, memory)
println(s"out.shape ${out.shape}")
}
def transformerCoderSuite():Unit = {
val m1 = nn.TransformerEncoderLayer(d_model = 8,n_head = 4)
val m12 = nn.TransformerEncoder(encoder_layer = m1,num_layers = 4)
val input = torch.randn(Seq(1, 64, 8, 9))
println(m12(input).shape )
// assertEquals(m12(input).shape, Seq(1, 64, 5, 7)) //lack of paramter
val m2 = nn.TransformerDecoderLayer(d_model = 8, n_head = 4)
val m22 = nn.TransformerDecoder(decoder_layer = m2,num_layers = 4)
println(m22(input).shape )
// assertEquals(m22(input).shape, Seq(1, 64, 1, 1))
}
def main(args: Array[String]): Unit =
transformerTest()
multiheadAttentionTest()
// transformerEncoderLayerTest()
// TransformerEncoderLayerSuite2()
// TransformerDecoderLayerSuite()
// transformerCoderSuite()
// TransformerDecoderSuite()
TransformerEncoderSuite()
// val m13 = nn.FractionalMaxPool2d(kernel_size = (7, 7), output_size = Some(5, 7), output_ratio = Some(0.57f, 0.57f))
// val input = torch.randn(Seq(1, 64, 8, 9))
// m13(input.to(torch.float64)).shape
// val m23 = nn.FractionalMaxPool3d(kernel_size = (4, 8, 1), output_size = Some(5, 6, 7), output_ratio = Some(0.4f, 0.34f, 0.57f))
// assertEquals(m23(input.to(torch.float64)).shape, Seq(1, 64, 1, 1))
// val model = new LstmNet()
// val input = Tensor.rand(1, 1, 10)
// val output = model.forward(input)
// println(output)
FractionalMaxPool2d FractionalMaxPool3d cannot pass right parameter for output_size output_ratio
FractionalMaxPool2d ://
Testing started at 15:25 ...
output_size two elements full
output ratio two elements full
randomSamples is None, outputSize: Some((7,7)) outputRatio Some((0.57,5.0))
FractionalMaxPool2d raw options kernel 7 k2 7 outsize true 7454142788699316554 out2 7598807741461061480 outRatio true 5.0 ratio2 1.9108424296605356E214
java.lang.RuntimeException: FractionalMaxPool2d requires specifying either an output size, or a pooling ratio
FractionalMaxPool3d //
Testing started at 15:21 ...
output_size three elements full
output ratio three elements full
randomSamples is None outputSize Some((5,6,7)) outputRatio Some((0.4,0.34,0.57))
FractionalMaxPool3d raw options kernel 4 k2 8 k3 1 outsize true 25895968444448860 out2 28147905699774563 out3 25896174605893733 outRatio true 0.5699999928474426 ratio2 1.8691964927058795E-306 ratio3 1.1125994018147494E-306
java.lang.RuntimeException: FractionalMaxPool2d requires specifying either an output size, or a pooling ratio
Please try to narrow down where the crash occurs
Now I try to use correct way to set options value ,but can not work ,
I want to promise now these 8 pools layer really has bug ,,
first anyway set value all can not work
second you need to change the options out_size return type not as LongOptional ,It really bug ,and when you want to check the inner value only can get one though you have cast the type to vector. but for pool2d pool3d we must set two or three values /!! please change the return type
val input = torch.randn(Seq(1, 64, 8, 9, 10)).native
val s1 =LongPointer(1).put(5)
val s2 = LongPointer(2).put(7)
val s3 = LongPointer(3).put(9)
// val vec = LongVectorOptional(LongVector(s1,s2,s3))
val nativeOutputSize = new LongOptionalVector(LongOptional(5), LongOptional(7), LongOptional(9))
// val nativeOutputSize = new LongOptionalVector(LongOptional(s1), LongOptional(s2), LongOptional(s3))
val options = AdaptiveMaxPool3dOptions(nativeOutputSize)
println(s"options out ${options.output_size().get()}")
options.output_size().put(nativeOutputSize)
val model = AdaptiveMaxPool3dImpl(options)
val output =fromNative( model.forward(input))
println(s" output.shape ${output.shape}")
console log
options out 1308676889232
Exception in thread "main" java.lang.RuntimeException: Storage size calculation overflowed with sizes=[1, 64, 1308676889232, 1308676889280, 10]
please run the test code ,you will check the fault or error .or could you show me the correct workable code ,thanks
also FractionalMaxPool2d FractionalMaxPool3d I need your correct workable code too, thanks
these pool layer is important for compute vision ,please make a serious to fixup ,you are powerful man ,@saudet
Hi ,
I use scala restructure the attention layer from javacpp -pytorch 15.1-2.1 ,java23 scala3.6
first 1.now all attention layer ,cannot use make memory out stackflow ,
second 2 FractionalMaxPool2d FractionalMaxPool3d cannot pass right parameter for output_size output_ratio ,
third 3 AdaptiveMaxPool3d AdaptiveMaxPool2d AdaptiveAvgPool3d AdaptiveAvgPool2d the last dimension is always be zero
FractionalMaxPool2d FractionalMaxPool3d cannot pass right parameter for output_size output_ratio
console log error
pool last dimensions is zero
The text was updated successfully, but these errors were encountered: