diff --git a/nodes/ImageNode.py b/nodes/ImageNode.py
index 432d26b2..8b81f4c8 100644
--- a/nodes/ImageNode.py
+++ b/nodes/ImageNode.py
@@ -8,6 +8,7 @@
import base64,os,random
from io import BytesIO
import folder_paths
+import node_helpers
import json,io
import comfy.utils
from comfy.cli_args import args
@@ -491,6 +492,53 @@ def load_image(fp,white_bg=False):
return images
+
+# 读取图片数据,转成tensor
+def load_image_to_tensor( image):
+ image_path = folder_paths.get_annotated_filepath(image)
+
+ img = node_helpers.pillow(Image.open, image_path)
+
+ output_images = []
+ output_masks = []
+ w, h = None, None
+
+ excluded_formats = ['MPO']
+
+ for i in ImageSequence.Iterator(img):
+ i = node_helpers.pillow(ImageOps.exif_transpose, i)
+
+ if i.mode == 'I':
+ i = i.point(lambda i: i * (1 / 255))
+ image = i.convert("RGB")
+
+ if len(output_images) == 0:
+ w = image.size[0]
+ h = image.size[1]
+
+ if image.size[0] != w or image.size[1] != h:
+ continue
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = torch.from_numpy(image)[None,]
+ if 'A' in i.getbands():
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
+ mask = 1. - torch.from_numpy(mask)
+ else:
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
+ output_images.append(image)
+ output_masks.append(mask.unsqueeze(0))
+
+ if len(output_images) > 1 and img.format not in excluded_formats:
+ output_image = torch.cat(output_images, dim=0)
+ output_mask = torch.cat(output_masks, dim=0)
+ else:
+ output_image = output_images[0]
+ output_mask = output_masks[0]
+
+ return (output_image, output_mask)
+
+
def load_image_and_mask_from_url(url, timeout=10):
# Load the image from the URL
response = requests.get(url, timeout=timeout)
@@ -1687,28 +1735,51 @@ def INPUT_TYPES(s):
def run(self,upload,material=None):
# print('material',material)
# print(upload )
- image = base64_to_image(upload['image'])
- mat=None
- if 'material' in upload and upload['material']:
- mat=base64_to_image(upload['material'])
- mat=mat.convert('RGB')
- mat=pil2tensor(mat)
+ # 截取的系列角度截图
+ images=upload['images'] if "images" in upload else []
- mask = image.split()[3]
- image=image.convert('RGB')
+ ims=[]
+ for im in images:
+ if 'type' in im and (not f"[{im['type']}]" in im['name']):
+ im['name']=im['name']+" "+f"[{im['type']}]"
+ output_image, output_mask = load_image_to_tensor(im['name'])
+ ims.append(output_image)
- mask=mask.convert('L')
-
+
+ mask=None
bg_image=None
- if 'bg_image' in upload and upload['bg_image']:
- bg_image = base64_to_image(upload['bg_image'])
- bg_image=bg_image.convert('RGB')
- bg_image=pil2tensor(bg_image)
+ mat=None
+
+ # 如果没有系列截图
+ if len(ims)==0:
+ # 这个是3d模型当前截图
+ image = base64_to_image(upload['image'])
+
+ if 'material' in upload and upload['material']:
+ mat=base64_to_image(upload['material'])
+ mat=mat.convert('RGB')
+ mat=pil2tensor(mat)
+
+ mask = image.split()[3]
+ image=image.convert('RGB')
+
+ mask=mask.convert('L')
+
+
+ if 'bg_image' in upload and upload['bg_image']:
+ bg_image = base64_to_image(upload['bg_image'])
+ bg_image=bg_image.convert('RGB')
+ bg_image=pil2tensor(bg_image)
+
+
+ mask=pil2tensor(mask)
+ image=pil2tensor(image)
+ else:
+
+ image = torch.cat(ims, dim=0)
- mask=pil2tensor(mask)
- image=pil2tensor(image)
m=[]
if not material is None:
diff --git a/nodes/P5.py b/nodes/P5.py
index 509a462e..615db2db 100644
--- a/nodes/P5.py
+++ b/nodes/P5.py
@@ -16,7 +16,7 @@ def tensor2pil(image):
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
-def load_image( image):
+def load_image_to_tensor( image):
image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)
@@ -88,7 +88,7 @@ def run(self, frames):
if 'type' in im and (not f"[{im['type']}]" in im['name']):
im['name']=im['name']+" "+f"[{im['type']}]"
- output_image, output_mask = load_image(im['name'])
+ output_image, output_mask = load_image_to_tensor(im['name'])
ims.append(output_image)
if len(ims)==0:
diff --git a/web/javascript/3d_mixlab.js b/web/javascript/3d_mixlab.js
index 70cca447..a9f48c38 100644
--- a/web/javascript/3d_mixlab.js
+++ b/web/javascript/3d_mixlab.js
@@ -26,7 +26,8 @@ const setLocalDataOfWin = (key, value) => {
localStorage.setItem(key, JSON.stringify(value))
// window[key] = value
}
-async function uploadImage (blob, fileType = '.svg', filename) {
+
+async function uploadImage_ (blob, fileType = '.svg', filename) {
// const blob = await (await fetch(src)).blob();
const body = new FormData()
body.append(
@@ -41,13 +42,17 @@ async function uploadImage (blob, fileType = '.svg', filename) {
// console.log(resp)
let data = await resp.json()
+ return data
+}
+
+async function uploadImage (blob, fileType = '.svg', filename) {
+ let data = await uploadImage_(blob, fileType, filename)
let { name, subfolder } = data
let src = api.apiURL(
`/view?filename=${encodeURIComponent(
name
)}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}${app.getRandParam()}`
)
-
return src
}
@@ -189,7 +194,7 @@ app.registerExtension({
let d = getLocalData('_mixlab_3d_image')
// console.log('serializeValue', node)
if (d && d[node.id]) {
- let { url, bg, material } = d[node.id]
+ let { url, bg, material, images } = d[node.id]
let data = {}
if (url) {
data.image = await parseImage(url)
@@ -205,6 +210,10 @@ app.registerExtension({
data.material = await parseImage(material)
}
+ if (images) {
+ data.images = images
+ }
+
return JSON.parse(JSON.stringify(data))
} else {
return {}
@@ -276,6 +285,7 @@ app.registerExtension({
const fileURL = URL.createObjectURL(file)
// console.log('文件URL: ', fileURL)
let html = `