{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["1iOzmc1kF03y","YTsHyCKlO5U5","xCW9ZEM6ptC_","-XJ2ulGNT04S","zU-Ay16vYvS6","szyejNAIojCW"],"authorship_tag":"ABX9TyPWsGxkCFJF1/4UN9cBi6EL"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Point clouds and deep learning\n","\n","In this part of the tutorial, you will learn how to train a deep neural network to learn to segment plants in different plant parts. \n","\n","After finishing this part of the tutorial, you will be able to:\n","* explain how training data for 3D semantic segmentation looks like\n","* use the PointNet deep neural network\n","* compare the performance of a model trained for only a few epochs to one trained for many epochs "],"metadata":{"id":"IlrVJp_uBM1-"}},{"cell_type":"markdown","source":["# 1: Getting started\n"],"metadata":{"id":"wszcFACQ9PQU"}},{"cell_type":"markdown","source":["**Use the GPU:** Click on the menu \"Runtime\", then \"Change runtime type\" and select for \"Hardware accelerator\" **GPU**. This will making training later on much faster."],"metadata":{"id":"9JngCysfYNlJ"}},{"cell_type":"markdown","source":["We have to install Open3D again. Don't forget to press the \"restart runtime\" button that appears after the installation completed:"],"metadata":{"id":"X4fl515UKJT4"}},{"cell_type":"code","source":["!pip install open3d"],"metadata":{"id":"ngiQS6aiKJhP"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import numpy as np\n","import matplotlib.pyplot as plt\n","import plotly.graph_objects as go\n","from plotly.subplots import make_subplots\n","import plotly.express as px\n","import pandas as pd\n","import os\n","import pickle\n","from pathlib import Path\n","import random\n","import numpy as np\n","import h5py\n","\n","# TensorFlow is a library to implement and run neural networks\n","import tensorflow as tf\n","from tensorflow import keras\n","from tensorflow.keras import layers\n"],"metadata":{"id":"Pe9gvu1y9a3U"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!git clone https://git.wur.nl/koots006/summerschool-application-of-machine-learning-in-plant-sciences.git point_cloud"],"metadata":{"id":"LXRDXH1W9dHs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We need to include a few helper functions:"],"metadata":{"id":"_jIFeldTNMXC"}},{"cell_type":"code","source":["from point_cloud.pc_helper import *\n","from point_cloud.pn_helper import *"],"metadata":{"id":"DaqhJ1t4JUix"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 2: Training data for 3D semantic segmentation"],"metadata":{"id":"MRJpzzut9kGN"}},{"cell_type":"markdown","source":["We will work with a dataset consisting of point clouds of tomato seedlings, with the aim to segment the point clouds in to leaf points and stem points. The dataset consists of annotated point clouds, where the leaf and stem points were labeled by hand.\n","\n","**Exercise:**\n","* Run the code below to view a labeled point cloud\n","* Try a few different plant IDs in the range from 1 to 377 (note, not all plant IDs exist)"],"metadata":{"id":"7zzFcAEQ9p50"}},{"cell_type":"code","source":["pc_dir = 'point_cloud/data'\n","plant_id = 10\n","pc, lab = load_seedling_pc(pc_dir, plant_id)\n","if(len(pc)>0):\n"," show_labeled_seedling_pc(pc, lab)"],"metadata":{"id":"4JSU8Wy9-gQ5"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Let's have a look at the matrix containing the point cloud data (point coordinates in mm) and at the array containing the point labels (0 for stem and 1 for leaf): "],"metadata":{"id":"8IxrKJsrDXjk"}},{"cell_type":"code","source":["print('The point cloud data:')\n","print(pc)\n","\n","print('\\nThe point labels:')\n","print(lab)"],"metadata":{"id":"r2DigO5bDTT7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Exercise:**\n","* Can you calculate the height of the plant? Tip: find the minimum and maximum z-coordinates\n","* How many stem points and how many leaf points are there?\n","* What do you think of the division of points in stem and leaf concerning the training of the neural network?"],"metadata":{"id":"GclA-udkEBE8"}},{"cell_type":"code","source":["# Add your code here\n"],"metadata":{"id":"4YYZ2RmOE9wp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 2.1: Setting up the training, validation and test data"],"metadata":{"id":"1iOzmc1kF03y"}},{"cell_type":"markdown","source":["To simplify loading of all the data, we put all labeled point clouds for training and validation in a single hdf5 (h5) file. Let's load this data:"],"metadata":{"id":"Spi_Wm9TG8Mr"}},{"cell_type":"code","source":["## - Load training set\n","train_val_data_path = \"point_cloud/seedling_train_val_set.h5\"\n","train_val_point_clouds, train_val_labels = get_dataset(train_val_data_path)\n","print(\"The size of the training and validation set: \", len(train_val_point_clouds))\n","print(\"Number of points per pointcloud: \", len(train_val_point_clouds[0]))"],"metadata":{"id":"Q81qLBlF1Cj3"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Exercise**:\n","* The code above prints the number of points in the first point cloud. Get also the number of points for a few other point clouds.\n","* What do you observe? And why would that be? "],"metadata":{"id":"YrayIxRPHjI3"}},{"cell_type":"code","source":["# You can add your code here\n"],"metadata":{"id":"_ZNzS2oEH4mu"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The neural network that we will use for semantic segmentation, PointNet, needs a fixed number of points as input. This is why the point clouds of all the seedlings were sampled to a size of 1024.\n","\n","**Exercise:**\n","* Run the code below to look at the point labels.\n","* You see that the format is different from what we observerd earlier. What do you think is the relation between the two formats? And why would we use this format for training of the neural network?"],"metadata":{"id":"0xGDKWrlIeN0"}},{"cell_type":"code","source":["train_val_labels[0]"],"metadata":{"id":"h0CAkOheDKnK"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["PointNet predicts $c$ values per point, where $c$ is the number of class labels. In our case, we have two classes, stem and leaf. This means that the network predicts two values for every point containing the probability for class soil and class stem. This is why in the labeled data, we see $[1,0]$ for soil and $[0,1]$ for leaf. This data will be used to train the network to make the right predictions."],"metadata":{"id":"88eGkTlxKX0i"}},{"cell_type":"markdown","source":["During training, we want to get a sense of how well the model will perform on new data. For this reason, we split up the data is a training set and a validation set according to a 80-20% split:\n"],"metadata":{"id":"wi--kV0JKyk6"}},{"cell_type":"code","source":["VAL_SPLIT = 0.2\n","\n","split_index = int(len(train_val_point_clouds) * (1 - VAL_SPLIT))\n","train_point_clouds = train_val_point_clouds[:split_index]\n","train_labels = train_val_labels[:split_index]\n","\n","val_point_clouds = train_val_point_clouds[split_index:]\n","val_labels = train_val_labels[split_index:]\n","\n","print(\"Num train point clouds:\", len(train_point_clouds))\n","print(\"Num train point cloud labels:\", len(train_labels))\n","print(\"Num val point clouds:\", len(val_point_clouds))\n","print(\"Num val point cloud labels:\", len(val_labels))"],"metadata":{"id":"KOJ4tHp93GXu"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["To use the training and validation sets, we need to wrap it in a TenserFlow format. For the training set, data augmentation is applied to enrich the training set to prevent overfitting. In this case, data augmentation consists of jittering, which adds some random noise to each point in the point cloud."],"metadata":{"id":"KfLsuq-KK6fw"}},{"cell_type":"code","source":["train_dataset = generate_dataset(train_point_clouds, train_labels)\n","val_dataset = generate_dataset(val_point_clouds, val_labels, is_training=False)\n","\n","print(\"Train Dataset:\", train_dataset)\n","print(\"Validation Dataset:\", val_dataset)"],"metadata":{"id":"3S93lh9fK3BY"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 3: Setting up the PointNet Model"],"metadata":{"id":"duaqXI4emzcY"}},{"cell_type":"markdown","source":["We will use the PointNet neural network. The figure below shows the network architecture:"],"metadata":{"id":"VX7i8vSZYZgV"}},{"cell_type":"markdown","source":["\n","\n","Here is a brief explanation of the different parts of the network:\n","* Calculation of features per point \n"," * As input, $n$ points are provided, with 3 dimensions (the position) per point. \n"," * The input is transformed using a 3x3 affine transformation matrix, which is provided by the T-Net. This rotates, scales and sheers the point cloud, with the aim to normalize the data and thus simplify the task\n"," * Next, two multi-layer-perceptron (mlp) layers are applied in sequence. Both with 64 output features per point. The mlp is applied to each point individually with the same (shared) weights for each point. Now we have a point cloud of n points with 64 features per point (nx64)\n"," * This point cloud is again transformed, now with a 64x64 affine transformation matrix resulting from the second T-net, again normalizing the data\n"," * Again mlp layers are applied to each point individually. This time 3 layers with an output dimension of 64, 128 and 1024, respectively. Resulting in a nx1024 point cloud\n","* Calculation of a global feature vector\n"," * A max pooling is applied to get a global featue vector. For each of the 1024 features, the maximum value is selected over all points in the point cloud. This results in a 1024-dimensional feature vector. Mind that because of the use of the max-operator, this is permutation invariant.\n","* Classification head\n"," * Classification is done using 3 consecutive mlp layers with 512, 256 and k outputs, where k is the number of object classes. Note that we do not use the classification head in this tutorial\n","* Segmentation head (in yellow)\n"," * The nx64 representation resulting after the second T-Net is concatenated with the 1024-dimensional global-feature vector for every point. This means that every point now carries a local descriptor and a global descriptor, resulting in a nx1088 point cloud. \n"," * Five mlp layers are applied to every point individually, with 512, 256, 128, 128 and m outputs respectively, where m is the number of segmentation classes\n"," * Per point, the class with the highest predicted score is taken as the class label \n","\n"],"metadata":{"id":"x-6lICxTAqp7"}},{"cell_type":"markdown","source":["The code block below defines the pointnet model.\n","\n","**Exercise (bonus):**\n","* If you feel comfortable, inspect the code below to understand the network architecture. Note that `conv_block` implements the mlp blocks that were shown in the figure.\n","* Compare the defined architecture in the block below with the architecture shown in the figure. There are differences, can you find them? "],"metadata":{"id":"4X7GtLpIZBCz"}},{"cell_type":"code","source":["def get_pointnet_model(num_points: int, num_classes: int) -> keras.Model:\n"," \n"," input_points = keras.Input(shape=(None, NUM_FEATURES))\n","\n"," # PointNet Feature Extractor\n"," transformed_inputs = transformation_block(\n"," input_points, num_features=NUM_FEATURES, name=\"input_transformation_block\"\n"," )\n"," features_64 = conv_block(transformed_inputs, filters=64, name=\"features_64\")\n"," features_128_1 = conv_block(features_64, filters=128, name=\"features_128_1\")\n"," features_128_2 = conv_block(features_128_1, filters=128, name=\"features_128_2\")\n"," transformed_features = transformation_block(\n"," features_128_2, num_features=128, name=\"transformed_features\"\n"," )\n"," features_512 = conv_block(transformed_features, filters=512, name=\"features_512\")\n"," features_2048 = conv_block(features_512, filters=2048, name=\"pre_maxpool_block\")\n"," global_features = layers.MaxPool1D(pool_size=num_points, name=\"global_features\")(\n"," features_2048\n"," )\n"," global_features = tf.tile(global_features, [1, num_points, 1])\n","\n"," # Segmentation head.\n"," segmentation_input = layers.Concatenate(name=\"segmentation_input\")(\n"," [\n"," features_64,\n"," features_128_1,\n"," features_128_2,\n"," transformed_features,\n"," features_512,\n"," global_features,\n"," ]\n"," )\n"," segmentation_features = conv_block(\n"," segmentation_input, filters=128, name=\"segmentation_features\"\n"," )\n"," outputs = layers.Conv1D(\n"," num_classes, kernel_size=1, activation=\"softmax\", name=\"segmentation_head\"\n"," )(segmentation_features)\n"," return keras.Model(input_points, outputs)"],"metadata":{"id":"pAw3DX85m5DX"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We will now create the model:"],"metadata":{"id":"EZLYjIacazbJ"}},{"cell_type":"code","source":["segmentation_model = get_pointnet_model(NUM_SAMPLE_POINTS, len(LABELS))"],"metadata":{"id":"RRzvn_c-7i3z"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["With `segmentation_model.summary()`, we can get an overview of the different layers and connections in the model including the number of parameters (weights) in the network that need to be trained.\n","\n","**Exercise:**\n","* How many trainable parameters are there in the network?\n","* Which layer contains most parameters?"],"metadata":{"id":"2ZVhGGeNa2fD"}},{"cell_type":"code","source":["segmentation_model.summary()\n","\n","# You can also visualize the connections if you remove the # in the line below\n","#keras.utils.plot_model(segmentation_model, \"model.png\", show_shapes=True)"],"metadata":{"id":"A3OFlQcy8Yrg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 4: Training the PointNet model\n","\n","Now that we defined the PointNet model, it is time to train it."],"metadata":{"id":"xxXgmyz9O9gH"}},{"cell_type":"markdown","source":["## 4.1: Setting up the training process\n","\n"],"metadata":{"id":"YTsHyCKlO5U5"}},{"cell_type":"markdown","source":["The code below defines a function that we can use for training. It contains a few steps:\n","1. Get the PointNet model with randomly initialized network weights\n","2. Set a learning schedule using eponential decay\n","3. Set Adam as optimized with cross-entropy as loss function and accuracy as metric\n","4. Start the training of the network by fitting it to the training set for a number of epochs"],"metadata":{"id":"AYAQYPgaxghn"}},{"cell_type":"code","source":["def run_experiment(epochs):\n"," # Step 1: Get the PointNet model\n"," segmentation_model = get_pointnet_model(NUM_SAMPLE_POINTS, len(LABELS))\n","\n"," # Step 2: Setting up the learning schedule using exponential decay\n"," INITIAL_LR = 0.001\n"," lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=INITIAL_LR, \\\n"," decay_steps=200000, \\\n"," decay_rate=0.7, \\\n"," staircase=False, \\\n"," name=None)\n","\n"," # Step 3: Set Adam as optimizer, using cross-entropy as loss function and accuracy as metricx\n"," segmentation_model.compile(\n"," optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),\n"," loss=keras.losses.CategoricalCrossentropy(),\n"," metrics=[\"accuracy\"],\n"," )\n","\n"," # Step 4: Start the training of the network by fitting it to the training set for a number of epochs\n"," history = segmentation_model.fit(\n"," train_dataset,\n"," validation_data=val_dataset,\n"," epochs=epochs\n"," )\n","\n"," return segmentation_model, history"],"metadata":{"id":"76Pb1ETSnq12"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 4.2: Execution of the Training Process\n","\n","We will now execute the training process. We will do this for 5 epochs, that means that all plants in the training set will be used 5 times to train the network. Training the network for 50 epochs will yield better performance, but this takes too long for the tutorial as it takes roughly 1 minute per epoch on Google Colab. But already with 5 epochs, we can see satisfactory results. \n","\n","NB. Later, you will get the weights of the neural network trained for 50 epochs, so that you can see the results of that better optimized network."],"metadata":{"id":"xCW9ZEM6ptC_"}},{"cell_type":"code","source":["# Train the network for 5 epochs\n","EPOCHS = 5\n","segmentation_model, history = run_experiment(epochs=EPOCHS)\n","\n","\n","#tf.saved_model.save(segmentation_model, checkpoint_filepath)\n","\n","#with open(os.path.join(log_dir, 'segmentation_history.pickle'), 'wb') as file_pi:\n","# pickle.dump(history.history, file_pi)"],"metadata":{"id":"Ts29rZHlpsKg"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Exercise:**\n","* Let's have a look at the accuracy and loss during training in the code block below.\n","* What do you see?\n","* Are there signs of overfitting?\n","* Might it help to train for more epochs?\n"],"metadata":{"id":"I6ytbt3mk_F_"}},{"cell_type":"code","source":["plt.figure(figsize=(10,4))\n","plt.subplot(1,2,1)\n","plt.plot(history.history['accuracy'])\n","plt.plot(history.history['val_accuracy'])\n","plt.title('model accuracy')\n","plt.ylabel('accuracy')\n","plt.xlabel('epoch')\n","plt.legend(['training', 'validation'], loc='upper left')\n","\n","plt.subplot(1,2,2)\n","plt.plot(history.history['loss'])\n","plt.plot(history.history['val_loss'])\n","plt.title('model loss')\n","plt.ylabel('loss')\n","plt.xlabel('epoch')\n","plt.legend(['training', 'validation'], loc='upper left')\n","plt.show()"],"metadata":{"id":"UTGpxny0kE1p"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# 5: Test the network"],"metadata":{"id":"cJ3NEOYJTwW0"}},{"cell_type":"markdown","source":["## 5.1: Use the network to predict the segmentation of stem and leaf\n","\n","For a proper test of the performance of the network, we will use the test set, which was not used during the training of the network, nor for optimizing any hyperparameters.\n","\n","Getting test data:"],"metadata":{"id":"-XJ2ulGNT04S"}},{"cell_type":"code","source":["test_data_path = \"point_cloud/seedling_test_set.h5\"\n","\n","test_point_clouds, test_labels = get_dataset(test_data_path)\n","print(\"The size of the test set is:\", len(test_point_clouds), 'plants')\n","\n","# Make the test set\n","test_dataset = generate_dataset(test_point_clouds, test_labels, is_training=False)"],"metadata":{"id":"tmDAK9a-T3Gh"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the model to segment the first batch of 16 point clouds in the test set:"],"metadata":{"id":"G8uYtp4nUj29"}},{"cell_type":"code","source":["# Get the first test batch of 16 plants\n","test_batch = next(iter(test_dataset))\n","\n","# Process the point clouds (whole batch) with the trained segmentation model to get the predicted output \n","test_predictions = segmentation_model(test_batch[0])\n","\n","# Check the size of the input and output\n","print(f\"Shape of the test input: {test_batch[0].shape}\")\n","print(f\"Shape of the test output (ground truth): {test_batch[1].shape}\")\n","print(f\"Shape of the prediction made by the network: {test_predictions.shape}\")\n","\n","# Calculate the accuracy"],"metadata":{"id":"0PqavqyOUqfU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Exercise:**\n","* We can have a look at the predictions made for the first plant in the batch int eh code block below\n","* How many predictions are made per point? How can you interpret these values? \n","* Based on this prediction, how can we get the predicted class label (0: stem, 1: leaf)?"],"metadata":{"id":"sifS-jvWWIdH"}},{"cell_type":"code","source":["test_predictions[0]"],"metadata":{"id":"fT695pZoWOyK"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["You see that for every point (every row), two values are provided. This is the predicted probability that this point is either stem (first column) or leaf (second column). By taking the `argmax`, we can find the class label of the most probably class. `argmax(...,axis=1)` finds the highest value per row and returns the index of the column containing this highest value."],"metadata":{"id":"dBqXc40eWYjg"}},{"cell_type":"code","source":["np.argmax(test_predictions[0], axis=1)"],"metadata":{"id":"3nqOjsb6W3jR"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We can also visualize the predicted segmentation of the point cloud.\n","\n","**Exercise:**\n","* Run the code below\n","* Change the `plant_id` within the range [0-15]. You might see some points clouds that are well segmented and some others that are poorly segmented. Note that we trained the network only for 5 epochs."],"metadata":{"id":"RIDsgSv9Xhoi"}},{"cell_type":"code","source":["plant_id = 0\n","\n","test_prediction = np.argmax(test_predictions[plant_id], axis=1)\n","show_labeled_seedling_pc(test_point_clouds[plant_id], test_prediction)"],"metadata":{"id":"qkodRsJkXjqp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We can quantify the performance of the network by calculating the accuracy over the full point cloud:"],"metadata":{"id":"PHIuXsSeYNfX"}},{"cell_type":"code","source":["def calc_accuracy(test_predictions, test_truth, show=True):\n"," # Calculate the accuracy of the prediction\n"," accuracies = np.zeros(len(test_batch[0]))\n"," for idx in range( len(test_batch[0]) ):\n"," # Get the predicted class labels per point\n"," pc_pred_labels_one_hot = test_predictions[idx]\n"," pc_pred_labels = np.array([np.argmax(label) for label in pc_pred_labels_one_hot])\n"," # Get the ground-truth class labels\n"," pc_gt_labels_one_hot = test_truth[idx]\n"," pc_gt_labels = np.array([np.argmax(label) for label in pc_gt_labels_one_hot])\n","\n"," # Calculate the accuracy for this point cloud\n"," accuracy = np.mean(pc_pred_labels == pc_gt_labels)\n"," accuracies[idx] = accuracy\n","\n"," if(show):\n"," print(' Accuracy for plant %d: %1.2f' % (idx, accuracy))\n","\n"," if(show):\n"," print('Overall accuracy (n=%d): %1.2f (%1.2f)' % (len(accuracies), accuracies.mean(), accuracies.std()))\n"," return(accuracies)\n","\n","# Calculate the accuracy by comparing the test_prediction to the ground truth (in test_batch[1])\n","accuracy = calc_accuracy(test_predictions, test_batch[1])"],"metadata":{"id":"mMkP98AmYVdZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 5.2: Loading weights of the network trained for 50 epochs \n"],"metadata":{"id":"zU-Ay16vYvS6"}},{"cell_type":"markdown","source":["You see that 5 epochs is not enough to consistenty predict a good segmentation. \n","Prior to the tutorial, the network was trained for 50 epochs and the network weights were stored. Let's load them:"],"metadata":{"id":"t0X7OTYzcx-o"}},{"cell_type":"code","source":["log_dir = \"point_cloud\" \n","checkpoint_50_epochs_filepath = os.path.join(log_dir, \"checkpoints_50epochs\")\n","\n","# Load the pre-trained weights\n","segmentation_model_50epochs = tf.saved_model.load(checkpoint_50_epochs_filepath)"],"metadata":{"id":"WIUlzH06aVdB"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["We can now use this network to segment the point clouds. Let's look at the results.\n","\n","**Exercise:**\n","* Run the code below to calculate the accuracy for the plants in the test set.\n","* Compare the results to results when trained for 5 epochs. "],"metadata":{"id":"g3ntHT-mc6T1"}},{"cell_type":"code","source":["# Process the point clouds (whole batch) with the trained segmentation model to get the predicted output \n","test_predictions = segmentation_model_50epochs(test_batch[0])\n","\n","accuracy = calc_accuracy(test_predictions, test_batch[1])"],"metadata":{"id":"HZXHEtLba-qZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["And look at the resulting segmented point clouds:"],"metadata":{"id":"mTS2j5nrdAGS"}},{"cell_type":"code","source":["plant_id = 0\n","\n","test_prediction = np.argmax(test_predictions[plant_id], axis=1)\n","show_labeled_seedling_pc(test_point_clouds[plant_id], test_prediction)"],"metadata":{"id":"XecRgA5OdDl0"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 5.3: Test accuracy for the whole test set\n","\n","In the code below, we looked at the performance of the network for one batch from the test set, containing 16 plants. Let's look at the accuracy for all plants in the test set:"],"metadata":{"id":"szyejNAIojCW"}},{"cell_type":"code","source":["accuracies = []\n","for batch_i, test_batch in enumerate(test_dataset):\n"," # Process the point clouds (whole batch) with the trained segmentation model to get the predicted output \n"," test_predictions = segmentation_model_50epochs(test_batch[0])\n"," batch_acc = calc_accuracy(test_predictions, test_batch[1], show=False)\n"," accuracies = np.hstack((accuracies,batch_acc))\n"," print(' Mean accuracy for batch %d: %1.3f' %(batch_i, batch_acc.mean()))\n","\n","print('Overall accuracy (n=%d): %1.3f (%1.3f)' % (len(accuracies), accuracies.mean(), accuracies.std()))"],"metadata":{"id":"p3reEDQQs5Wk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# The end\n","\n","This ends the second part of the tutorial on 3D point cloud processing.\n","\n","For further reading on the use of deep neural networks for the segmentation of 3D point clouds of plants, we refer to:\n","\n","* Boogaard, F. P., van Henten, E., & Kootstra, G. (2022). Improved point-cloud segmentation for plant phenotyping through class-dependent sampling of training data to battle class imbalance. Frontiers in Plant Science, 13, 838190. https://doi.org/10.3389/fpls.2022.838190\n","* Boogaard, F. P., van Henten, E., & Kootstra, G. (2021). Boosting plant-part segmentation of cucumber plants by enriching incomplete 3D point clouds with spectral data. Biosystems Engineering, 211, 167–182. https://doi.org/10.1016/j.biosystemseng.2021.09.004\n","* Turgut, K., Dutagaci, H., G., G., Rousseau, D. (2022). Segmentation of structural parts of rosebush plants with 3d point-based deep learning methods. Plant Methods 18 (20). https://doi.org/10.1186/s13007-022-00857-3\n","\n","\n"],"metadata":{"id":"lS41XCHZvjk5"}}]}