AI Chatbots With TensorFlow.js: Training a Trivia Expert AI - CodeProject
Mon Dec 19 2022 02:13:37 GMT+0000 (Coordinated Universal Time)
Saved by @modesto59 #html
<html> <head> <title>Trivia Know-It-All: Chatbots in the Browser with TensorFlow.js</title> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@2.0.0/dist/tf.min.js"></script> </head> <body> <h1 id="status">Trivia Know-It-All Bot</h1> <label>Ask a trivia question:</label> <input id="question" type="text" /> <button id="submit">Submit</button> <p id="bot-question"></p> <p id="bot-answer"></p> <script> function setText( text ) { document.getElementById( "status" ).innerText = text; } (async () => { // Load TriviaQA data let triviaData = await fetch( "web/verified-wikipedia-dev.json" ).then( r => r.json() ); let data = triviaData.Data; // Process all QA to map to answers let questions = data.map( qa => qa.Question ); let bagOfWords = {}; let allWords = []; let wordReference = {}; questions.forEach( q => { let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x ); words.forEach( w => { if( !bagOfWords[ w ] ) { bagOfWords[ w ] = 0; } bagOfWords[ w ]++; // Counting occurrence just for word frequency fun }); }); allWords = Object.keys( bagOfWords ); allWords.forEach( ( w, i ) => { wordReference[ w ] = i + 1; }); // Create a tokenized vector for each question const maxSentenceLength = 30; let vectors = []; questions.forEach( q => { let qVec = []; // Use a regex to only get spaces and letters and remove any blank elements let words = q.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x ); for( let i = 0; i < maxSentenceLength; i++ ) { if( words[ i ] ) { qVec.push( wordReference[ words[ i ] ] ); } else { // Add padding to keep the vectors the same length qVec.push( 0 ); } } vectors.push( qVec ); }); let outputs = questions.map( ( q, index ) => { let output = []; for( let i = 0; i < questions.length; i++ ) { output.push( i === index ? 1 : 0 ); } return output; }); // Define our RNN model with several hidden layers const model = tf.sequential(); // Add 1 to inputDim for the "padding" character model.add(tf.layers.embedding( { inputDim: allWords.length + 1, outputDim: 128, inputLength: maxSentenceLength, maskZero: true } ) ); model.add(tf.layers.simpleRNN( { units: 32 } ) ); // model.add(tf.layers.bidirectional( { layer: tf.layers.simpleRNN( { units: 32 } ), mergeMode: "concat" } ) ); model.add(tf.layers.dense( { units: 50 } ) ); model.add(tf.layers.dense( { units: 25 } ) ); model.add(tf.layers.dense( { units: questions.length, activation: "softmax" } ) ); model.compile({ optimizer: tf.train.adam(), loss: "categoricalCrossentropy", metrics: [ "accuracy" ] }); const xs = tf.stack( vectors.map( x => tf.tensor1d( x ) ) ); const ys = tf.stack( outputs.map( x => tf.tensor1d( x ) ) ); await model.fit( xs, ys, { epochs: 20, shuffle: true, callbacks: { onEpochEnd: ( epoch, logs ) => { setText( `Training... Epoch #${epoch} (${logs.acc})` ); console.log( "Epoch #", epoch, logs ); } } } ); setText( "Trivia Know-It-All Bot is Ready!" ); document.getElementById( "question" ).addEventListener( "keyup", function( event ) { // Number 13 is the "Enter" key on the keyboard if( event.keyCode === 13 ) { // Cancel the default action, if needed event.preventDefault(); // Trigger the button element with a click document.getElementById( "submit" ).click(); } }); document.getElementById( "submit" ).addEventListener( "click", async function( event ) { let text = document.getElementById( "question" ).value; document.getElementById( "question" ).value = ""; // Run the calculation things let qVec = []; let words = text.replace(/[^a-z ]/gi, "").toLowerCase().split( " " ).filter( x => !!x ); for( let i = 0; i < maxSentenceLength; i++ ) { if( words[ i ] ) { qVec.push( wordReference[ words[ i ] ] ); } else { // Add padding to keep the vectors the same length qVec.push( 0 ); } } let prediction = await model.predict( tf.stack( [ tf.tensor1d( qVec ) ] ) ).data(); // Get the index of the highest value in the prediction let id = prediction.indexOf( Math.max( ...prediction ) ); document.getElementById( "bot-question" ).innerText = questions[ id ]; document.getElementById( "bot-answer" ).innerText = data[ id ].Answer.Value; }); })(); </script> </body> </html>
Comments