diff --git a/src/backends/backend.ts b/src/backends/backend.ts index 434d0e5fd8..048f2ba954 100644 --- a/src/backends/backend.ts +++ b/src/backends/backend.ts @@ -17,7 +17,7 @@ import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util'; import {Activation} from '../ops/fused_util'; -import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; +import {Backend, DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor'; import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types'; export const EPSILON_FLOAT32 = 1e-7; @@ -623,4 +623,13 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer { dispose(): void { throw new Error('Not yet implemented'); } + + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + throw new Error('Not yet implemented'); + } + + decodeBase64(str: StringTensor|Tensor): T { + throw new Error('Not yet implemented'); + } } diff --git a/src/backends/cpu/backend_cpu.ts b/src/backends/cpu/backend_cpu.ts index 4d05b7bbf1..fecca34781 100644 --- a/src/backends/cpu/backend_cpu.ts +++ b/src/backends/cpu/backend_cpu.ts @@ -33,7 +33,7 @@ import {buffer, scalar, tensor, tensor3d, tensor4d} from '../../ops/ops'; import * as scatter_nd_util from '../../ops/scatter_nd_util'; import * as selu_util from '../../ops/selu_util'; import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; -import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; +import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor'; import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util'; @@ -42,6 +42,7 @@ import * as backend_util from '../backend_util'; import * as complex_util from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {decodeBase64Impl, encodeBase64Impl} from '../string_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -80,8 +81,8 @@ export class MathBackendCPU implements KernelBackend { public blockSize = 48; private data: DataStorage>; - private fromPixels2DContext: CanvasRenderingContext2D - | OffscreenCanvasRenderingContext2D; + private fromPixels2DContext: CanvasRenderingContext2D| + OffscreenCanvasRenderingContext2D; private firstUse = true; constructor() { @@ -133,12 +134,11 @@ export class MathBackendCPU implements KernelBackend { const isPixelData = (pixels as PixelData).data instanceof Uint8Array; const isImageData = - typeof(ImageData) !== 'undefined' && pixels instanceof ImageData; - const isVideo = - typeof(HTMLVideoElement) !== 'undefined' - && pixels instanceof HTMLVideoElement; - const isImage = typeof(HTMLImageElement) !== 'undefined' - && pixels instanceof HTMLImageElement; + typeof (ImageData) !== 'undefined' && pixels instanceof ImageData; + const isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + const isImage = typeof (HTMLImageElement) !== 'undefined' && + pixels instanceof HTMLImageElement; let vals: Uint8ClampedArray|Uint8Array; // tslint:disable-next-line:no-any @@ -3194,6 +3194,17 @@ export class MathBackendCPU implements KernelBackend { dispose() {} + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return encodeBase64Impl(sVals, str.shape, pad); + } + + decodeBase64(str: StringTensor|Tensor): T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return decodeBase64Impl(sVals, str.shape); + } + floatPrecision(): 16|32 { return 32; } diff --git a/src/backends/string_shared.ts b/src/backends/string_shared.ts new file mode 100644 index 0000000000..c0416b3fcf --- /dev/null +++ b/src/backends/string_shared.ts @@ -0,0 +1,57 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {arrayBufferToBase64String, base64StringToArrayBuffer, urlSafeBase64, urlUnsafeBase64} from '../io/io_utils'; +import {StringTensor, Tensor} from '../tensor'; +import {decodeString} from '../util'; + +/** Shared implementation of the encodeBase64 kernel across WebGL and CPU. */ +export function encodeBase64Impl( + values: Uint8Array[], shape: number[], pad = false): T { + const resultValues = new Array(values.length); + + for (let i = 0; i < values.length; ++i) { + const bStr = arrayBufferToBase64String(values[i].buffer); + const bStrUrl = urlSafeBase64(bStr); + + if (pad) { + resultValues[i] = bStrUrl; + } else { + // Remove padding + resultValues[i] = bStrUrl.replace(/=/g, ''); + } + } + + return Tensor.make(shape, {values: resultValues}, 'string') as T; +} + +/** Shared implementation of the decodeBase64 kernel across WebGL and CPU. */ +export function decodeBase64Impl( + values: Uint8Array[], shape: number[]): T { + const resultValues = new Array(values.length); + + for (let i = 0; i < values.length; ++i) { + // Undo URL safe and decode from Base64 to ArrayBuffer + const bStrUrl = decodeString(values[i]); + const bStr = urlUnsafeBase64(bStrUrl); + const aBuff = base64StringToArrayBuffer(bStr); + + resultValues[i] = decodeString(new Uint8Array(aBuff)); + } + + return Tensor.make(shape, {values: resultValues}, 'string') as T; +} diff --git a/src/backends/webgl/backend_webgl.ts b/src/backends/webgl/backend_webgl.ts index 41a24f6b1c..7db0747184 100644 --- a/src/backends/webgl/backend_webgl.ts +++ b/src/backends/webgl/backend_webgl.ts @@ -36,7 +36,7 @@ import * as segment_util from '../../ops/segment_util'; import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util'; import {softmax} from '../../ops/softmax'; import {range, scalar, tensor} from '../../ops/tensor_ops'; -import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; +import {DataId, Scalar, StringTensor, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor'; import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types'; import * as util from '../../util'; import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util'; @@ -45,6 +45,7 @@ import * as backend_util from '../backend_util'; import {mergeRealAndImagArrays} from '../complex_util'; import {nonMaxSuppressionImpl} from '../non_max_suppression_impl'; import {split} from '../split_shared'; +import {decodeBase64Impl, encodeBase64Impl} from '../string_shared'; import {tile} from '../tile_impl'; import {topkImpl} from '../topk_impl'; import {whereImpl} from '../where_impl'; @@ -62,7 +63,7 @@ import * as binaryop_gpu from './binaryop_gpu'; import {BinaryOpProgram} from './binaryop_gpu'; import * as binaryop_packed_gpu from './binaryop_packed_gpu'; import {BinaryOpPackedProgram} from './binaryop_packed_gpu'; -import {getWebGLContext, createCanvas} from './canvas_util'; +import {createCanvas, getWebGLContext} from './canvas_util'; import {ClipProgram} from './clip_gpu'; import {ClipPackedProgram} from './clip_packed_gpu'; import {ComplexAbsProgram} from './complex_abs_gpu'; @@ -222,8 +223,8 @@ export class MathBackendWebGL implements KernelBackend { private numBytesInGPU = 0; private canvas: HTMLCanvasElement; - private fromPixels2DContext: CanvasRenderingContext2D - | OffscreenCanvasRenderingContext2D; + private fromPixels2DContext: CanvasRenderingContext2D| + OffscreenCanvasRenderingContext2D; private programTimersStack: TimerNode[]; private activeTimers: TimerNode[]; @@ -272,9 +273,9 @@ export class MathBackendWebGL implements KernelBackend { } fromPixels( - pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| - HTMLVideoElement, - numChannels: number): Tensor3D { + pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement| + HTMLVideoElement, + numChannels: number): Tensor3D { if (pixels == null) { throw new Error( 'pixels passed to tf.browser.fromPixels() can not be null'); @@ -282,26 +283,25 @@ export class MathBackendWebGL implements KernelBackend { const texShape: [number, number] = [pixels.height, pixels.width]; const outShape = [pixels.height, pixels.width, numChannels]; - const isCanvas = (typeof(OffscreenCanvas) !== 'undefined' - && pixels instanceof OffscreenCanvas) - || (typeof(HTMLCanvasElement) !== 'undefined' - && pixels instanceof HTMLCanvasElement); + const isCanvas = (typeof (OffscreenCanvas) !== 'undefined' && + pixels instanceof OffscreenCanvas) || + (typeof (HTMLCanvasElement) !== 'undefined' && + pixels instanceof HTMLCanvasElement); const isPixelData = (pixels as PixelData).data instanceof Uint8Array; const isImageData = - typeof(ImageData) !== 'undefined' && pixels instanceof ImageData; - const isVideo = - typeof(HTMLVideoElement) !== 'undefined' - && pixels instanceof HTMLVideoElement; - const isImage = typeof(HTMLImageElement) !== 'undefined' - && pixels instanceof HTMLImageElement; + typeof (ImageData) !== 'undefined' && pixels instanceof ImageData; + const isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + const isImage = typeof (HTMLImageElement) !== 'undefined' && + pixels instanceof HTMLImageElement; if (!isCanvas && !isPixelData && !isImageData && !isVideo && !isImage) { throw new Error( - 'pixels passed to tf.browser.fromPixels() must be either an ' + - `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` + - `in browser, or OffscreenCanvas, ImageData in webworker` + - ` or {data: Uint32Array, width: number, height: number}, ` + - `but was ${(pixels as {}).constructor.name}`); + 'pixels passed to tf.browser.fromPixels() must be either an ' + + `HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData ` + + `in browser, or OffscreenCanvas, ImageData in webworker` + + ` or {data: Uint32Array, width: number, height: number}, ` + + `but was ${(pixels as {}).constructor.name}`); } if (isVideo) { @@ -314,14 +314,14 @@ export class MathBackendWebGL implements KernelBackend { 'on the document object'); } //@ts-ignore - this.fromPixels2DContext = createCanvas(ENV.getNumber('WEBGL_VERSION')) - .getContext('2d'); + this.fromPixels2DContext = + createCanvas(ENV.getNumber('WEBGL_VERSION')).getContext('2d'); } this.fromPixels2DContext.canvas.width = pixels.width; this.fromPixels2DContext.canvas.height = pixels.height; this.fromPixels2DContext.drawImage( - pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height); - //@ts-ignore + pixels as HTMLVideoElement, 0, 0, pixels.width, pixels.height); + //@ts-ignore pixels = this.fromPixels2DContext.canvas; } @@ -2176,6 +2176,17 @@ export class MathBackendWebGL implements KernelBackend { return split(x, sizeSplits, axis); } + encodeBase64(str: StringTensor|Tensor, pad = false): + T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return encodeBase64Impl(sVals, str.shape, pad); + } + + decodeBase64(str: StringTensor|Tensor): T { + const sVals = this.readSync(str.dataId) as Uint8Array[]; + return decodeBase64Impl(sVals, str.shape); + } + scatterND( indices: Tensor, updates: Tensor, shape: ShapeMap[R]): Tensor { const {sliceRank, numUpdates, sliceSize, strides, outputSize} = diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index a4a3b26763..4653619a1f 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -335,3 +335,22 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): modelArtifacts.weightData.byteLength, }; } + +/** + * Make Base64 string URL safe by replacing `+` with `-` and `/` with `_`. + * + * @param str Base64 string to make URL safe. + */ +export function urlSafeBase64(str: string): string { + return str.replace(/\+/g, '-').replace(/\//g, '_'); +} + +// revert Base64 URL safe replacement of + and / +/** + * Revert Base64 URL safe changes by replacing `-` with `+` and `_` with `/`. + * + * @param str URL safe Base string to revert changes. + */ +export function urlUnsafeBase64(str: string): string { + return str.replace(/-/g, '+').replace(/_/g, '/'); +} diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 135eeecb98..71a79e7273 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -22,6 +22,7 @@ import {NamedTensor, NamedTensorMap} from '../tensor_types'; import {expectArraysEqual} from '../test_util'; import {expectArraysClose} from '../test_util'; import {encodeString} from '../util'; + import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils'; import {WeightsManifestEntry} from './types'; diff --git a/src/ops/ops.ts b/src/ops/ops.ts index 7714ba320b..7b4b9252bf 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -46,6 +46,7 @@ export * from './sparse_to_dense'; export * from './gather_nd'; export * from './dropout'; export * from './signal_ops'; +export * from './string_ops'; export {op} from './operation'; diff --git a/src/ops/string_ops.ts b/src/ops/string_ops.ts new file mode 100644 index 0000000000..c1c48ec711 --- /dev/null +++ b/src/ops/string_ops.ts @@ -0,0 +1,78 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {StringTensor, Tensor} from '../tensor'; +import {convertToTensor} from '../tensor_util_env'; + +import {op} from './operation'; + +/** + * Encodes the values of a `tf.Tensor` (of dtype `string`) to Base64. + * + * Given a String tensor, returns a new tensor with the values encoded into + * web-safe base64 format. + * + * Web-safe means that the encoder uses `-` and `_` instead of `+` and `/`: + * + * en.wikipedia.org/wiki/Base64 + * + * ```js + * const x = tf.tensor1d(['Hello world!'], 'string'); + * + * x.encodeBase64().print(); + * ``` + * @param str The input `tf.Tensor` of dtype `string` to encode. + * @param pad Whether to add padding (`=`) to the end of the encoded string. + */ +/** @doc {heading: 'Operations', subheading: 'String'} */ +function encodeBase64_( + str: StringTensor|Tensor, pad = false): T { + const $str = convertToTensor(str, 'str', 'encodeBase64', 'string'); + + const backwardsFunc = (dy: T) => ({$str: () => decodeBase64(dy)}); + + return ENGINE.runKernel( + backend => backend.encodeBase64($str, pad), {$str}, backwardsFunc); +} + +/** + * Decodes the values of a `tf.Tensor` (of dtype `string`) from Base64. + * + * Given a String tensor of Base64 encoded values, returns a new tensor with the + * decoded values. + * + * en.wikipedia.org/wiki/Base64 + * + * ```js + * const y = tf.scalar('SGVsbG8gd29ybGQh', 'string'); + * + * y.decodeBase64().print(); + * ``` + * @param str The input `tf.Tensor` of dtype `string` to decode. + */ +/** @doc {heading: 'Operations', subheading: 'String'} */ +function decodeBase64_(str: StringTensor|Tensor): T { + const $str = convertToTensor(str, 'str', 'decodeBase64', 'string'); + + const backwardsFunc = (dy: T) => ({$str: () => encodeBase64(dy)}); + + return ENGINE.runKernel( + backend => backend.decodeBase64($str), {$str}, backwardsFunc); +} + +export const encodeBase64 = op({encodeBase64_}); +export const decodeBase64 = op({decodeBase64_}); diff --git a/src/ops/string_ops_test.ts b/src/ops/string_ops_test.ts new file mode 100644 index 0000000000..5eecc442d1 --- /dev/null +++ b/src/ops/string_ops_test.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysEqual} from '../test_util'; + +const txtArr = [ + 'Hello TensorFlow.js!', '𝌆', 'Pre\u2014trained models with Base64 ops\u002e', + 'how about these? 🌍💻🍕', 'https://www.tensorflow.org/js', 'àβÇdéf', + '你好, 世界', `Build, train, & deploy +ML models in JS` +]; +const urlSafeB64 = [ + 'SGVsbG8gVGVuc29yRmxvdy5qcyE', '8J2Mhg', + 'UHJl4oCUdHJhaW5lZCBtb2RlbHMgd2l0aCBCYXNlNjQgb3BzLg', + 'aG93IGFib3V0IHRoZXNlPyDwn4yN8J-Su_CfjZU', + 'aHR0cHM6Ly93d3cudGVuc29yZmxvdy5vcmcvanM', 'w6DOssOHZMOpZg', + '5L2g5aW9LCDkuJbnlYw', 'QnVpbGQsIHRyYWluLCAmIGRlcGxveQpNTCBtb2RlbHMgaW4gSlM' +]; +const urlSafeB64Pad = [ + 'SGVsbG8gVGVuc29yRmxvdy5qcyE=', '8J2Mhg==', + 'UHJl4oCUdHJhaW5lZCBtb2RlbHMgd2l0aCBCYXNlNjQgb3BzLg==', + 'aG93IGFib3V0IHRoZXNlPyDwn4yN8J-Su_CfjZU=', + 'aHR0cHM6Ly93d3cudGVuc29yZmxvdy5vcmcvanM=', 'w6DOssOHZMOpZg==', + '5L2g5aW9LCDkuJbnlYw=', 'QnVpbGQsIHRyYWluLCAmIGRlcGxveQpNTCBtb2RlbHMgaW4gSlM=' +]; + +describeWithFlags('encodeBase64', ALL_ENVS, () => { + it('scalar', async () => { + const a = tf.scalar(txtArr[1], 'string'); + const r = tf.encodeBase64(a); + expect(r.shape).toEqual([]); + expectArraysEqual(await r.data(), urlSafeB64[1]); + }); + it('1D padded', async () => { + const a = tf.tensor1d([txtArr[2]], 'string'); + const r = tf.encodeBase64(a, true); + expect(r.shape).toEqual([1]); + expectArraysEqual(await r.data(), [urlSafeB64Pad[2]]); + }); + it('2D', async () => { + const a = tf.tensor2d(txtArr, [2, 4], 'string'); + const r = tf.encodeBase64(a, false); + expect(r.shape).toEqual([2, 4]); + expectArraysEqual(await r.data(), urlSafeB64); + }); + it('3D padded', async () => { + const a = tf.tensor3d(txtArr, [2, 2, 2], 'string'); + const r = tf.encodeBase64(a, true); + expect(r.shape).toEqual([2, 2, 2]); + expectArraysEqual(await r.data(), urlSafeB64Pad); + }); +}); + +describeWithFlags('decodeBase64', ALL_ENVS, () => { + it('scalar', async () => { + const a = tf.scalar(urlSafeB64[1], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([]); + expectArraysEqual(await r.data(), txtArr[1]); + }); + it('1D padded', async () => { + const a = tf.tensor1d([urlSafeB64Pad[2]], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([1]); + expectArraysEqual(await r.data(), [txtArr[2]]); + }); + it('2D', async () => { + const a = tf.tensor2d(urlSafeB64, [2, 4], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([2, 4]); + expectArraysEqual(await r.data(), txtArr); + }); + it('3D padded', async () => { + const a = tf.tensor3d(urlSafeB64Pad, [2, 2, 2], 'string'); + const r = tf.decodeBase64(a); + expect(r.shape).toEqual([2, 2, 2]); + expectArraysEqual(await r.data(), txtArr); + }); +}); + +describeWithFlags('encodeBase64-decodeBase64', ALL_ENVS, () => { + it('round-trip', async () => { + const s = [txtArr.join('')]; + const a = tf.tensor(s, [1], 'string'); + const b = tf.encodeBase64(a); + const c = tf.decodeBase64(b); + expectArraysEqual(await c.data(), s); + }); +}); diff --git a/src/tensor.ts b/src/tensor.ts index 0cd8fdbe57..845012108e 100644 --- a/src/tensor.ts +++ b/src/tensor.ts @@ -367,6 +367,8 @@ export interface OpHandler { fft(x: Tensor): Tensor; ifft(x: Tensor): Tensor; rfft(x: Tensor): Tensor; irfft(x: Tensor): Tensor }; + encodeBase64(x: T, pad: boolean): T; + decodeBase64(x: T): T; } // For tracking tensor creation and disposal. @@ -1425,6 +1427,16 @@ export class Tensor { this.throwIfDisposed(); return opHandler.spectral.irfft(this); } + + encodeBase64(this: T, pad = false): T { + this.throwIfDisposed(); + return opHandler.encodeBase64(this, pad); + } + + decodeBase64(this: T): T { + this.throwIfDisposed(); + return opHandler.decodeBase64(this); + } } Object.defineProperty(Tensor, Symbol.hasInstance, { value: (instance: Tensor) => { diff --git a/src/tests.ts b/src/tests.ts index eba4c3b71b..6792238fef 100644 --- a/src/tests.ts +++ b/src/tests.ts @@ -84,6 +84,7 @@ import './ops/softmax_test'; import './ops/sparse_to_dense_test'; import './ops/spectral_ops_test'; import './ops/strided_slice_test'; +import './ops/string_ops_test'; import './ops/topk_test'; import './ops/transpose_test'; import './ops/unary_ops_test';