在JavaScript項(xiàng)目中,TensorFlow.js的安裝方法有兩種:一種是通過(guò)script標(biāo)簽引入,另外一種就是通過(guò)npm進(jìn)行安裝。
如果不熟悉WEB開發(fā)的同學(xué),我們建議使用腳本標(biāo)簽來(lái)獲取。
使用Script Tag
將以下腳本標(biāo)簽添加到您的主HTML文件中:
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js" rel="external nofollow" ></script>
有關(guān)腳本標(biāo)簽的設(shè)置,請(qǐng)參閱代碼示例:
<html>
<head>
<!-- Load TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.9.0" rel="external nofollow" > </script>
<!-- Place your code in the script tag below. You can also use an external .js file -->
<script>
// Notice there is no 'import' statement. 'tf' is available on the index-page
// because of the script tag above.
// Define a model for linear regression.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// Prepare the model for training: Specify the loss and the optimizer.
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// Generate some synthetic data for training.
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// Train the model using the data.
model.fit(xs, ys).then(() => {
// Use the model to do inference on a data point the model hasn't seen before:
// Open the browser devtools to see the output
model.predict(tf.tensor2d([5], [1, 1])).print();
});
</script>
</head>
<body>
</body>
</html>
通過(guò)NPM(或yarn)
使用yarn或npm將TensorFlow.js添加到您的項(xiàng)目中。 注意:因?yàn)槭褂肊S2017語(yǔ)法(如import),所以此工作流程假定您使用打包程序/轉(zhuǎn)換程序?qū)⒋a轉(zhuǎn)換為瀏覽器可以理解的內(nèi)容。
yarn add @tensorflow/tfjs
或者
npm install @tensorflow/tfjs
在NPM中輸入以下代碼:
import * as tf from '@tensorflow/tfjs';
//定義一個(gè)線性回歸模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 為訓(xùn)練生成一些合成數(shù)據(jù)
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// 使用數(shù)據(jù)訓(xùn)練模型
model.fit(xs, ys, {epochs: 10}).then(() => {
// 在該模型從未看到過(guò)的數(shù)據(jù)點(diǎn)上使用模型進(jìn)行推理
model.predict(tf.tensor2d([5], [1, 1])).print();
// 打開瀏覽器開發(fā)工具查看輸出
});
更多建議: