[PyTorch]中的数据类型

基本概念

torch.float64对应torch.DoubleTensor
torch.float32对应torch.FloatTensor

  • 在判断数据类型的时候只能使用torch.FloatTensor不能使用torch.float32,而且tensorA.type()返回的是数字串类型的
a=[[1.0,2.0],[3.9,4]]
b=[2.0,4]
a=torch.tensor(a)
b=torch.tensor(b).to(torch.float32)
if b.type()=="torch.FloatTensor":
    print("ha")
>>>ha
  • torch.tensor([1.2,3],dtype=troch.float)中的dtype不能用dtype=torch.FloatTensor因为设置的是数据类型而不是tensor类型。

  • .to(float32)就等于.to(float)torch.int也有但是等于torch.IntTensor不等于torch.LongTensor

进行数据转换的几种方式

  1. 使用函数tensor1.type_as(tensor2)将1的数据类型转换成2的数据类型。
tensor_1 = torch.FloatTensor(5)
tensor_2 = torch.IntTensor([10, 20])
tensor_1 = tensor_1.type_as(tensor_2)
  1. tensor.type(torch.IntTensor)
  2. tensor.long(),tensor.char(),tensor.int(),tensor.byte(),tensor.double()
  3. tenosr.to(torch.long)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容