We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
将代码:
def get_batch(self, batch_size): # 从数据集中随机取出batch_size个元素并返回 index = np.random.randint(0, np.shape(self.train_data)[0], batch_size) return self.train_data[index, :], self.train_label[index]
改为:
def get_batch(self, batch_size): # 从数据集中随机取出batch_size个元素并返回 # index = np.random.randint(0, np.shape(self.train_data)[0], batch_size) index = np.random.choice(np.shape(self.train_data)[0], batch_size, replace=False) return self.train_data[index, :], self.train_label[index]
可避免每次获取的数据中不存在重复项。
The text was updated successfully, but these errors were encountered:
Thank you for the suggestion, could you please kindly submit a Pull Request to fix that?
Appreciate!
Sorry, something went wrong.
No branches or pull requests
将代码:
改为:
可避免每次获取的数据中不存在重复项。
The text was updated successfully, but these errors were encountered: